Open-Llama/dataset/collate_fn.py

70 lines
2.1 KiB
Python

"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 20:58:16
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-05 22:11:03
FilePath: /Open-Llama/dataset/collate_fn.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import torch
def collate_fn_gen(tokenizer, segment_max_length=1024, padding="longest"):
"""
Organize data into tensors by padding based on the preset maximum length.
"""
pad_id = tokenizer.pad_id
def collate_fn(batch):
if padding == "longest":
max_length = max([len(i) for i in batch])
elif padding == "max_length":
max_length = segment_max_length
else:
raise Exception("Invalid argumet for padding: {}".format(padding))
input_ids = []
for i in batch:
input_len = len(i)
input_ids.append(i + [pad_id] * (max_length - input_len))
inputs = {
"input_ids": torch.tensor(input_ids, dtype=torch.int64),
}
return inputs
return collate_fn
if __name__ == "__main__":
import sentencepiece as spm
from torch.utils.data import DataLoader
from dataset.pretrain_dataset import preprocess_wudao_gen, preprocess_the_pile_gen
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, DataIter
sp_model = spm.SentencePieceProcessor(
model_file="configs/10w_vocab_wudao5_pile10.model"
)
tokenizer = Tokenizer(sp_model)
patterns = ["data/pretrain_data/part-*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
transform_dict = {
"wudao": preprocess_wudao_gen(tokenizer),
"pile": preprocess_the_pile_gen(tokenizer),
}
data_set = DataIter(paths, transform_dict=transform_dict)
train_loader = DataLoader(
data_set,
batch_size=8,
num_workers=4,
collate_fn=collate_fn_gen(tokenizer),
drop_last=True,
)
for batch in train_loader:
for k, v in batch.items():
print(k, v.shape)
break