add split dataset by shard option to accelerate data loading
This commit is contained in:
parent
f0d41f937b
commit
51686b5fb8
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-24 20:05:21
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-05-04 08:42:58
|
||||
LastEditTime: 2023-05-04 09:17:21
|
||||
FilePath: /Open-Llama/dataset/dataset.py
|
||||
Description:
|
||||
|
||||
|
@ -16,6 +16,8 @@ from datasets import load_dataset
|
|||
|
||||
|
||||
random.seed(42)
|
||||
|
||||
|
||||
def pretrain_transform(batch):
|
||||
# wudao preprocess
|
||||
if "title" in batch and "content" in batch:
|
||||
|
@ -153,12 +155,18 @@ def get_labels_gen(pad_token_id):
|
|||
return get_labels
|
||||
|
||||
|
||||
def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
||||
def construct_dataset(
|
||||
dataset_config, tokenizer, return_raw_text=False, world_size=None
|
||||
):
|
||||
all_data_files = []
|
||||
for name, pattern in dataset_config["data"].items():
|
||||
data_files = glob(pattern)
|
||||
assert len(data_files) > 0
|
||||
all_data_files.extend(data_files)
|
||||
random.shuffle(all_data_files)
|
||||
if world_size is not None:
|
||||
num_shards = len(all_data_files)
|
||||
all_data_files = all_data_files[num_shards // world_size * world_size]
|
||||
dataset = load_dataset(
|
||||
"json", data_files=all_data_files, split="train", streaming=True
|
||||
)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-12 19:12:42
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-05-02 18:26:50
|
||||
LastEditTime: 2023-05-04 09:19:15
|
||||
FilePath: /Open-Llama/train_lm.py
|
||||
Description:
|
||||
|
||||
|
@ -41,7 +41,12 @@ def main(argv):
|
|||
add_eos_token=True,
|
||||
)
|
||||
data_config = config["data"]
|
||||
train_dataset = construct_dataset(data_config, tokenizer)
|
||||
if data_config.get("split_by_shard", False):
|
||||
train_dataset = construct_dataset(
|
||||
data_config, tokenizer, world_size=accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
train_dataset = construct_dataset(data_config, tokenizer)
|
||||
train_dataset = split_dataset_by_node(
|
||||
train_dataset,
|
||||
rank=accelerator.process_index,
|
||||
|
|
Loading…
Reference in New Issue
Block a user