add split dataset by shard option to accelerate data loading

This commit is contained in:
LiangSong 2023-05-04 09:20:23 +08:00
parent f0d41f937b
commit 51686b5fb8
2 changed files with 17 additions and 4 deletions

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-24 20:05:21
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-05-04 08:42:58
LastEditTime: 2023-05-04 09:17:21
FilePath: /Open-Llama/dataset/dataset.py
Description:
@ -16,6 +16,8 @@ from datasets import load_dataset
random.seed(42)
def pretrain_transform(batch):
# wudao preprocess
if "title" in batch and "content" in batch:
@ -153,12 +155,18 @@ def get_labels_gen(pad_token_id):
return get_labels
def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
def construct_dataset(
dataset_config, tokenizer, return_raw_text=False, world_size=None
):
all_data_files = []
for name, pattern in dataset_config["data"].items():
data_files = glob(pattern)
assert len(data_files) > 0
all_data_files.extend(data_files)
random.shuffle(all_data_files)
if world_size is not None:
num_shards = len(all_data_files)
all_data_files = all_data_files[num_shards // world_size * world_size]
dataset = load_dataset(
"json", data_files=all_data_files, split="train", streaming=True
)

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-12 19:12:42
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-05-02 18:26:50
LastEditTime: 2023-05-04 09:19:15
FilePath: /Open-Llama/train_lm.py
Description:
@ -41,7 +41,12 @@ def main(argv):
add_eos_token=True,
)
data_config = config["data"]
train_dataset = construct_dataset(data_config, tokenizer)
if data_config.get("split_by_shard", False):
train_dataset = construct_dataset(
data_config, tokenizer, world_size=accelerator.num_processes
)
else:
train_dataset = construct_dataset(data_config, tokenizer)
train_dataset = split_dataset_by_node(
train_dataset,
rank=accelerator.process_index,