add instruction-tuning
This commit is contained in:
parent
e1bd1766bc
commit
a62ac2658f
25
configs/instruction_tuning_config.py
Normal file
25
configs/instruction_tuning_config.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
"""
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-03-30 21:38:07
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-03-30 21:39:40
|
||||||
|
FilePath: /Open-Llama/configs/instruction_tuning_config.py
|
||||||
|
Description:
|
||||||
|
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
"""
|
||||||
|
max_length = 1024
|
||||||
|
train_batch_size = 2
|
||||||
|
num_training_steps = 37500
|
||||||
|
num_warmup_steps = 100
|
||||||
|
initializer_range = 1e-2
|
||||||
|
lr = 2e-4
|
||||||
|
weight_decay = 1e-1
|
||||||
|
tokenizer_model_path = "configs/10w_vocab_wudao5_pile10.model"
|
||||||
|
patterns = ["data/instruction_data/part-*.jsonl.zst"]
|
||||||
|
# global step
|
||||||
|
log_interval = 50
|
||||||
|
eval_interval = 500
|
||||||
|
save_interval = 1000
|
||||||
|
work_dir = "data/saved_ckpt/"
|
||||||
|
ckpt_path = "data/saved_ckpt/30000.pt"
|
61
data/preprocess_instruction.py
Normal file
61
data/preprocess_instruction.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
"""
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-03-30 20:52:10
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-03-30 20:52:12
|
||||||
|
FilePath: /Open-Llama/data/preprocess_instruction.py
|
||||||
|
Description:
|
||||||
|
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import zstandard as zstd
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
dataset = load_dataset("yizhongw/self_instruct")
|
||||||
|
write_path = "data/instruction_data/part-self_instruct-{}.jsonl.zst"
|
||||||
|
total_num = 0
|
||||||
|
file_num = 0
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
for line in dataset["train"]:
|
||||||
|
line = json.dumps(line)
|
||||||
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
|
file_num += 1
|
||||||
|
wfp.close()
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
wfp.write(line.encode("utf-8"))
|
||||||
|
wfp.write(b"\n")
|
||||||
|
total_num += 1
|
||||||
|
wfp.close()
|
||||||
|
|
||||||
|
dataset = load_dataset("BelleGroup/generated_train_0.5M_CN")
|
||||||
|
write_path = "data/instruction_data/part-belle_0.5M-{}.jsonl.zst"
|
||||||
|
total_num = 0
|
||||||
|
file_num = 0
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
for line in dataset["train"]:
|
||||||
|
line = json.dumps(line)
|
||||||
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
|
file_num += 1
|
||||||
|
wfp.close()
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
wfp.write(line.encode("utf-8"))
|
||||||
|
wfp.write(b"\n")
|
||||||
|
total_num += 1
|
||||||
|
wfp.close()
|
||||||
|
|
||||||
|
dataset = load_dataset("BelleGroup/generated_train_1M_CN")
|
||||||
|
write_path = "data/instruction_data/part-belle_1M-{}.jsonl.zst"
|
||||||
|
total_num = 0
|
||||||
|
file_num = 0
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
for line in dataset["train"]:
|
||||||
|
line = json.dumps(line)
|
||||||
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
|
file_num += 1
|
||||||
|
wfp.close()
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
wfp.write(line.encode("utf-8"))
|
||||||
|
wfp.write(b"\n")
|
||||||
|
total_num += 1
|
||||||
|
wfp.close()
|
192
dataset/data_loader.py
Normal file
192
dataset/data_loader.py
Normal file
|
@ -0,0 +1,192 @@
|
||||||
|
"""
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-03-30 20:58:16
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-03-30 21:00:49
|
||||||
|
FilePath: /Open-Llama/dataset/data_loader.py
|
||||||
|
Description:
|
||||||
|
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024, padding="longest"):
|
||||||
|
"""
|
||||||
|
Organize data into tensors by padding based on the preset maximum length.
|
||||||
|
"""
|
||||||
|
pad_id = tokenizer.pad_id
|
||||||
|
|
||||||
|
def pretrain_collate_fn(batch):
|
||||||
|
if padding == "longest":
|
||||||
|
max_length = max([len(i) for i in batch])
|
||||||
|
elif padding == "max_length":
|
||||||
|
max_length = segment_max_length
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid argumet for padding: {}".format(padding))
|
||||||
|
input_ids = []
|
||||||
|
for i in batch:
|
||||||
|
input_len = len(i)
|
||||||
|
input_ids.append(i + [pad_id] * (max_length - input_len))
|
||||||
|
inputs = {
|
||||||
|
"input_ids": torch.tensor(input_ids, dtype=torch.int64),
|
||||||
|
}
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
return pretrain_collate_fn
|
||||||
|
|
||||||
|
|
||||||
|
class BySequenceLengthDataset(torch.utils.data.IterableDataset):
|
||||||
|
"""
|
||||||
|
experimental
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, generator, batch_size, accelerator=None, bucket_size=16, max_length=1024
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.generator = generator
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.bucket_size = bucket_size
|
||||||
|
self.bucket_num = math.ceil(max_length / bucket_size)
|
||||||
|
self.buckets = [[] for _ in range(self.bucket_num)]
|
||||||
|
self.bucket_idx = None
|
||||||
|
self.accelerator = accelerator
|
||||||
|
if self.accelerator is not None:
|
||||||
|
self.buckets_ele_num = torch.tensor(
|
||||||
|
[0] * self.bucket_num, dtype=torch.int64, device=accelerator.device
|
||||||
|
)
|
||||||
|
self.buckets_indexes = torch.arange(
|
||||||
|
self.bucket_num, device=accelerator.device
|
||||||
|
)
|
||||||
|
self.finished = False
|
||||||
|
self.has_no_same_bucket = False
|
||||||
|
self.rest = None
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self.batch_size <= 1:
|
||||||
|
return self.generator
|
||||||
|
|
||||||
|
def bucket_iter():
|
||||||
|
while True:
|
||||||
|
if self.bucket_idx is not None:
|
||||||
|
sample = self.buckets[self.bucket_idx].pop()
|
||||||
|
if len(self.buckets[self.bucket_idx]) == 0:
|
||||||
|
self.bucket_idx = None
|
||||||
|
yield sample
|
||||||
|
try:
|
||||||
|
sample = next(self.generator)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
sample_len = len(sample) - 1
|
||||||
|
bucket_idx = sample_len // self.bucket_size
|
||||||
|
if len(self.buckets[bucket_idx]) == self.batch_size - 1:
|
||||||
|
self.bucket_idx = bucket_idx
|
||||||
|
yield sample
|
||||||
|
else:
|
||||||
|
self.buckets[bucket_idx].append(sample)
|
||||||
|
|
||||||
|
def parallel_bucket_iter():
|
||||||
|
while True:
|
||||||
|
if self.bucket_idx is not None:
|
||||||
|
sample = self.buckets[self.bucket_idx].pop()
|
||||||
|
self.buckets_ele_num[self.bucket_idx] -= 1
|
||||||
|
buckets_ele_num = self.accelerator.gather(self.buckets_ele_num)
|
||||||
|
buckets_ele_num = buckets_ele_num.reshape(
|
||||||
|
self.accelerator.num_processes, self.bucket_num
|
||||||
|
)
|
||||||
|
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
|
||||||
|
if min_buckets_ele_num[self.bucket_idx] <= 0:
|
||||||
|
self.bucket_idx = None
|
||||||
|
yield sample
|
||||||
|
else:
|
||||||
|
if self.finished:
|
||||||
|
if self.has_no_same_bucket:
|
||||||
|
if self.rest is None:
|
||||||
|
self.rest = []
|
||||||
|
for bucket in self.buckets:
|
||||||
|
for i in bucket:
|
||||||
|
self.rest.append(i)
|
||||||
|
elif len(self.rest) > 0:
|
||||||
|
yield self.rest.pop()
|
||||||
|
else:
|
||||||
|
raise StopIteration()
|
||||||
|
else:
|
||||||
|
buckets_ele_num = self.accelerator.gather(
|
||||||
|
self.buckets_ele_num
|
||||||
|
)
|
||||||
|
buckets_ele_num = buckets_ele_num.view(
|
||||||
|
self.accelerator.num_processes, self.bucket_num
|
||||||
|
)
|
||||||
|
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
|
||||||
|
valid_bucket_idx = self.buckets_indexes[
|
||||||
|
min_buckets_ele_num >= self.batch_size
|
||||||
|
]
|
||||||
|
if len(valid_bucket_idx) > 0:
|
||||||
|
self.bucket_idx = valid_bucket_idx[0].cpu().item()
|
||||||
|
else:
|
||||||
|
self.has_no_same_bucket = True
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
sample = next(self.generator)
|
||||||
|
except StopIteration:
|
||||||
|
self.finished = True
|
||||||
|
continue
|
||||||
|
sample_len = len(sample) - 1
|
||||||
|
bucket_idx = sample_len // self.bucket_size
|
||||||
|
self.buckets[bucket_idx].append(sample)
|
||||||
|
self.buckets_ele_num[bucket_idx] += 1
|
||||||
|
buckets_ele_num = self.accelerator.gather(
|
||||||
|
self.buckets_ele_num
|
||||||
|
).cpu()
|
||||||
|
buckets_ele_num = buckets_ele_num.view(
|
||||||
|
self.accelerator.num_processes, self.bucket_num
|
||||||
|
)
|
||||||
|
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
|
||||||
|
valid_bucket_idx = self.buckets_indexes[
|
||||||
|
min_buckets_ele_num >= self.batch_size
|
||||||
|
]
|
||||||
|
if len(valid_bucket_idx) > 0:
|
||||||
|
self.bucket_idx = valid_bucket_idx[0].cpu().item()
|
||||||
|
|
||||||
|
if self.accelerator:
|
||||||
|
return parallel_bucket_iter()
|
||||||
|
else:
|
||||||
|
return bucket_iter()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sentencepiece as spm
|
||||||
|
from datasets import IterableDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from dataset.pretrain_dataset import preprocess_wudao_gen, preprocess_the_pile_gen
|
||||||
|
|
||||||
|
from dataset.tokenizer import Tokenizer
|
||||||
|
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
||||||
|
|
||||||
|
sp_model = spm.SentencePieceProcessor(
|
||||||
|
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||||
|
)
|
||||||
|
tokenizer = Tokenizer(sp_model)
|
||||||
|
patterns = ["data/pretrain_data/part-*.jsonl.zst"]
|
||||||
|
paths = create_shard_kwargs(patterns)
|
||||||
|
transform_dict = {
|
||||||
|
"wudao": preprocess_wudao_gen(tokenizer),
|
||||||
|
"pile": preprocess_the_pile_gen(tokenizer),
|
||||||
|
}
|
||||||
|
data_set = IterableDataset.from_generator(
|
||||||
|
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
|
||||||
|
)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
data_set,
|
||||||
|
batch_size=8,
|
||||||
|
num_workers=4,
|
||||||
|
collate_fn=pretrain_collate_fn_gen(tokenizer),
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
for batch in train_loader:
|
||||||
|
for k, v in batch.items():
|
||||||
|
print(k, v.shape)
|
||||||
|
break
|
75
dataset/instruction_dataset.py
Normal file
75
dataset/instruction_dataset.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
"""
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-03-30 21:02:00
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-03-30 21:02:06
|
||||||
|
FilePath: /Open-Llama/dataset/instruction_dataset.py
|
||||||
|
Description:
|
||||||
|
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
|
||||||
|
def preprocess_self_instruction(line):
|
||||||
|
"""
|
||||||
|
The format of the data is roughly as follows.
|
||||||
|
{'prompt': 'Explain the origin of life on earth. Output:', 'completion': 'Life on Earth is believed to have'}
|
||||||
|
Split the data based on the tokenized length according to the maximum length.
|
||||||
|
"""
|
||||||
|
prompt = line["prompt"]
|
||||||
|
if prompt.endswith("Output:"):
|
||||||
|
prompt = prompt[:-7]
|
||||||
|
total = "user:{}<s>system:{}".format(prompt.strip(), line["completion"].strip())
|
||||||
|
out = tokenizer(total)
|
||||||
|
input_ids = out["input_ids"]
|
||||||
|
return [input_ids]
|
||||||
|
|
||||||
|
return preprocess_self_instruction
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_belle_gen(tokenizer, segment_max_length=1024):
|
||||||
|
def preprocess_belle(line):
|
||||||
|
"""
|
||||||
|
The format of the data is roughly as follows.
|
||||||
|
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||||
|
Split the data based on the tokenized length according to the maximum length.
|
||||||
|
"""
|
||||||
|
prompt = line["input"].replace("\\n", "")
|
||||||
|
prompt = prompt.strip("")
|
||||||
|
|
||||||
|
completion = line["target"].replace("\\n", "")
|
||||||
|
completion = completion.strip("")
|
||||||
|
total = "user:{}<s>system:{}".format(prompt, completion)
|
||||||
|
out = tokenizer(total)
|
||||||
|
input_ids = out["input_ids"]
|
||||||
|
return [input_ids]
|
||||||
|
|
||||||
|
return preprocess_belle
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sentencepiece as spm
|
||||||
|
from datasets import IterableDataset
|
||||||
|
|
||||||
|
from dataset.tokenizer import Tokenizer
|
||||||
|
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
||||||
|
|
||||||
|
sp_model = spm.SentencePieceProcessor(
|
||||||
|
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||||
|
)
|
||||||
|
tokenizer = Tokenizer(sp_model)
|
||||||
|
patterns = ["data/instruction_data/part-belle_1M*.jsonl.zst"]
|
||||||
|
paths = create_shard_kwargs(patterns)
|
||||||
|
transform_dict = {
|
||||||
|
"belle_1M": preprocess_belle_gen(tokenizer),
|
||||||
|
"belle_0.5M": preprocess_belle_gen(tokenizer),
|
||||||
|
"self_instruct": preprocess_self_instruction_gen(tokenizer),
|
||||||
|
}
|
||||||
|
data_set = IterableDataset.from_generator(
|
||||||
|
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
|
||||||
|
)
|
||||||
|
for i, sample in enumerate(data_set):
|
||||||
|
print(sample, sp_model.Decode(sample))
|
||||||
|
if i == 20:
|
||||||
|
break
|
|
@ -9,7 +9,6 @@ Description:
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_wudao_gen(tokenizer, segment_max_length=1024):
|
def preprocess_wudao_gen(tokenizer, segment_max_length=1024):
|
||||||
|
@ -48,59 +47,9 @@ def preprocess_the_pile_gen(tokenizer, segment_max_length=1024):
|
||||||
return preprocess_the_pile
|
return preprocess_the_pile
|
||||||
|
|
||||||
|
|
||||||
def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024):
|
|
||||||
"""
|
|
||||||
Organize data into tensors by padding based on the preset maximum length.
|
|
||||||
"""
|
|
||||||
pad_id = tokenizer.pad_id
|
|
||||||
|
|
||||||
def pretrain_collate_fn(batch):
|
|
||||||
input_ids = []
|
|
||||||
for i in batch:
|
|
||||||
input_len = len(i)
|
|
||||||
input_ids.append(i + [pad_id] * (segment_max_length - input_len))
|
|
||||||
inputs = {
|
|
||||||
"input_ids": torch.tensor(input_ids, dtype=torch.int64),
|
|
||||||
}
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
return pretrain_collate_fn
|
|
||||||
|
|
||||||
|
|
||||||
class BucketBySequenceLengthDataset(torch.utils.data.IterableDataset):
|
|
||||||
def __init__(self, generator, batch_size, bucket_size=32, max_length=1024):
|
|
||||||
super().__init__()
|
|
||||||
self.generator = generator
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.bucket_size = bucket_size
|
|
||||||
self.bucket_num = math.ceil(max_length / bucket_size)
|
|
||||||
self.buckets = [[] for _ in range(self.bucket_num)]
|
|
||||||
self.bucket_idx = None
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
if self.batch_size <= 1:
|
|
||||||
return self.generator
|
|
||||||
def bucket_iter():
|
|
||||||
if self.bucket_idx is not None:
|
|
||||||
sample = self.buckets[self.bucket_idx].pop()
|
|
||||||
if len(self.buckets[self.bucket_idx]) == 0:
|
|
||||||
self.bucket_idx = None
|
|
||||||
yield sample
|
|
||||||
sample = next(self.generator) - 1
|
|
||||||
sample_len = len(sample)
|
|
||||||
bucket_idx = sample_len // self.bucket_size
|
|
||||||
if len(self.buckets[bucket_idx]) == self.batch_size - 1:
|
|
||||||
self.bucket_idx = bucket_idx
|
|
||||||
yield sample
|
|
||||||
else:
|
|
||||||
self.buckets[bucket_idx].append(sample)
|
|
||||||
return bucket_iter()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
from datasets import IterableDataset
|
from datasets import IterableDataset
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from dataset.tokenizer import Tokenizer
|
from dataset.tokenizer import Tokenizer
|
||||||
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
||||||
|
@ -118,14 +67,6 @@ if __name__ == "__main__":
|
||||||
data_set = IterableDataset.from_generator(
|
data_set = IterableDataset.from_generator(
|
||||||
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
|
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
|
||||||
)
|
)
|
||||||
train_loader = DataLoader(
|
for sample in data_set:
|
||||||
data_set,
|
print(sample)
|
||||||
batch_size=8,
|
|
||||||
num_workers=4,
|
|
||||||
collate_fn=pretrain_collate_fn_gen(tokenizer),
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
for batch in train_loader:
|
|
||||||
for k, v in batch.items():
|
|
||||||
print(k, v.shape)
|
|
||||||
break
|
break
|
||||||
|
|
180
inctruction_tuning.py
Normal file
180
inctruction_tuning.py
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
"""
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-03-30 21:35:01
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-03-30 21:40:03
|
||||||
|
FilePath: /Open-Llama/inctruction_tuning.py
|
||||||
|
Description:
|
||||||
|
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import wandb
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import sentencepiece as spm
|
||||||
|
from torchinfo import summary
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from datasets import IterableDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from deepspeed.ops.adam import FusedAdam
|
||||||
|
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
|
||||||
|
|
||||||
|
from dataset.validation import val_set
|
||||||
|
from dataset.tokenizer import Tokenizer
|
||||||
|
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
||||||
|
from dataset.data_loader import pretrain_collate_fn_gen
|
||||||
|
from dataset.instruction_dataset import (
|
||||||
|
preprocess_belle_gen,
|
||||||
|
preprocess_self_instruction_gen,
|
||||||
|
)
|
||||||
|
from configs.instruction_tuning_config import *
|
||||||
|
|
||||||
|
accelerator = Accelerator()
|
||||||
|
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
wandb.init(project="LLAMA Instruction")
|
||||||
|
|
||||||
|
log_interval *= accelerator.gradient_accumulation_steps
|
||||||
|
eval_interval *= accelerator.gradient_accumulation_steps
|
||||||
|
save_interval *= accelerator.gradient_accumulation_steps
|
||||||
|
|
||||||
|
sp_model = spm.SentencePieceProcessor(model_file=tokenizer_model_path)
|
||||||
|
tokenizer = Tokenizer(sp_model)
|
||||||
|
|
||||||
|
paths = create_shard_kwargs(patterns, repeat=3)
|
||||||
|
random.shuffle(paths)
|
||||||
|
transform_dict = {
|
||||||
|
"belle_1M": preprocess_belle_gen(tokenizer, max_length),
|
||||||
|
"belle_0.5M": preprocess_belle_gen(tokenizer, max_length),
|
||||||
|
"self_instruct": preprocess_self_instruction_gen(tokenizer, max_length),
|
||||||
|
}
|
||||||
|
data_set = IterableDataset.from_generator(
|
||||||
|
create_data_iter,
|
||||||
|
gen_kwargs={
|
||||||
|
"paths": paths,
|
||||||
|
"transform_dict": transform_dict,
|
||||||
|
"process_index": accelerator.process_index,
|
||||||
|
"num_processes": accelerator.num_processes,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
data_set,
|
||||||
|
batch_size=train_batch_size,
|
||||||
|
# If num_workers is greater than 1, duplicate data may occur.
|
||||||
|
num_workers=0,
|
||||||
|
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
# smaller initializer_range make training more stable
|
||||||
|
# add stabel embedding to token embedding
|
||||||
|
raw_model = LlamaForCausalLM(
|
||||||
|
LlamaConfig(
|
||||||
|
vocab_size=tokenizer.vocab_size,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
pad_token_id=tokenizer.pad_id,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_dropout_prob=0.1,
|
||||||
|
use_stable_embedding=True,
|
||||||
|
shared_input_output_embedding=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
raw_model.load_state_dict(ckpt)
|
||||||
|
raw_model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
summary(raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda())
|
||||||
|
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in raw_model.named_parameters()
|
||||||
|
if not any(nd in n for nd in no_decay)
|
||||||
|
],
|
||||||
|
"weight_decay": weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in raw_model.named_parameters()
|
||||||
|
if any(nd in n for nd in no_decay)
|
||||||
|
],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
optim = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=(0.9, 0.95))
|
||||||
|
optim.zero_grad()
|
||||||
|
factor = accelerator.num_processes / accelerator.gradient_accumulation_steps
|
||||||
|
scheduler = get_cosine_schedule_with_warmup(
|
||||||
|
optim,
|
||||||
|
num_warmup_steps=num_warmup_steps * factor,
|
||||||
|
num_training_steps=num_training_steps * factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, model, optim, scheduler = accelerator.prepare(
|
||||||
|
train_loader, raw_model, optim, scheduler
|
||||||
|
)
|
||||||
|
print("start training...")
|
||||||
|
train_loader_iter = iter(train_loader)
|
||||||
|
global_step = 0
|
||||||
|
start_time = time.time()
|
||||||
|
for data_step in range(num_training_steps):
|
||||||
|
model.train()
|
||||||
|
with accelerator.accumulate(model):
|
||||||
|
batch = next(train_loader_iter)
|
||||||
|
for k, v in batch.items():
|
||||||
|
batch[k] = v.to(accelerator.device, non_blocking=True)
|
||||||
|
out = model(**batch, labels=batch["input_ids"])
|
||||||
|
total_loss = out.loss
|
||||||
|
losses = {"total_loss": total_loss}
|
||||||
|
accelerator.backward(total_loss)
|
||||||
|
optim.step()
|
||||||
|
scheduler.step()
|
||||||
|
optim.zero_grad()
|
||||||
|
if accelerator.sync_gradients:
|
||||||
|
global_step += 1
|
||||||
|
if data_step % log_interval == 0 and data_step > 0 and accelerator.is_main_process:
|
||||||
|
cost_time = time.time() - start_time
|
||||||
|
start_time = time.time()
|
||||||
|
tokens = train_batch_size * log_interval * max_length
|
||||||
|
wandb.log({"Training/Token per second per gpu": tokens / cost_time})
|
||||||
|
for k, v in losses.items():
|
||||||
|
wandb.log({"Losses/{}".format(k): v})
|
||||||
|
current_lr = optim.param_groups[0]["lr"]
|
||||||
|
wandb.log({"Training/LR": current_lr})
|
||||||
|
if optim.scaler is not None:
|
||||||
|
wandb.log({"Training/Loss Scale": optim.scaler.get_scale()})
|
||||||
|
wandb.log({"Training/Data Step": data_step})
|
||||||
|
wandb.log({"Training/Global Step": global_step})
|
||||||
|
accelerator.print(
|
||||||
|
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
|
||||||
|
global_step, data_step, losses["total_loss"], tokens / cost_time
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if data_step % eval_interval == 0 and accelerator.is_main_process:
|
||||||
|
text_table = wandb.Table(columns=["question", "pred"])
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in val_set:
|
||||||
|
raw_inputs = data
|
||||||
|
inputs_len = len(raw_inputs)
|
||||||
|
inputs = tokenizer(
|
||||||
|
raw_inputs, return_tensors=True, add_special_tokens=False
|
||||||
|
)
|
||||||
|
for k, v in inputs.items():
|
||||||
|
inputs[k] = v.to(accelerator.device)
|
||||||
|
pred = model.generate(
|
||||||
|
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
|
||||||
|
)
|
||||||
|
pred = tokenizer.decode(pred.cpu())[0]
|
||||||
|
pred = pred[inputs_len:]
|
||||||
|
text_table.add_data(raw_inputs, pred)
|
||||||
|
wandb.log({"Predictions on {}".format(global_step): text_table})
|
||||||
|
if data_step % save_interval == 0 and data_step > 0 and accelerator.is_main_process:
|
||||||
|
if not os.path.isdir(work_dir):
|
||||||
|
os.mkdir(work_dir)
|
||||||
|
torch.save(raw_model.state_dict(), "{}/{}.pt".format(work_dir, global_step))
|
||||||
|
wandb.finish()
|
|
@ -24,12 +24,12 @@ from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with
|
||||||
from dataset.validation import val_set
|
from dataset.validation import val_set
|
||||||
from dataset.tokenizer import Tokenizer
|
from dataset.tokenizer import Tokenizer
|
||||||
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
||||||
|
from dataset.data_loader import pretrain_collate_fn_gen
|
||||||
from dataset.pretrain_dataset import (
|
from dataset.pretrain_dataset import (
|
||||||
preprocess_the_pile_gen,
|
preprocess_the_pile_gen,
|
||||||
preprocess_wudao_gen,
|
preprocess_wudao_gen,
|
||||||
pretrain_collate_fn_gen,
|
|
||||||
)
|
)
|
||||||
from configs.train_config import *
|
from configs.pretrain_config import *
|
||||||
|
|
||||||
accelerator = Accelerator()
|
accelerator = Accelerator()
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ train_loader = DataLoader(
|
||||||
data_set,
|
data_set,
|
||||||
batch_size=train_batch_size,
|
batch_size=train_batch_size,
|
||||||
# If num_workers is greater than 1, duplicate data may occur.
|
# If num_workers is greater than 1, duplicate data may occur.
|
||||||
num_workers=1,
|
num_workers=0,
|
||||||
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
|
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
@ -124,7 +124,7 @@ for data_step in range(num_training_steps):
|
||||||
batch = next(train_loader_iter)
|
batch = next(train_loader_iter)
|
||||||
for k, v in batch.items():
|
for k, v in batch.items():
|
||||||
batch[k] = v.to(accelerator.device, non_blocking=True)
|
batch[k] = v.to(accelerator.device, non_blocking=True)
|
||||||
out = model(**batch, labels=batch['input_ids'])
|
out = model(**batch, labels=batch["input_ids"])
|
||||||
total_loss = out.loss
|
total_loss = out.loss
|
||||||
losses = {"total_loss": total_loss}
|
losses = {"total_loss": total_loss}
|
||||||
accelerator.backward(total_loss)
|
accelerator.backward(total_loss)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user