diff --git a/dataset/pretrain_dataset.py b/dataset/pretrain_dataset.py index 55de7e3..18e0013 100644 --- a/dataset/pretrain_dataset.py +++ b/dataset/pretrain_dataset.py @@ -67,6 +67,36 @@ def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024): return pretrain_collate_fn +class BucketBySequenceLengthDataset(torch.utils.data.IterableDataset): + def __init__(self, generator, batch_size, bucket_size=32, max_length=1024): + super().__init__() + self.generator = generator + self.batch_size = batch_size + self.bucket_size = bucket_size + self.bucket_num = math.ceil(max_length / bucket_size) + self.buckets = [[] for _ in range(self.bucket_num)] + self.bucket_idx = None + + def __iter__(self): + if self.batch_size <= 1: + return self.generator + def bucket_iter(): + if self.bucket_idx is not None: + sample = self.buckets[self.bucket_idx].pop() + if len(self.buckets[self.bucket_idx]) == 0: + self.bucket_idx = None + yield sample + sample = next(self.generator) - 1 + sample_len = len(sample) + bucket_idx = sample_len // self.bucket_size + if len(self.buckets[bucket_idx]) == self.batch_size - 1: + self.bucket_idx = bucket_idx + yield sample + else: + self.buckets[bucket_idx].append(sample) + return bucket_iter() + + if __name__ == "__main__": import sentencepiece as spm from datasets import IterableDataset