add split dataset by shard option to accelerate data loading

This commit is contained in:
LiangSong 2023-05-04 09:20:23 +08:00
parent f0d41f937b
commit 51686b5fb8
2 changed files with 17 additions and 4 deletions

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com) Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-24 20:05:21 Date: 2023-04-24 20:05:21
LastEditors: LiangSong(sl12160010@gmail.com) LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-05-04 08:42:58 LastEditTime: 2023-05-04 09:17:21
FilePath: /Open-Llama/dataset/dataset.py FilePath: /Open-Llama/dataset/dataset.py
Description: Description:
@ -16,6 +16,8 @@ from datasets import load_dataset
random.seed(42) random.seed(42)
def pretrain_transform(batch): def pretrain_transform(batch):
# wudao preprocess # wudao preprocess
if "title" in batch and "content" in batch: if "title" in batch and "content" in batch:
@ -153,12 +155,18 @@ def get_labels_gen(pad_token_id):
return get_labels return get_labels
def construct_dataset(dataset_config, tokenizer, return_raw_text=False): def construct_dataset(
dataset_config, tokenizer, return_raw_text=False, world_size=None
):
all_data_files = [] all_data_files = []
for name, pattern in dataset_config["data"].items(): for name, pattern in dataset_config["data"].items():
data_files = glob(pattern) data_files = glob(pattern)
assert len(data_files) > 0 assert len(data_files) > 0
all_data_files.extend(data_files) all_data_files.extend(data_files)
random.shuffle(all_data_files)
if world_size is not None:
num_shards = len(all_data_files)
all_data_files = all_data_files[num_shards // world_size * world_size]
dataset = load_dataset( dataset = load_dataset(
"json", data_files=all_data_files, split="train", streaming=True "json", data_files=all_data_files, split="train", streaming=True
) )

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com) Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-12 19:12:42 Date: 2023-04-12 19:12:42
LastEditors: LiangSong(sl12160010@gmail.com) LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-05-02 18:26:50 LastEditTime: 2023-05-04 09:19:15
FilePath: /Open-Llama/train_lm.py FilePath: /Open-Llama/train_lm.py
Description: Description:
@ -41,6 +41,11 @@ def main(argv):
add_eos_token=True, add_eos_token=True,
) )
data_config = config["data"] data_config = config["data"]
if data_config.get("split_by_shard", False):
train_dataset = construct_dataset(
data_config, tokenizer, world_size=accelerator.num_processes
)
else:
train_dataset = construct_dataset(data_config, tokenizer) train_dataset = construct_dataset(data_config, tokenizer)
train_dataset = split_dataset_by_node( train_dataset = split_dataset_by_node(
train_dataset, train_dataset,