add instruction-tuning

This commit is contained in:
LiangSong 2023-03-30 23:43:12 +08:00
parent e1bd1766bc
commit a62ac2658f
8 changed files with 539 additions and 65 deletions

View File

@ -0,0 +1,25 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 21:38:07
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-30 21:39:40
FilePath: /Open-Llama/configs/instruction_tuning_config.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
max_length = 1024
train_batch_size = 2
num_training_steps = 37500
num_warmup_steps = 100
initializer_range = 1e-2
lr = 2e-4
weight_decay = 1e-1
tokenizer_model_path = "configs/10w_vocab_wudao5_pile10.model"
patterns = ["data/instruction_data/part-*.jsonl.zst"]
# global step
log_interval = 50
eval_interval = 500
save_interval = 1000
work_dir = "data/saved_ckpt/"
ckpt_path = "data/saved_ckpt/30000.pt"

View File

@ -0,0 +1,61 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 20:52:10
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-30 20:52:12
FilePath: /Open-Llama/data/preprocess_instruction.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import json
import zstandard as zstd
from datasets import load_dataset
dataset = load_dataset("yizhongw/self_instruct")
write_path = "data/instruction_data/part-self_instruct-{}.jsonl.zst"
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in dataset["train"]:
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
wfp.close()
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
wfp.write(line.encode("utf-8"))
wfp.write(b"\n")
total_num += 1
wfp.close()
dataset = load_dataset("BelleGroup/generated_train_0.5M_CN")
write_path = "data/instruction_data/part-belle_0.5M-{}.jsonl.zst"
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in dataset["train"]:
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
wfp.close()
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
wfp.write(line.encode("utf-8"))
wfp.write(b"\n")
total_num += 1
wfp.close()
dataset = load_dataset("BelleGroup/generated_train_1M_CN")
write_path = "data/instruction_data/part-belle_1M-{}.jsonl.zst"
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in dataset["train"]:
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
wfp.close()
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
wfp.write(line.encode("utf-8"))
wfp.write(b"\n")
total_num += 1
wfp.close()

192
dataset/data_loader.py Normal file
View File

