""" Author: LiangSong(sl12160010@gmail.com) Date: 2023-03-17 20:41:25 LastEditors: LiangSong(sl12160010@gmail.com) LastEditTime: 2023-03-26 23:07:56 FilePath: /Open-Llama/dataset/pretrain_dataset.py Description: Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. """ import math import torch def preprocess_wudao_gen(tokenizer, segment_max_length=1024): def preprocess_wudao(line): """ The format of the data is roughly as follows. {'id': 1, 'dataType': '百科', 'title': 'some title', 'content': 'some content'} Split the data based on the tokenized length according to the maximum length. """ total = line["title"] + "\n" + line["content"] out = tokenizer(total) input_ids = out["input_ids"] return [ input_ids[i * segment_max_length : (i + 1) * segment_max_length] for i in range(math.ceil(len(input_ids) / segment_max_length)) ] return preprocess_wudao def preprocess_the_pile_gen(tokenizer, segment_max_length=1024): def preprocess_the_pile(line): """ The format of the data is roughly as follows. {'text': 'some text', 'meta': {'pile_set_name': 'Github'}} Split the data based on the tokenized length according to the maximum length. """ total = line["text"] out = tokenizer(total) input_ids = out["input_ids"] return [ input_ids[i * segment_max_length : (i + 1) * segment_max_length] for i in range(math.ceil(len(input_ids) / segment_max_length)) ] return preprocess_the_pile def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024): """ Organize data into tensors by padding based on the preset maximum length. """ pad_id = tokenizer.pad_id def pretrain_collate_fn(batch): input_ids = [] for i in batch: input_len = len(i) input_ids.append(i + [pad_id] * (segment_max_length - input_len)) inputs = { "input_ids": torch.tensor(input_ids, dtype=torch.int64), } return inputs return pretrain_collate_fn class BucketBySequenceLengthDataset(torch.utils.data.IterableDataset): def __init__(self, generator, batch_size, bucket_size=32, max_length=1024): super().__init__() self.generator = generator self.batch_size = batch_size self.bucket_size = bucket_size self.bucket_num = math.ceil(max_length / bucket_size) self.buckets = [[] for _ in range(self.bucket_num)] self.bucket_idx = None def __iter__(self): if self.batch_size <= 1: return self.generator def bucket_iter(): if self.bucket_idx is not None: sample = self.buckets[self.bucket_idx].pop() if len(self.buckets[self.bucket_idx]) == 0: self.bucket_idx = None yield sample sample = next(self.generator) - 1 sample_len = len(sample) bucket_idx = sample_len // self.bucket_size if len(self.buckets[bucket_idx]) == self.batch_size - 1: self.bucket_idx = bucket_idx yield sample else: self.buckets[bucket_idx].append(sample) return bucket_iter() if __name__ == "__main__": import sentencepiece as spm from datasets import IterableDataset from torch.utils.data import DataLoader from dataset.tokenizer import Tokenizer from dataset.data_iter import create_shard_kwargs, create_data_iter sp_model = spm.SentencePieceProcessor( model_file="configs/10w_vocab_wudao5_pile10.model" ) tokenizer = Tokenizer(sp_model) patterns = ["data/pretrain_data/part-*.jsonl.zst"] paths = create_shard_kwargs(patterns) transform_dict = { "wudao": preprocess_wudao_gen(tokenizer), "pile": preprocess_the_pile_gen(tokenizer), } data_set = IterableDataset.from_generator( create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict} ) train_loader = DataLoader( data_set, batch_size=8, num_workers=4, collate_fn=pretrain_collate_fn_gen(tokenizer), drop_last=True, ) for batch in train_loader: for k, v in batch.items(): print(k, v.shape) break