193 lines
7.3 KiB
Python
193 lines
7.3 KiB
Python
"""
|
|
Author: LiangSong(sl12160010@gmail.com)
|
|
Date: 2023-03-30 20:58:16
|
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
|
LastEditTime: 2023-03-30 21:00:49
|
|
FilePath: /Open-Llama/dataset/data_loader.py
|
|
Description:
|
|
|
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
|
"""
|
|
import math
|
|
import torch
|
|
|
|
|
|
def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024, padding="longest"):
|
|
"""
|
|
Organize data into tensors by padding based on the preset maximum length.
|
|
"""
|
|
pad_id = tokenizer.pad_id
|
|
|
|
def pretrain_collate_fn(batch):
|
|
if padding == "longest":
|
|
max_length = max([len(i) for i in batch])
|
|
elif padding == "max_length":
|
|
max_length = segment_max_length
|
|
else:
|
|
raise Exception("Invalid argumet for padding: {}".format(padding))
|
|
input_ids = []
|
|
for i in batch:
|
|
input_len = len(i)
|
|
input_ids.append(i + [pad_id] * (max_length - input_len))
|
|
inputs = {
|
|
"input_ids": torch.tensor(input_ids, dtype=torch.int64),
|
|
}
|
|
return inputs
|
|
|
|
return pretrain_collate_fn
|
|
|
|
|
|
class BySequenceLengthDataset(torch.utils.data.IterableDataset):
|
|
"""
|
|
experimental
|
|
"""
|
|
|
|
def __init__(
|
|
self, generator, batch_size, accelerator=None, bucket_size=16, max_length=1024
|
|
):
|
|
super().__init__()
|
|
self.generator = generator
|
|
self.batch_size = batch_size
|
|
self.bucket_size = bucket_size
|
|
self.bucket_num = math.ceil(max_length / bucket_size)
|
|
self.buckets = [[] for _ in range(self.bucket_num)]
|
|
self.bucket_idx = None
|
|
self.accelerator = accelerator
|
|
if self.accelerator is not None:
|
|
self.buckets_ele_num = torch.tensor(
|
|
[0] * self.bucket_num, dtype=torch.int64, device=accelerator.device
|
|
)
|
|
self.buckets_indexes = torch.arange(
|
|
self.bucket_num, device=accelerator.device
|
|
)
|
|
self.finished = False
|
|
self.has_no_same_bucket = False
|
|
self.rest = None
|
|
|
|
def __iter__(self):
|
|
if self.batch_size <= 1:
|
|
return self.generator
|
|
|
|
def bucket_iter():
|
|
while True:
|
|
if self.bucket_idx is not None:
|
|
sample = self.buckets[self.bucket_idx].pop()
|
|
if len(self.buckets[self.bucket_idx]) == 0:
|
|
self.bucket_idx = None
|
|
yield sample
|
|
try:
|
|
sample = next(self.generator)
|
|
except StopIteration:
|
|
break
|
|
sample_len = len(sample) - 1
|
|
bucket_idx = sample_len // self.bucket_size
|
|
if len(self.buckets[bucket_idx]) == self.batch_size - 1:
|
|
self.bucket_idx = bucket_idx
|
|
yield sample
|
|
else:
|
|
self.buckets[bucket_idx].append(sample)
|
|
|
|
def parallel_bucket_iter():
|
|
while True:
|
|
if self.bucket_idx is not None:
|
|
sample = self.buckets[self.bucket_idx].pop()
|
|
self.buckets_ele_num[self.bucket_idx] -= 1
|
|
buckets_ele_num = self.accelerator.gather(self.buckets_ele_num)
|
|
buckets_ele_num = buckets_ele_num.reshape(
|
|
self.accelerator.num_processes, self.bucket_num
|
|
)
|
|
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
|
|
if min_buckets_ele_num[self.bucket_idx] <= 0:
|
|
self.bucket_idx = None
|
|
yield sample
|
|
else:
|
|
if self.finished:
|
|
if self.has_no_same_bucket:
|
|
if self.rest is None:
|
|
self.rest = []
|
|
for bucket in self.buckets:
|
|
for i in bucket:
|
|
self.rest.append(i)
|
|
elif len(self.rest) > 0:
|
|
yield self.rest.pop()
|
|
else:
|
|
raise StopIteration()
|
|
else:
|
|
buckets_ele_num = self.accelerator.gather(
|
|
self.buckets_ele_num
|
|
)
|
|
buckets_ele_num = buckets_ele_num.view(
|
|
self.accelerator.num_processes, self.bucket_num
|
|
)
|
|
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
|
|
valid_bucket_idx = self.buckets_indexes[
|
|
min_buckets_ele_num >= self.batch_size
|
|
]
|
|
if len(valid_bucket_idx) > 0:
|
|
self.bucket_idx = valid_bucket_idx[0].cpu().item()
|
|
else:
|
|
self.has_no_same_bucket = True
|
|
else:
|
|
try:
|
|
sample = next(self.generator)
|
|
except StopIteration:
|
|
self.finished = True
|
|
continue
|
|
sample_len = len(sample) - 1
|
|
bucket_idx = sample_len // self.bucket_size
|
|
self.buckets[bucket_idx].append(sample)
|
|
self.buckets_ele_num[bucket_idx] += 1
|
|
buckets_ele_num = self.accelerator.gather(
|
|
self.buckets_ele_num
|
|
).cpu()
|
|
buckets_ele_num = buckets_ele_num.view(
|
|
self.accelerator.num_processes, self.bucket_num
|
|
)
|
|
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
|
|
valid_bucket_idx = self.buckets_indexes[
|
|
min_buckets_ele_num >= self.batch_size
|
|
]
|
|
if len(valid_bucket_idx) > 0:
|
|
self.bucket_idx = valid_bucket_idx[0].cpu().item()
|
|
|
|
if self.accelerator:
|
|
return parallel_bucket_iter()
|
|
else:
|
|
return bucket_iter()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sentencepiece as spm
|
|
from datasets import IterableDataset
|
|
from torch.utils.data import DataLoader
|
|
|
|
from dataset.pretrain_dataset import preprocess_wudao_gen, preprocess_the_pile_gen
|
|
|
|
from dataset.tokenizer import Tokenizer
|
|
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
|
|
|
sp_model = spm.SentencePieceProcessor(
|
|
model_file="configs/10w_vocab_wudao5_pile10.model"
|
|
)
|
|
tokenizer = Tokenizer(sp_model)
|
|
patterns = ["data/pretrain_data/part-*.jsonl.zst"]
|
|
paths = create_shard_kwargs(patterns)
|
|
transform_dict = {
|
|
"wudao": preprocess_wudao_gen(tokenizer),
|
|
"pile": preprocess_the_pile_gen(tokenizer),
|
|
}
|
|
data_set = IterableDataset.from_generator(
|
|
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
|
|
)
|
|
train_loader = DataLoader(
|
|
data_set,
|
|
batch_size=8,
|
|
num_workers=4,
|
|
collate_fn=pretrain_collate_fn_gen(tokenizer),
|
|
drop_last=True,
|
|
)
|
|
for batch in train_loader:
|
|
for k, v in batch.items():
|
|
print(k, v.shape)
|
|
break
|