@ -0,0 +1,192 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 20:58:16
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-30 21:00:49
FilePath: /Open-Llama/dataset/data_loader.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import math
import torch
def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024, padding="longest"):
"""
Organize data into tensors by padding based on the preset maximum length.
"""
pad_id = tokenizer.pad_id
def pretrain_collate_fn(batch):
if padding == "longest":
max_length = max([len(i) for i in batch])
elif padding == "max_length":
max_length = segment_max_length
else:
raise Exception("Invalid argumet for padding: {}".format(padding))
input_ids = []
for i in batch:
input_len = len(i)
input_ids.append(i + [pad_id] * (max_length - input_len))
inputs = {
"input_ids": torch.tensor(input_ids, dtype=torch.int64),
}
return inputs
return pretrain_collate_fn
class BySequenceLengthDataset(torch.utils.data.IterableDataset):
"""
experimental
"""
def __init__(
self, generator, batch_size, accelerator=None, bucket_size=16, max_length=1024
):
super().__init__()
self.generator = generator
self.batch_size = batch_size
self.bucket_size = bucket_size
self.bucket_num = math.ceil(max_length / bucket_size)
self.buckets = [[] for _ in range(self.bucket_num)]
self.bucket_idx = None
self.accelerator = accelerator
if self.accelerator is not None:
self.buckets_ele_num = torch.tensor(
[0] * self.bucket_num, dtype=torch.int64, device=accelerator.device
)
self.buckets_indexes = torch.arange(
self.bucket_num, device=accelerator.device
)
self.finished = False
self.has_no_same_bucket = False
self.rest = None
def __iter__(self):
if self.batch_size <= 1:
return self.generator
def bucket_iter():
while True:
if self.bucket_idx is not None:
sample = self.buckets[self.bucket_idx].pop()
if len(self.buckets[self.bucket_idx]) == 0:
self.bucket_idx = None
yield sample
try:
sample = next(self.generator)
except StopIteration:
break
sample_len = len(sample) - 1
bucket_idx = sample_len // self.bucket_size
if len(self.buckets[bucket_idx]) == self.batch_size - 1:
self.bucket_idx = bucket_idx
yield sample
else:
self.buckets[bucket_idx].append(sample)
def parallel_bucket_iter():
while True:
if self.bucket_idx is not None:
sample = self.buckets[self.bucket_idx].pop()
self.buckets_ele_num[self.bucket_idx] -= 1
buckets_ele_num = self.accelerator.gather(self.buckets_ele_num)
buckets_ele_num = buckets_ele_num.reshape(
self.accelerator.num_processes, self.bucket_num
)
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
if min_buckets_ele_num[self.bucket_idx] <= 0:
self.bucket_idx = None
yield sample
else:
if self.finished:
if self.has_no_same_bucket:
if self.rest is None:
self.rest = []
for bucket in self.buckets:
for i in bucket:
self.rest.append(i)
elif len(self.rest) > 0:
yield self.rest.pop()
else:
raise StopIteration()
else:
buckets_ele_num = self.accelerator.gather(
self.buckets_ele_num
)
buckets_ele_num = buckets_ele_num.view(
self.accelerator.num_processes, self.bucket_num
)
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
valid_bucket_idx = self.buckets_indexes[
min_buckets_ele_num >= self.batch_size
]
if len(valid_bucket_idx) > 0:
self.bucket_idx = valid_bucket_idx[0].cpu().item()
else:
self.has_no_same_bucket = True
else:
try:
sample = next(self.generator)
except StopIteration:
self.finished = True
continue
sample_len = len(sample) - 1
bucket_idx = sample_len // self.bucket_size
self.buckets[bucket_idx].append(sample)
self.buckets_ele_num[bucket_idx] += 1
buckets_ele_num = self.accelerator.gather(
self.buckets_ele_num
).cpu()
buckets_ele_num = buckets_ele_num.view(
self.accelerator.num_processes, self.bucket_num
)
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
valid_bucket_idx = self.buckets_indexes[
min_buckets_ele_num >= self.batch_size
]
if len(valid_bucket_idx) > 0:
self.bucket_idx = valid_bucket_idx[0].cpu().item()
if self.accelerator:
return parallel_bucket_iter()
else:
return bucket_iter()
if __name__ == "__main__":
import sentencepiece as spm
from datasets import IterableDataset
from torch.utils.data import DataLoader
from dataset.pretrain_dataset import preprocess_wudao_gen, preprocess_the_pile_gen
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, create_data_iter
sp_model = spm.SentencePieceProcessor(
model_file="configs/10w_vocab_wudao5_pile10.model"
)
tokenizer = Tokenizer(sp_model)
patterns = ["data/pretrain_data/part-*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
transform_dict = {
"wudao": preprocess_wudao_gen(tokenizer),
"pile": preprocess_the_pile_gen(tokenizer),
}
data_set = IterableDataset.from_generator(
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
)
train_loader = DataLoader(
data_set,
batch_size=8,
num_workers=4,
collate_fn=pretrain_collate_fn_gen(tokenizer),
drop_last=True,
)
for batch in train_loader:
for k, v in batch.items():
print(k, v.shape)
break

View File

@ -0,0 +1,75 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 21:02:00
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-30 21:02:06
FilePath: /Open-Llama/dataset/instruction_dataset.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
def preprocess_self_instruction(line):
"""
The format of the data is roughly as follows.
{'prompt': 'Explain the origin of life on earth. Output:', 'completion': 'Life on Earth is believed to have'}
Split the data based on the tokenized length according to the maximum length.
"""
prompt = line["prompt"]
if prompt.endswith("Output:"):
prompt = prompt[:-7]
total = "user:{}<s>system:{}".format(prompt.strip(), line["completion"].strip())
out = tokenizer(total)
input_ids = out["input_ids"]
return [input_ids]
return preprocess_self_instruction
def preprocess_belle_gen(tokenizer, segment_max_length=1024):
def preprocess_belle(line):
"""
The format of the data is roughly as follows.
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
"""
prompt = line["input"].replace("\\n", "")
prompt = prompt.strip("")
completion = line["target"].replace("\\n", "")
completion = completion.strip("")
total = "user:{}<s>system:{}".format(prompt, completion)
out = tokenizer(total)
input_ids = out["input_ids"]
return [input_ids]
return preprocess_belle
if __name__ == "__main__":
import sentencepiece as spm
from datasets import IterableDataset
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, create_data_iter
sp_model = spm.SentencePieceProcessor(
model_file="configs/10w_vocab_wudao5_pile10.model"
)
tokenizer = Tokenizer(sp_model)
patterns = ["data/instruction_data/part-belle_1M*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
transform_dict = {
"belle_1M": preprocess_belle_gen(tokenizer),
"belle_0.5M": preprocess_belle_gen(tokenizer),
"self_instruct": preprocess_self_instruction_gen(tokenizer),
}
data_set = IterableDataset.from_generator(
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
)
for i, sample in enumerate(data_set):
print(sample, sp_model.Decode(sample))
if i == 20:
break

View File

@ -9,7 +9,6 @@ Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import math
import torch
def preprocess_wudao_gen(tokenizer, segment_max_length=1024):
@ -48,59 +47,9 @@ def preprocess_the_pile_gen(tokenizer, segment_max_length=1024):
return preprocess_the_pile
def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024):
"""
Organize data into tensors by padding based on the preset maximum length.
"""
pad_id = tokenizer.pad_id
def pretrain_collate_fn(batch):
input_ids = []
for i in batch:
input_len = len(i)
input_ids.append(i + [pad_id] * (segment_max_length - input_len))
inputs = {
"input_ids": torch.tensor(input_ids, dtype=torch.int64),
}
return inputs
return pretrain_collate_fn
class BucketBySequenceLengthDataset(torch.utils.data.IterableDataset):
def __init__(self, generator, batch_size, bucket_size=32, max_length=1024):
super().__init__()
self.generator = generator
self.batch_size = batch_size
self.bucket_size = bucket_size
self.bucket_num = math.ceil(max_length / bucket_size)
self.buckets = [[] for _ in range(self.bucket_num)]
self.bucket_idx = None
def __iter__(self):
if self.batch_size <= 1:
return self.generator
def bucket_iter():
if self.bucket_idx is not None:
sample = self.buckets[self.bucket_idx].pop()
if len(self.buckets[self.bucket_idx]) == 0:
self.bucket_idx = None
yield sample
sample = next(self.generator) - 1
sample_len = len(sample)
bucket_idx = sample_len // self.bucket_size
if len(self.buckets[bucket_idx]) == self.batch_size - 1:
self.bucket_idx = bucket_idx
yield sample
else:
self.buckets[bucket_idx].append(sample)
return bucket_iter()
if __name__ == "__main__":
import sentencepiece as spm
from datasets import IterableDataset
from torch.utils.data import DataLoader
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, create_data_iter
@ -118,14 +67,6 @@ if __name__ == "__main__":
data_set = IterableDataset.from_generator(
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
)
train_loader = DataLoader(
data_set,
batch_size=8,
num_workers=4,
collate_fn=pretrain_collate_fn_gen(tokenizer),
drop_last=True,
)
for batch in train_loader:
for k, v in batch.items():
print(k, v.shape)
for sample in data_set:
print(sample)
break

180
inctruction_tuning.py Normal file
View File

@ -0,0 +1,180 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 21:35:01
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-30 21:40:03
FilePath: /Open-Llama/inctruction_tuning.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import os
import time
import wandb
import torch
import random
import sentencepiece as spm
from torchinfo import summary
from accelerate import Accelerator
from datasets import IterableDataset
from torch.utils.data import DataLoader
from deepspeed.ops.adam import FusedAdam
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
from dataset.validation import val_set
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, create_data_iter
from dataset.data_loader import pretrain_collate_fn_gen
from dataset.instruction_dataset import (
preprocess_belle_gen,
preprocess_self_instruction_gen,
)
from configs.instruction_tuning_config import *
accelerator = Accelerator()
if accelerator.is_main_process:
wandb.init(project="LLAMA Instruction")
log_interval *= accelerator.gradient_accumulation_steps
eval_interval *= accelerator.gradient_accumulation_steps
save_interval *= accelerator.gradient_accumulation_steps
sp_model = spm.SentencePieceProcessor(model_file=tokenizer_model_path)
tokenizer = Tokenizer(sp_model)
paths = create_shard_kwargs(patterns, repeat=3)
random.shuffle(paths)
transform_dict = {
"belle_1M": preprocess_belle_gen(tokenizer, max_length),
"belle_0.5M": preprocess_belle_gen(tokenizer, max_length),
"self_instruct": preprocess_self_instruction_gen(tokenizer, max_length),
}
data_set = IterableDataset.from_generator(
create_data_iter,
gen_kwargs={
"paths": paths,
"transform_dict": transform_dict,
"process_index": accelerator.process_index,
"num_processes": accelerator.num_processes,
},
)
train_loader = DataLoader(
data_set,
batch_size=train_batch_size,
# If num_workers is greater than 1, duplicate data may occur.
num_workers=0,
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
drop_last=True,
)
# smaller initializer_range make training more stable
# add stabel embedding to token embedding
raw_model = LlamaForCausalLM(
LlamaConfig(
vocab_size=tokenizer.vocab_size,
initializer_range=initializer_range,
pad_token_id=tokenizer.pad_id,
rms_norm_eps=1e-5,
hidden_dropout_prob=0.1,
attention_dropout_prob=0.1,
use_stable_embedding=True,
shared_input_output_embedding=True,
)
)
ckpt = torch.load(ckpt_path, map_location="cpu")
raw_model.load_state_dict(ckpt)
raw_model.eval()
with torch.no_grad():
summary(raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda())
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in raw_model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": weight_decay,
},
{
"params": [
p
for n, p in raw_model.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
optim = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=(0.9, 0.95))
optim.zero_grad()
factor = accelerator.num_processes / accelerator.gradient_accumulation_steps
scheduler = get_cosine_schedule_with_warmup(
optim,
num_warmup_steps=num_warmup_steps * factor,
num_training_steps=num_training_steps * factor,
)
_, model, optim, scheduler = accelerator.prepare(
train_loader, raw_model, optim, scheduler
)
print("start training...")
train_loader_iter = iter(train_loader)
global_step = 0
start_time = time.time()
for data_step in range(num_training_steps):
model.train()
with accelerator.accumulate(model):
batch = next(train_loader_iter)
for k, v in batch.items():
batch[k] = v.to(accelerator.device, non_blocking=True)
out = model(**batch, labels=batch["input_ids"])
total_loss = out.loss
losses = {"total_loss": total_loss}
accelerator.backward(total_loss)
optim.step()
scheduler.step()
optim.zero_grad()
if accelerator.sync_gradients:
global_step += 1
if data_step % log_interval == 0 and data_step > 0 and accelerator.is_main_process:
cost_time = time.time() - start_time
start_time = time.time()
tokens = train_batch_size * log_interval * max_length
wandb.log({"Training/Token per second per gpu": tokens / cost_time})
for k, v in losses.items():
wandb.log({"Losses/{}".format(k): v})
current_lr = optim.param_groups[0]["lr"]
wandb.log({"Training/LR": current_lr})
if optim.scaler is not None:
wandb.log({"Training/Loss Scale": optim.scaler.get_scale()})
wandb.log({"Training/Data Step": data_step})
wandb.log({"Training/Global Step": global_step})
accelerator.print(
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
global_step, data_step, losses["total_loss"], tokens / cost_time
)
)
if data_step % eval_interval == 0 and accelerator.is_main_process:
text_table = wandb.Table(columns=["question", "pred"])
model.eval()
with torch.no_grad():
for data in val_set:
raw_inputs = data
inputs_len = len(raw_inputs)
inputs = tokenizer(
raw_inputs, return_tensors=True, add_special_tokens=False
)
for k, v in inputs.items():
inputs[k] = v.to(accelerator.device)
pred = model.generate(
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
)
pred = tokenizer.decode(pred.cpu())[0]
pred = pred[inputs_len:]
text_table.add_data(raw_inputs, pred)
wandb.log({"Predictions on {}".format(global_step): text_table})
if data_step % save_interval == 0 and data_step > 0 and accelerator.is_main_process:
if not os.path.isdir(work_dir):
os.mkdir(work_dir)
torch.save(raw_model.state_dict(), "{}/{}.pt".format(work_dir, global_step))
wandb.finish()

View File

@ -24,12 +24,12 @@ from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with
from dataset.validation import val_set
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, create_data_iter
from dataset.data_loader import pretrain_collate_fn_gen
from dataset.pretrain_dataset import (
preprocess_the_pile_gen,
preprocess_wudao_gen,
pretrain_collate_fn_gen,
)
from configs.train_config import *
from configs.pretrain_config import *
accelerator = Accelerator()
@ -62,7 +62,7 @@ train_loader = DataLoader(
data_set,
batch_size=train_batch_size,
# If num_workers is greater than 1, duplicate data may occur.
num_workers=1,
num_workers=0,
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
drop_last=True,
)
@ -124,7 +124,7 @@ for data_step in range(num_training_steps):
batch = next(train_loader_iter)
for k, v in batch.items():
batch[k] = v.to(accelerator.device, non_blocking=True)
out = model(**batch, labels=batch['input_ids'])
out = model(**batch, labels=batch["input_ids"])
total_loss = out.loss
losses = {"total_loss": total_loss}
accelerator.backward(total_loss)