Open-Llama/dataset/data_iter.py
2023-03-27 14:34:59 +08:00

98 lines
3.7 KiB
Python

"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-17 19:32:20
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-26 23:03:32
FilePath: /Open-Llama/dataset/data_iter.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import json
from glob import glob
import zstandard as zstd
def create_data_iter(paths, transform_dict=None, process_index=0, num_processes=1):
"""
Currently, the allowed storage formats are jsonl and jsonl.zst.
Each line of the data is a dictionary, which can be parsed as JSON for subsequent processing after reading.
"""
past = None
for i, path in paths:
dataset_name = path.split("-")[-2]
if past != dataset_name:
print("Loading data from {}".format(path))
past = path
if num_processes > 1 and i % num_processes != process_index:
continue
if path.endswith("jsonl.zst"):
with zstd.open(path, "r", encoding="utf-8") as fp:
for line in fp:
if isinstance(line, bytes):
line = line.decode("utf-8")
line = json.loads(line)
line["dataset"] = dataset_name
if transform_dict:
line = transform_dict[dataset_name](line)
if isinstance(line, str):
yield line
elif isinstance(line, list):
for i in line:
yield i
else:
raise Exception(
"Unsupported type in Transformation: {}".format(
transform_dict[dataset_name]
)
)
else:
yield line
elif path.endswith("jsonl"):
with open(path, "r") as fp:
for line in fp:
if isinstance(line, bytes):
line = line.decode("utf-8")
line = json.loads(line)
line["dataset"] = dataset_name
if transform_dict:
line = transform_dict[dataset_name](line)
if isinstance(line, str):
yield line
elif isinstance(line, list):
for i in line:
yield i
else:
raise Exception(
"Unsupported type in Transformation: {}".format(
transform_dict[dataset_name]
)
)
else:
yield line
else:
raise Exception("File format of {} is not supported yet.".format(path))
def create_shard_kwargs(patterns, repeat=1):
"""
Assign numbers to different shards of data to ensure that data is not duplicated
when allocated to different nodes during distributed training.
"""
all_path = []
for p in patterns:
all_path.extend(glob(p))
all_path *= repeat
return [(i, p) for i, p in enumerate(all_path)]
if __name__ == "__main__":
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
transform_dict = {"wudao": lambda x: x["title"], "pile": lambda x: [x["text"]]}
data_iter = create_data_iter(paths, transform_dict=transform_dict)
for i, data in enumerate(data_iter):
print(i, data)
if i == 20:
break