Open-Llama/dataset/data_iter.py
2023-04-06 03:45:24 +08:00

117 lines
4.3 KiB
Python

"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-17 19:32:20
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-06 03:37:55
FilePath: /Open-Llama/dataset/data_iter.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import json
from glob import glob
import zstandard as zstd
from torch.utils.data import IterableDataset
class DataIter(IterableDataset):
"""
Currently, the allowed storage formats are jsonl.zst.
Each line of the data is a dictionary, which can be parsed as JSON for subsequent processing after reading.
Currently, only single worker is supported.
"""
def __init__(
self,
paths_with_index,
transform_dict=None,
max_length=None,
concat_docs=False,
process_index=0,
num_processes=1,
):
super().__init__()
self.paths_with_index = paths_with_index
self.max_length = max_length
self.transform_dict = transform_dict
self.concat_docs = concat_docs
self.process_index = process_index
self.num_processes = num_processes
if self.concat_docs:
self.cache = []
def __iter__(self):
past = None
for i, path in self.paths_with_index:
# part-dataset_name-01.jsonl.zst
dataset_name = path.split("-")[-2]
# shard to multiple device
if self.num_processes > 1 and i % self.num_processes != self.process_index:
continue
# Log the file name when encountering a new file.
if past != dataset_name:
print("Loading data from {}".format(path))
past = path
# Currently, the allowed storage formats are jsonl.zst.
assert path.endswith("jsonl.zst")
with zstd.open(path, "r", encoding="utf-8") as fp:
for line in fp:
# If the length of the cache is greater than max_length.
if self.concat_docs and len(self.cache) >= self.max_length:
seq = self.cache[: self.max_length]
self.cache = self.cache[self.max_length :]
yield seq
if isinstance(line, bytes):
line = line.decode("utf-8")
line = json.loads(line)
line["dataset"] = dataset_name
# Transformation, including sample, tokenize, etc.
if self.transform_dict:
line = self.transform_dict[dataset_name](line)
# skip bad doc
if line is None:
continue
elif isinstance(line, str):
yield line
# must be list of list
elif isinstance(line, list) and isinstance(line[0], list):
for seq in line:
if self.concat_docs:
# concat seq from multiple docs
self.cache += seq
else:
yield seq
else:
raise Exception(
"Unsupported type in Transformation: {}".format(
self.transform_dict[dataset_name]
)
)
else:
yield line
def create_shard_kwargs(patterns, repeat=1):
"""
Assign numbers to different shards of data to ensure that data is not duplicated
when allocated to different nodes during distributed training.
"""
all_path = []
for p in patterns:
all_path.extend(glob(p))
all_path *= repeat
return [(i, p) for i, p in enumerate(all_path)]
if __name__ == "__main__":
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
transform_dict = {"wudao": lambda x: x["title"], "pile": lambda x: [x["text"]]}
data_iter = DataIter(
paths, transform_dict=transform_dict, max_length=16, concat_docs=True
)
for i, data in enumerate(data_iter):
print(i, data)
if i == 20:
break