''' Author: LiangSong(sl12160010@gmail.com) Date: 2023-03-17 19:32:20 LastEditors: LiangSong(sl12160010@gmail.com) LastEditTime: 2023-03-26 23:03:32 FilePath: /Open-Llama/dataset/data_iter.py Description: Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. ''' import json from glob import glob import zstandard as zstd def create_data_iter(paths, transform_dict=None, process_index=0, num_processes=1): ''' Currently, the allowed storage formats are jsonl and jsonl.zst. Each line of the data is a dictionary, which can be parsed as JSON for subsequent processing after reading. ''' past = None for i, path in paths: dataset_name = path.split('-')[-2] if past != dataset_name: print('Loading data from {}'.format(path)) past = path if num_processes > 1 and i % num_processes != process_index: continue if path.endswith('jsonl.zst'): with zstd.open(path, 'r', encoding='utf-8') as fp: for line in fp: if isinstance(line, bytes): line = line.decode('utf-8') line = json.loads(line) line['dataset'] = dataset_name if transform_dict: line = transform_dict[dataset_name](line) if isinstance(line, str): yield line elif isinstance(line, list): for i in line: yield i else: raise Exception('Unsupported type in Transformation: {}'.format(transform_dict[dataset_name])) else: yield line elif path.endswith('jsonl'): with open(path, 'r') as fp: for line in fp: if isinstance(line, bytes): line = line.decode('utf-8') line = json.loads(line) line['dataset'] = dataset_name if transform_dict: line = transform_dict[dataset_name](line) if isinstance(line, str): yield line elif isinstance(line, list): for i in line: yield i else: raise Exception('Unsupported type in Transformation: {}'.format(transform_dict[dataset_name])) else: yield line else: raise Exception('File format of {} is not supported yet.'.format(path)) def create_shard_kwargs(patterns, repeat=1): ''' Assign numbers to different shards of data to ensure that data is not duplicated when allocated to different nodes during distributed training. ''' all_path = [] for p in patterns: all_path.extend(glob(p)) all_path *= repeat return [(i, p) for i, p in enumerate(all_path)] if __name__ == '__main__': patterns = [ 'data/pretrain_data/part-wudao*.jsonl.zst' ] paths = create_shard_kwargs(patterns) transform_dict = { 'wudao': lambda x: x['title'], 'pile': lambda x: [x['text']] } data_iter = create_data_iter(paths, transform_dict=transform_dict) for i, data in enumerate(data_iter): print(i, data) if i == 20: break