92 lines
3.4 KiB
Python
92 lines
3.4 KiB
Python
'''
|
|
Author: LiangSong(sl12160010@gmail.com)
|
|
Date: 2023-03-17 19:32:20
|
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
|
LastEditTime: 2023-03-26 23:03:32
|
|
FilePath: /Open-Llama/dataset/data_iter.py
|
|
Description:
|
|
|
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
|
'''
|
|
import json
|
|
from glob import glob
|
|
import zstandard as zstd
|
|
|
|
|
|
def create_data_iter(paths, transform_dict=None, process_index=0, num_processes=1):
|
|
'''
|
|
Currently, the allowed storage formats are jsonl and jsonl.zst.
|
|
Each line of the data is a dictionary, which can be parsed as JSON for subsequent processing after reading.
|
|
'''
|
|
past = None
|
|
for i, path in paths:
|
|
dataset_name = path.split('-')[-2]
|
|
if past != dataset_name:
|
|
print('Loading data from {}'.format(path))
|
|
past = path
|
|
if num_processes > 1 and i % num_processes != process_index:
|
|
continue
|
|
if path.endswith('jsonl.zst'):
|
|
with zstd.open(path, 'r', encoding='utf-8') as fp:
|
|
for line in fp:
|
|
if isinstance(line, bytes):
|
|
line = line.decode('utf-8')
|
|
line = json.loads(line)
|
|
line['dataset'] = dataset_name
|
|
if transform_dict:
|
|
line = transform_dict[dataset_name](line)
|
|
if isinstance(line, str):
|
|
yield line
|
|
elif isinstance(line, list):
|
|
for i in line:
|
|
yield i
|
|
else:
|
|
raise Exception('Unsupported type in Transformation: {}'.format(transform_dict[dataset_name]))
|
|
else:
|
|
yield line
|
|
elif path.endswith('jsonl'):
|
|
with open(path, 'r') as fp:
|
|
for line in fp:
|
|
if isinstance(line, bytes):
|
|
line = line.decode('utf-8')
|
|
line = json.loads(line)
|
|
line['dataset'] = dataset_name
|
|
if transform_dict:
|
|
line = transform_dict[dataset_name](line)
|
|
if isinstance(line, str):
|
|
yield line
|
|
elif isinstance(line, list):
|
|
for i in line:
|
|
yield i
|
|
else:
|
|
raise Exception('Unsupported type in Transformation: {}'.format(transform_dict[dataset_name]))
|
|
else:
|
|
yield line
|
|
else:
|
|
raise Exception('File format of {} is not supported yet.'.format(path))
|
|
|
|
def create_shard_kwargs(patterns, repeat=1):
|
|
'''
|
|
Assign numbers to different shards of data to ensure that data is not duplicated
|
|
when allocated to different nodes during distributed training.
|
|
'''
|
|
all_path = []
|
|
for p in patterns:
|
|
all_path.extend(glob(p))
|
|
all_path *= repeat
|
|
return [(i, p) for i, p in enumerate(all_path)]
|
|
|
|
if __name__ == '__main__':
|
|
patterns = [
|
|
'data/pretrain_data/part-wudao*.jsonl.zst'
|
|
]
|
|
paths = create_shard_kwargs(patterns)
|
|
transform_dict = {
|
|
'wudao': lambda x: x['title'],
|
|
'pile': lambda x: [x['text']]
|
|
}
|
|
data_iter = create_data_iter(paths, transform_dict=transform_dict)
|
|
for i, data in enumerate(data_iter):
|
|
print(i, data)
|
|
if i == 20:
|
|
break |