186 lines
6.6 KiB
Python
186 lines
6.6 KiB
Python
"""
|
|
Author: LiangSong(sl12160010@gmail.com)
|
|
Date: 2023-03-30 21:35:01
|
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
|
LastEditTime: 2023-04-06 03:35:31
|
|
FilePath: /Open-Llama/inctruction_tuning.py
|
|
Description:
|
|
|
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
|
"""
|
|
import os
|
|
import time
|
|
import wandb
|
|
import torch
|
|
import random
|
|
import sentencepiece as spm
|
|
from torchinfo import summary
|
|
from accelerate import Accelerator
|
|
from torch.utils.data import DataLoader
|
|
from deepspeed.ops.adam import FusedAdam
|
|
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
|
|
|
|
from dataset.validation import val_set
|
|
from dataset.tokenizer import Tokenizer
|
|
from dataset.data_iter import create_shard_kwargs, DataIter
|
|
from dataset.collate_fn import collate_fn_gen
|
|
from dataset.instruction_dataset import (
|
|
preprocess_belle_gen,
|
|
preprocess_self_instruction_gen,
|
|
preprocess_belle_multiturn_chat_gen,
|
|
preprocess_instruct_code_gen,
|
|
preprocess_sharegpt_gen,
|
|
)
|
|
from configs.instruction_tuning_config import *
|
|
|
|
accelerator = Accelerator()
|
|
|
|
if accelerator.is_main_process:
|
|
wandb.init(project="LLAMA Instruction")
|
|
|
|
log_interval *= accelerator.gradient_accumulation_steps
|
|
eval_interval *= accelerator.gradient_accumulation_steps
|
|
save_interval *= accelerator.gradient_accumulation_steps
|
|
|
|
sp_model = spm.SentencePieceProcessor(model_file=tokenizer_model_path)
|
|
tokenizer = Tokenizer(sp_model)
|
|
|
|
paths = create_shard_kwargs(patterns, repeat=3)
|
|
random.shuffle(paths)
|
|
transform_dict = {
|
|
"self_instruct": preprocess_self_instruction_gen(tokenizer),
|
|
"belle_1M": preprocess_belle_gen(tokenizer),
|
|
"belle_0.5M": preprocess_belle_gen(tokenizer),
|
|
"belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
|
|
"belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
|
|
"instruct_to_code": preprocess_instruct_code_gen(tokenizer),
|
|
"sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
|
|
}
|
|
data_set = DataIter(
|
|
paths,
|
|
transform_dict=transform_dict,
|
|
concat_docs=True,
|
|
max_length=max_length,
|
|
process_index=accelerator.process_index,
|
|
num_processes=accelerator.num_processes,
|
|
)
|
|
train_loader = DataLoader(
|
|
data_set,
|
|
batch_size=train_batch_size,
|
|
# If num_workers is greater than 1, duplicate data may occur.
|
|
num_workers=0,
|
|
collate_fn=collate_fn_gen(tokenizer, max_length),
|
|
drop_last=True,
|
|
)
|
|
# smaller initializer_range make training more stable
|
|
# add stabel embedding to token embedding
|
|
raw_model = LlamaForCausalLM(
|
|
LlamaConfig(
|
|
vocab_size=tokenizer.vocab_size,
|
|
initializer_range=initializer_range,
|
|
pad_token_id=tokenizer.pad_id,
|
|
rms_norm_eps=1e-5,
|
|
hidden_dropout_prob=0.1,
|
|
attention_dropout_prob=0.1,
|
|
use_stable_embedding=True,
|
|
shared_input_output_embedding=True,
|
|
)
|
|
)
|
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
|
raw_model.load_state_dict(ckpt)
|
|
raw_model.eval()
|
|
with torch.no_grad():
|
|
summary(raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda())
|
|
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
|
|
optimizer_grouped_parameters = [
|
|
{
|
|
"params": [
|
|
p
|
|
for n, p in raw_model.named_parameters()
|
|
if not any(nd in n for nd in no_decay)
|
|
],
|
|
"weight_decay": weight_decay,
|
|
},
|
|
{
|
|
"params": [
|
|
p
|
|
for n, p in raw_model.named_parameters()
|
|
if any(nd in n for nd in no_decay)
|
|
],
|
|
"weight_decay": 0.0,
|
|
},
|
|
]
|
|
optim = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=(0.9, 0.95))
|
|
optim.zero_grad()
|
|
factor = accelerator.num_processes / accelerator.gradient_accumulation_steps
|
|
scheduler = get_cosine_schedule_with_warmup(
|
|
optim,
|
|
num_warmup_steps=num_warmup_steps * factor,
|
|
num_training_steps=num_training_steps * factor,
|
|
)
|
|
|
|
_, model, optim, scheduler = accelerator.prepare(
|
|
train_loader, raw_model, optim, scheduler
|
|
)
|
|
print("start training...")
|
|
train_loader_iter = iter(train_loader)
|
|
global_step = 0
|
|
start_time = time.time()
|
|
for data_step in range(num_training_steps):
|
|
model.train()
|
|
with accelerator.accumulate(model):
|
|
batch = next(train_loader_iter)
|
|
for k, v in batch.items():
|
|
batch[k] = v.to(accelerator.device, non_blocking=True)
|
|
out = model(**batch, labels=batch["input_ids"])
|
|
total_loss = out.loss
|
|
losses = {"total_loss": total_loss}
|
|
accelerator.backward(total_loss)
|
|
optim.step()
|
|
scheduler.step()
|
|
optim.zero_grad()
|
|
if accelerator.sync_gradients:
|
|
global_step += 1
|
|
if data_step % log_interval == 0 and data_step > 0 and accelerator.is_main_process:
|
|
cost_time = time.time() - start_time
|
|
start_time = time.time()
|
|
tokens = train_batch_size * log_interval * max_length
|
|
wandb.log({"Training/Token per second per gpu": tokens / cost_time})
|
|
for k, v in losses.items():
|
|
wandb.log({"Losses/{}".format(k): v})
|
|
current_lr = optim.param_groups[0]["lr"]
|
|
wandb.log({"Training/LR": current_lr})
|
|
if optim.scaler is not None:
|
|
wandb.log({"Training/Loss Scale": optim.scaler.get_scale()})
|
|
wandb.log({"Training/Data Step": data_step})
|
|
wandb.log({"Training/Global Step": global_step})
|
|
accelerator.print(
|
|
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
|
|
global_step, data_step, losses["total_loss"], tokens / cost_time
|
|
)
|
|
)
|
|
if data_step % eval_interval == 0 and accelerator.is_main_process:
|
|
text_table = wandb.Table(columns=["question", "pred"])
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for data in val_set:
|
|
raw_inputs = data
|
|
inputs_len = len(raw_inputs)
|
|
inputs = tokenizer(
|
|
raw_inputs, return_tensors=True, add_special_tokens=False
|
|
)
|
|
for k, v in inputs.items():
|
|
inputs[k] = v.to(accelerator.device)
|
|
pred = model.generate(
|
|
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
|
|
)
|
|
pred = tokenizer.decode(pred.cpu())[0]
|
|
pred = pred[inputs_len:]
|
|
text_table.add_data(raw_inputs, pred)
|
|
wandb.log({"Predictions on {}".format(global_step): text_table})
|
|
if data_step % save_interval == 0 and data_step > 0 and accelerator.is_main_process:
|
|
if not os.path.isdir(work_dir):
|
|
os.mkdir(work_dir)
|
|
torch.save(raw_model.state_dict(), "{}/{}.pt".format(work_dir, global_step))
|
|
wandb.finish()
|