add split dataset by shard option to accelerate data loading
This commit is contained in:
parent
f0d41f937b
commit
51686b5fb8
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-24 20:05:21
|
Date: 2023-04-24 20:05:21
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-05-04 08:42:58
|
LastEditTime: 2023-05-04 09:17:21
|
||||||
FilePath: /Open-Llama/dataset/dataset.py
|
FilePath: /Open-Llama/dataset/dataset.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -16,6 +16,8 @@ from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
random.seed(42)
|
random.seed(42)
|
||||||
|
|
||||||
|
|
||||||
def pretrain_transform(batch):
|
def pretrain_transform(batch):
|
||||||
# wudao preprocess
|
# wudao preprocess
|
||||||
if "title" in batch and "content" in batch:
|
if "title" in batch and "content" in batch:
|
||||||
|
@ -153,12 +155,18 @@ def get_labels_gen(pad_token_id):
|
||||||
return get_labels
|
return get_labels
|
||||||
|
|
||||||
|
|
||||||
def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
def construct_dataset(
|
||||||
|
dataset_config, tokenizer, return_raw_text=False, world_size=None
|
||||||
|
):
|
||||||
all_data_files = []
|
all_data_files = []
|
||||||
for name, pattern in dataset_config["data"].items():
|
for name, pattern in dataset_config["data"].items():
|
||||||
data_files = glob(pattern)
|
data_files = glob(pattern)
|
||||||
assert len(data_files) > 0
|
assert len(data_files) > 0
|
||||||
all_data_files.extend(data_files)
|
all_data_files.extend(data_files)
|
||||||
|
random.shuffle(all_data_files)
|
||||||
|
if world_size is not None:
|
||||||
|
num_shards = len(all_data_files)
|
||||||
|
all_data_files = all_data_files[num_shards // world_size * world_size]
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
"json", data_files=all_data_files, split="train", streaming=True
|
"json", data_files=all_data_files, split="train", streaming=True
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-12 19:12:42
|
Date: 2023-04-12 19:12:42
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-05-02 18:26:50
|
LastEditTime: 2023-05-04 09:19:15
|
||||||
FilePath: /Open-Llama/train_lm.py
|
FilePath: /Open-Llama/train_lm.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -41,7 +41,12 @@ def main(argv):
|
||||||
add_eos_token=True,
|
add_eos_token=True,
|
||||||
)
|
)
|
||||||
data_config = config["data"]
|
data_config = config["data"]
|
||||||
train_dataset = construct_dataset(data_config, tokenizer)
|
if data_config.get("split_by_shard", False):
|
||||||
|
train_dataset = construct_dataset(
|
||||||
|
data_config, tokenizer, world_size=accelerator.num_processes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train_dataset = construct_dataset(data_config, tokenizer)
|
||||||
train_dataset = split_dataset_by_node(
|
train_dataset = split_dataset_by_node(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
rank=accelerator.process_index,
|
rank=accelerator.process_index,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user