update dataset, add concat sequence from multiple docs
This commit is contained in:
parent
0b0028097d
commit
562067230f
69
dataset/collate_fn.py
Normal file
69
dataset/collate_fn.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
"""
|
||||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-30 20:58:16
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-05 22:11:03
|
||||
FilePath: /Open-Llama/dataset/collate_fn.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def collate_fn_gen(tokenizer, segment_max_length=1024, padding="longest"):
|
||||
"""
|
||||
Organize data into tensors by padding based on the preset maximum length.
|
||||
"""
|
||||
pad_id = tokenizer.pad_id
|
||||
|
||||
def collate_fn(batch):
|
||||
if padding == "longest":
|
||||
max_length = max([len(i) for i in batch])
|
||||
elif padding == "max_length":
|
||||
max_length = segment_max_length
|
||||
else:
|
||||
raise Exception("Invalid argumet for padding: {}".format(padding))
|
||||
input_ids = []
|
||||
for i in batch:
|
||||
input_len = len(i)
|
||||
input_ids.append(i + [pad_id] * (max_length - input_len))
|
||||
inputs = {
|
||||
"input_ids": torch.tensor(input_ids, dtype=torch.int64),
|
||||
}
|
||||
return inputs
|
||||
|
||||
return collate_fn
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sentencepiece as spm
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset.pretrain_dataset import preprocess_wudao_gen, preprocess_the_pile_gen
|
||||
|
||||
from dataset.tokenizer import Tokenizer
|
||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||
|
||||
sp_model = spm.SentencePieceProcessor(
|
||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||
)
|
||||
tokenizer = Tokenizer(sp_model)
|
||||
patterns = ["data/pretrain_data/part-*.jsonl.zst"]
|
||||
paths = create_shard_kwargs(patterns)
|
||||
transform_dict = {
|
||||
"wudao": preprocess_wudao_gen(tokenizer),
|
||||
"pile": preprocess_the_pile_gen(tokenizer),
|
||||
}
|
||||
data_set = DataIter(paths, transform_dict=transform_dict)
|
||||
train_loader = DataLoader(
|
||||
data_set,
|
||||
batch_size=8,
|
||||
num_workers=4,
|
||||
collate_fn=collate_fn_gen(tokenizer),
|
||||
drop_last=True,
|
||||
)
|
||||
for batch in train_loader:
|
||||
for k, v in batch.items():
|
||||
print(k, v.shape)
|
||||
break
|
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-17 19:32:20
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-03-26 23:03:32
|
||||
LastEditTime: 2023-04-05 22:36:45
|
||||
FilePath: /Open-Llama/dataset/data_iter.py
|
||||
Description:
|
||||
|
||||
|
@ -11,67 +11,81 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
|||
import json
|
||||
from glob import glob
|
||||
import zstandard as zstd
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
|
||||
def create_data_iter(paths, transform_dict=None, process_index=0, num_processes=1):
|
||||
class DataIter(IterableDataset):
|
||||
"""
|
||||
Currently, the allowed storage formats are jsonl and jsonl.zst.
|
||||
Currently, the allowed storage formats are jsonl.zst.
|
||||
Each line of the data is a dictionary, which can be parsed as JSON for subsequent processing after reading.
|
||||
Currently, only single worker is supported.
|
||||
"""
|
||||
past = None
|
||||
for i, path in paths:
|
||||
dataset_name = path.split("-")[-2]
|
||||
if num_processes > 1 and i % num_processes != process_index:
|
||||
continue
|
||||
if past != dataset_name:
|
||||
print("Loading data from {}".format(path))
|
||||
past = path
|
||||
if path.endswith("jsonl.zst"):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
paths_with_index,
|
||||
transform_dict=None,
|
||||
max_length=None,
|
||||
concat_docs=False,
|
||||
process_index=0,
|
||||
num_processes=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.paths_with_index = paths_with_index
|
||||
self.max_length = max_length
|
||||
self.transform_dict = transform_dict
|
||||
self.concat_docs = concat_docs
|
||||
self.process_index = process_index
|
||||
self.num_processes = num_processes
|
||||
if self.concat_docs:
|
||||
self.cache = []
|
||||
|
||||
def __iter__(self):
|
||||
past = None
|
||||
for i, path in self.paths_with_index:
|
||||
# part-dataset_name-01.jsonl.zst
|
||||
dataset_name = path.split("-")[-2]
|
||||
# shard to multiple device
|
||||
if self.num_processes > 1 and i % self.num_processes != self.process_index:
|
||||
continue
|
||||
# Log the file name when encountering a new file.
|
||||
if past != dataset_name:
|
||||
print("Loading data from {}".format(path))
|
||||
past = path
|
||||
# Currently, the allowed storage formats are jsonl.zst.
|
||||
assert path.endswith("jsonl.zst")
|
||||
with zstd.open(path, "r", encoding="utf-8") as fp:
|
||||
for line in fp:
|
||||
# If the length of the cache is greater than max_length.
|
||||
if self.concat_docs and len(self.cache) >= self.max_length:
|
||||
seq = self.cache[: self.max_length]
|
||||
self.cache = self.cache[self.max_length :]
|
||||
yield seq
|
||||
if isinstance(line, bytes):
|
||||
line = line.decode("utf-8")
|
||||
line = json.loads(line)
|
||||
line["dataset"] = dataset_name
|
||||
if transform_dict:
|
||||
line = transform_dict[dataset_name](line)
|
||||
# Transformation, including sample, tokenize, etc.
|
||||
if self.transform_dict:
|
||||
line = self.transform_dict[dataset_name](line)
|
||||
if isinstance(line, str):
|
||||
yield line
|
||||
elif isinstance(line, list):
|
||||
for i in line:
|
||||
yield i
|
||||
# must be list of list
|
||||
elif isinstance(line, list) and isinstance(line[0], list):
|
||||
for seq in line:
|
||||
if self.concat_docs:
|
||||
# concat seq from multiple docs
|
||||
self.cache += seq
|
||||
else:
|
||||
yield seq
|
||||
else:
|
||||
raise Exception(
|
||||
"Unsupported type in Transformation: {}".format(
|
||||
transform_dict[dataset_name]
|
||||
self.transform_dict[dataset_name]
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield line
|
||||
elif path.endswith("jsonl"):
|
||||
with open(path, "r") as fp:
|
||||
for line in fp:
|
||||
if isinstance(line, bytes):
|
||||
line = line.decode("utf-8")
|
||||
line = json.loads(line)
|
||||
line["dataset"] = dataset_name
|
||||
if transform_dict:
|
||||
line = transform_dict[dataset_name](line)
|
||||
if isinstance(line, str):
|
||||
yield line
|
||||
elif isinstance(line, list):
|
||||
for i in line:
|
||||
yield i
|
||||
else:
|
||||
raise Exception(
|
||||
"Unsupported type in Transformation: {}".format(
|
||||
transform_dict[dataset_name]
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield line
|
||||
else:
|
||||
raise Exception("File format of {} is not supported yet.".format(path))
|
||||
|
||||
|
||||
def create_shard_kwargs(patterns, repeat=1):
|
||||
|
@ -90,7 +104,9 @@ if __name__ == "__main__":
|
|||
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
|
||||
paths = create_shard_kwargs(patterns)
|
||||
transform_dict = {"wudao": lambda x: x["title"], "pile": lambda x: [x["text"]]}
|
||||
data_iter = create_data_iter(paths, transform_dict=transform_dict)
|
||||
data_iter = DataIter(
|
||||
paths, transform_dict=transform_dict, max_length=16, concat_docs=True
|
||||
)
|
||||
for i, data in enumerate(data_iter):
|
||||
print(i, data)
|
||||
if i == 20:
|
||||
|
|
|
@ -1,192 +0,0 @@
|
|||
"""
|
||||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-30 20:58:16
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-03-30 21:00:49
|
||||
FilePath: /Open-Llama/dataset/data_loader.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024, padding="longest"):
|
||||
"""
|
||||
Organize data into tensors by padding based on the preset maximum length.
|
||||
"""
|
||||
pad_id = tokenizer.pad_id
|
||||
|
||||
def pretrain_collate_fn(batch):
|
||||
if padding == "longest":
|
||||
max_length = max([len(i) for i in batch])
|
||||
elif padding == "max_length":
|
||||
max_length = segment_max_length
|
||||
else:
|
||||
raise Exception("Invalid argumet for padding: {}".format(padding))
|
||||
input_ids = []
|
||||
for i in batch:
|
||||
input_len = len(i)
|
||||
input_ids.append(i + [pad_id] * (max_length - input_len))
|
||||
inputs = {
|
||||
"input_ids": torch.tensor(input_ids, dtype=torch.int64),
|
||||
}
|
||||
return inputs
|
||||
|
||||
return pretrain_collate_fn
|
||||
|
||||
|
||||
class BySequenceLengthDataset(torch.utils.data.IterableDataset):
|
||||
"""
|
||||
experimental
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, generator, batch_size, accelerator=None, bucket_size=16, max_length=1024
|
||||
):
|
||||
super().__init__()
|
||||
self.generator = generator
|
||||
self.batch_size = batch_size
|
||||
self.bucket_size = bucket_size
|
||||
self.bucket_num = math.ceil(max_length / bucket_size)
|
||||
self.buckets = [[] for _ in range(self.bucket_num)]
|
||||
self.bucket_idx = None
|
||||
self.accelerator = accelerator
|
||||
if self.accelerator is not None:
|
||||
self.buckets_ele_num = torch.tensor(
|
||||
[0] * self.bucket_num, dtype=torch.int64, device=accelerator.device
|
||||
)
|
||||
self.buckets_indexes = torch.arange(
|
||||
self.bucket_num, device=accelerator.device
|
||||
)
|
||||
self.finished = False
|
||||
self.has_no_same_bucket = False
|
||||
self.rest = None
|
||||
|
||||
def __iter__(self):
|
||||
if self.batch_size <= 1:
|
||||
return self.generator
|
||||
|
||||
def bucket_iter():
|
||||
while True:
|
||||
if self.bucket_idx is not None:
|
||||
sample = self.buckets[self.bucket_idx].pop()
|
||||
if len(self.buckets[self.bucket_idx]) == 0:
|
||||
self.bucket_idx = None
|
||||
yield sample
|
||||
try:
|
||||
sample = next(self.generator)
|
||||
except StopIteration:
|
||||
break
|
||||
sample_len = len(sample) - 1
|
||||
bucket_idx = sample_len // self.bucket_size
|
||||
if len(self.buckets[bucket_idx]) == self.batch_size - 1:
|
||||
self.bucket_idx = bucket_idx
|
||||
yield sample
|
||||
else:
|
||||
self.buckets[bucket_idx].append(sample)
|
||||
|
||||
def parallel_bucket_iter():
|
||||
while True:
|
||||
if self.bucket_idx is not None:
|
||||
sample = self.buckets[self.bucket_idx].pop()
|
||||
self.buckets_ele_num[self.bucket_idx] -= 1
|
||||
buckets_ele_num = self.accelerator.gather(self.buckets_ele_num)
|
||||
buckets_ele_num = buckets_ele_num.reshape(
|
||||
self.accelerator.num_processes, self.bucket_num
|
||||
)
|
||||
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
|
||||
if min_buckets_ele_num[self.bucket_idx] <= 0:
|
||||
self.bucket_idx = None
|
||||
yield sample
|
||||
else:
|
||||
if self.finished:
|
||||
if self.has_no_same_bucket:
|
||||
if self.rest is None:
|
||||
self.rest = []
|
||||
for bucket in self.buckets:
|
||||
for i in bucket:
|
||||
self.rest.append(i)
|
||||
elif len(self.rest) > 0:
|
||||
yield self.rest.pop()
|
||||
else:
|
||||
raise StopIteration()
|
||||
else:
|
||||
buckets_ele_num = self.accelerator.gather(
|
||||
self.buckets_ele_num
|
||||
)
|
||||
buckets_ele_num = buckets_ele_num.view(
|
||||
self.accelerator.num_processes, self.bucket_num
|
||||
)
|
||||
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
|
||||
valid_bucket_idx = self.buckets_indexes[
|
||||
min_buckets_ele_num >= self.batch_size
|
||||
]
|
||||
if len(valid_bucket_idx) > 0:
|
||||
self.bucket_idx = valid_bucket_idx[0].cpu().item()
|
||||
else:
|
||||
self.has_no_same_bucket = True
|
||||
else:
|
||||
try:
|
||||
sample = next(self.generator)
|
||||
except StopIteration:
|
||||
self.finished = True
|
||||
continue
|
||||
sample_len = len(sample) - 1
|
||||
bucket_idx = sample_len // self.bucket_size
|
||||
self.buckets[bucket_idx].append(sample)
|
||||
self.buckets_ele_num[bucket_idx] += 1
|
||||
buckets_ele_num = self.accelerator.gather(
|
||||
self.buckets_ele_num
|
||||
).cpu()
|
||||
buckets_ele_num = buckets_ele_num.view(
|
||||
self.accelerator.num_processes, self.bucket_num
|
||||
)
|
||||
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
|
||||
valid_bucket_idx = self.buckets_indexes[
|
||||
min_buckets_ele_num >= self.batch_size
|
||||
]
|
||||
if len(valid_bucket_idx) > 0:
|
||||
self.bucket_idx = valid_bucket_idx[0].cpu().item()
|
||||
|
||||
if self.accelerator:
|
||||
return parallel_bucket_iter()
|
||||
else:
|
||||
return bucket_iter()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sentencepiece as spm
|
||||
from datasets import IterableDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset.pretrain_dataset import preprocess_wudao_gen, preprocess_the_pile_gen
|
||||
|
||||
from dataset.tokenizer import Tokenizer
|
||||
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
||||
|
||||
sp_model = spm.SentencePieceProcessor(
|
||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||
)
|
||||
tokenizer = Tokenizer(sp_model)
|
||||
patterns = ["data/pretrain_data/part-*.jsonl.zst"]
|
||||
paths = create_shard_kwargs(patterns)
|
||||
transform_dict = {
|
||||
"wudao": preprocess_wudao_gen(tokenizer),
|
||||
"pile": preprocess_the_pile_gen(tokenizer),
|
||||
}
|
||||
data_set = IterableDataset.from_generator(
|
||||
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
|
||||
)
|
||||
train_loader = DataLoader(
|
||||
data_set,
|
||||
batch_size=8,
|
||||
num_workers=4,
|
||||
collate_fn=pretrain_collate_fn_gen(tokenizer),
|
||||
drop_last=True,
|
||||
)
|
||||
for batch in train_loader:
|
||||
for k, v in batch.items():
|
||||
print(k, v.shape)
|
||||
break
|
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-30 21:02:00
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-03-30 21:02:06
|
||||
LastEditTime: 2023-04-05 22:35:24
|
||||
FilePath: /Open-Llama/dataset/instruction_dataset.py
|
||||
Description:
|
||||
|
||||
|
@ -60,7 +60,7 @@ if __name__ == "__main__":
|
|||
from datasets import IterableDataset
|
||||
|
||||
from dataset.tokenizer import Tokenizer
|
||||
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||
|
||||
sp_model = spm.SentencePieceProcessor(
|
||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||
|
@ -73,8 +73,8 @@ if __name__ == "__main__":
|
|||
"belle_0.5M": preprocess_belle_gen(tokenizer),
|
||||
"self_instruct": preprocess_self_instruction_gen(tokenizer),
|
||||
}
|
||||
data_set = IterableDataset.from_generator(
|
||||
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
|
||||
data_set = DataIter(
|
||||
paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
|
||||
)
|
||||
for i, sample in enumerate(data_set):
|
||||
print(sample, sp_model.Decode(sample))
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-17 20:41:25
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-03-26 23:07:56
|
||||
LastEditTime: 2023-04-05 22:32:39
|
||||
FilePath: /Open-Llama/dataset/pretrain_dataset.py
|
||||
Description:
|
||||
|
||||
|
@ -49,10 +49,9 @@ def preprocess_the_pile_gen(tokenizer, segment_max_length=1024):
|
|||
|
||||
if __name__ == "__main__":
|
||||
import sentencepiece as spm
|
||||
from datasets import IterableDataset
|
||||
|
||||
from dataset.tokenizer import Tokenizer
|
||||
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||
|
||||
sp_model = spm.SentencePieceProcessor(
|
||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||
|
@ -64,8 +63,8 @@ if __name__ == "__main__":
|
|||
"wudao": preprocess_wudao_gen(tokenizer),
|
||||
"pile": preprocess_the_pile_gen(tokenizer),
|
||||
}
|
||||
data_set = IterableDataset.from_generator(
|
||||
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
|
||||
data_set = DataIter(
|
||||
paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
|
||||
)
|
||||
for sample in data_set:
|
||||
print(sample)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-20 21:39:47
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-03-26 23:09:39
|
||||
LastEditTime: 2023-04-05 22:35:01
|
||||
FilePath: /Open-Llama/dataset/tokenizer.py
|
||||
Description:
|
||||
|
||||
|
@ -183,11 +183,11 @@ if __name__ == "__main__":
|
|||
for i, j in zip(tmp, out):
|
||||
assert normalize("NFKC", i) == j
|
||||
|
||||
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||
|
||||
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
|
||||
paths = create_shard_kwargs(patterns)
|
||||
data_iter = create_data_iter(paths)
|
||||
data_iter = DataIter(paths)
|
||||
for i, data in enumerate(data_iter):
|
||||
assert (
|
||||
normalize("NFKC", data["content"])
|
||||
|
|
|
@ -2,14 +2,14 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-24 20:49:03
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-03-26 23:43:59
|
||||
LastEditTime: 2023-04-05 22:40:29
|
||||
FilePath: /Open-Llama/dataset/train_tokenizer.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import random
|
||||
from dataset.data_iter import create_data_iter, create_shard_kwargs
|
||||
from dataset.data_iter import DataIter, create_shard_kwargs
|
||||
|
||||
wudao_patterns = [
|
||||
"data/pretrain_data/part-wudao-*.jsonl.zst",
|
||||
|
@ -24,10 +24,10 @@ pile_paths = create_shard_kwargs(pile_patterns)
|
|||
random.shuffle(pile_paths)
|
||||
paths = wudao_paths[:5] + pile_paths[:10]
|
||||
transform_dict = {
|
||||
"wudao": lambda line: [(line["title"] + "\n" + line["content"])],
|
||||
"pile": lambda line: [line["text"]],
|
||||
"wudao": lambda line: line["title"] + "\n" + line["content"],
|
||||
"pile": lambda line: line["text"],
|
||||
}
|
||||
data_iter = create_data_iter(paths, transform_dict)
|
||||
data_iter = iter(DataIter(paths, transform_dict))
|
||||
|
||||
import io
|
||||
import sentencepiece as spm
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-31 13:26:15
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-03-31 14:05:35
|
||||
LastEditTime: 2023-04-05 21:47:54
|
||||
FilePath: /Open-Llama/server.py
|
||||
Description:
|
||||
|
||||
|
@ -32,7 +32,9 @@ raw_model = LlamaForCausalLM(
|
|||
shared_input_output_embedding=True,
|
||||
)
|
||||
)
|
||||
ckpt = torch.load("data/saved_ckpt/instruction_tuning_3_epochs/23001.pt", map_location="cpu")
|
||||
ckpt = torch.load(
|
||||
"data/saved_ckpt/instruction_tuning_3_epochs/37001.pt", map_location="cpu"
|
||||
)
|
||||
raw_model.load_state_dict(ckpt)
|
||||
raw_model.eval()
|
||||
model = raw_model.cuda()
|
||||
|
|
Loading…
Reference in New Issue
Block a user