98 lines
3.7 KiB
Python
98 lines
3.7 KiB
Python
"""
|
|
Author: LiangSong(sl12160010@gmail.com)
|
|
Date: 2023-03-17 19:32:20
|
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
|
LastEditTime: 2023-03-26 23:03:32
|
|
FilePath: /Open-Llama/dataset/data_iter.py
|
|
Description:
|
|
|
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
|
"""
|
|
import json
|
|
from glob import glob
|
|
import zstandard as zstd
|
|
|
|
|
|
def create_data_iter(paths, transform_dict=None, process_index=0, num_processes=1):
|
|
"""
|
|
Currently, the allowed storage formats are jsonl and jsonl.zst.
|
|
Each line of the data is a dictionary, which can be parsed as JSON for subsequent processing after reading.
|
|
"""
|
|
past = None
|
|
for i, path in paths:
|
|
dataset_name = path.split("-")[-2]
|
|
if past != dataset_name:
|
|
print("Loading data from {}".format(path))
|
|
past = path
|
|
if num_processes > 1 and i % num_processes != process_index:
|
|
continue
|
|
if path.endswith("jsonl.zst"):
|
|
with zstd.open(path, "r", encoding="utf-8") as fp:
|
|
for line in fp:
|
|
if isinstance(line, bytes):
|
|
line = line.decode("utf-8")
|
|
line = json.loads(line)
|
|
line["dataset"] = dataset_name
|
|
if transform_dict:
|
|
line = transform_dict[dataset_name](line)
|
|
if isinstance(line, str):
|
|
yield line
|
|
elif isinstance(line, list):
|
|
for i in line:
|
|
yield i
|
|
else:
|
|
raise Exception(
|
|
"Unsupported type in Transformation: {}".format(
|
|
transform_dict[dataset_name]
|
|
)
|
|
)
|
|
else:
|
|
yield line
|
|
elif path.endswith("jsonl"):
|
|
with open(path, "r") as fp:
|
|
for line in fp:
|
|
if isinstance(line, bytes):
|
|
line = line.decode("utf-8")
|
|
line = json.loads(line)
|
|
line["dataset"] = dataset_name
|
|
if transform_dict:
|
|
line = transform_dict[dataset_name](line)
|
|
if isinstance(line, str):
|
|
yield line
|
|
elif isinstance(line, list):
|
|
for i in line:
|
|
yield i
|
|
else:
|
|
raise Exception(
|
|
"Unsupported type in Transformation: {}".format(
|
|
transform_dict[dataset_name]
|
|
)
|
|
)
|
|
else:
|
|
yield line
|
|
else:
|
|
raise Exception("File format of {} is not supported yet.".format(path))
|
|
|
|
|
|
def create_shard_kwargs(patterns, repeat=1):
|
|
"""
|
|
Assign numbers to different shards of data to ensure that data is not duplicated
|
|
when allocated to different nodes during distributed training.
|
|
"""
|
|
all_path = []
|
|
for p in patterns:
|
|
all_path.extend(glob(p))
|
|
all_path *= repeat
|
|
return [(i, p) for i, p in enumerate(all_path)]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
|
|
paths = create_shard_kwargs(patterns)
|
|
transform_dict = {"wudao": lambda x: x["title"], "pile": lambda x: [x["text"]]}
|
|
data_iter = create_data_iter(paths, transform_dict=transform_dict)
|
|
for i, data in enumerate(data_iter):
|
|
print(i, data)
|
|
if i == 20:
|
|
break
|