add trainer and utils
This commit is contained in:
parent
ae0691c509
commit
a4aa109dd3
|
@ -6,25 +6,14 @@ deepspeed_config:
|
|||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 1
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
dynamo_backend: 'no'
|
||||
# dynamo_config:
|
||||
# dynamo_backend: INDUCTOR
|
||||
# dynamo_mode: default
|
||||
# dynamo_use_dynamic: true
|
||||
# dynamo_use_fullgraph: false
|
||||
fsdp_config: {}
|
||||
machine_rank: 0
|
||||
main_process_ip: null
|
||||
main_process_port: null
|
||||
main_training_function: main
|
||||
megatron_lm_config: {}
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
use_cpu: false
|
|
@ -1,14 +0,0 @@
|
|||
max_length = 1024
|
||||
train_batch_size = 2
|
||||
num_training_steps = 1000000
|
||||
num_warmup_steps = 2000
|
||||
initializer_range = 1e-2
|
||||
lr = 2e-4
|
||||
weight_decay = 1e-1
|
||||
tokenizer_model_path = "configs/10w_vocab_wudao5_pile10.model"
|
||||
patterns = ["data/pretrain_data/part-*.jsonl.zst"]
|
||||
# global step
|
||||
log_interval = 5
|
||||
eval_interval = 200
|
||||
save_interval = 800
|
||||
work_dir = "data/saved_ckpt/"
|
24
configs/pretrain_config.yaml
Normal file
24
configs/pretrain_config.yaml
Normal file
|
@ -0,0 +1,24 @@
|
|||
data:
|
||||
patterns: ["data/pretrain_data/part-*.jsonl.zst"]
|
||||
tokenizer_model_path: "configs/10w_vocab_wudao5_pile10.model"
|
||||
model:
|
||||
initializer_range: 1.0e-2
|
||||
max_length: 1024
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_dropout_prob: 0.1
|
||||
use_stable_embedding: True
|
||||
shared_input_output_embedding: True
|
||||
train:
|
||||
train_batch_size: 2
|
||||
num_training_steps: 1000000
|
||||
num_warmup_steps: 2000
|
||||
initializer_range: 1.0e-2
|
||||
lr: 2.0e-4
|
||||
weight_decay: 1.0e-1
|
||||
ckpt: null
|
||||
# global step
|
||||
log_interval: 5
|
||||
eval_interval: 200
|
||||
save_interval: 800
|
||||
work_dir: "data/saved_ckpt/"
|
||||
project_name: "Llama Pretrain"
|
74
pretrain.py
Normal file
74
pretrain.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
import yaml
|
||||
import torch
|
||||
import random
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import sentencepiece as spm
|
||||
from accelerate import Accelerator
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import LlamaForCausalLM, LlamaConfig
|
||||
|
||||
from dataset.tokenizer import Tokenizer
|
||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||
from dataset.collate_fn import collate_fn_gen
|
||||
from dataset.pretrain_dataset import (
|
||||
preprocess_the_pile_gen,
|
||||
preprocess_wudao_gen,
|
||||
)
|
||||
from solver.trainer import Trainer
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("config", None, "Training config path")
|
||||
|
||||
def main(argv):
|
||||
accelerator = Accelerator()
|
||||
|
||||
with open(FLAGS.config, 'r', encoding="utf-8") as fp:
|
||||
config = yaml.load(fp, Loader=yaml.FullLoader)
|
||||
sp_model = spm.SentencePieceProcessor(model_file=config['data']['tokenizer_model_path'])
|
||||
tokenizer = Tokenizer(sp_model)
|
||||
|
||||
paths = create_shard_kwargs(config['data']['patterns'])
|
||||
random.shuffle(paths)
|
||||
transform_dict = {
|
||||
"wudao": preprocess_wudao_gen(tokenizer, config['model']['max_length']),
|
||||
"pile": preprocess_the_pile_gen(tokenizer, config['model']['max_length']),
|
||||
}
|
||||
data_set = DataIter(
|
||||
paths,
|
||||
transform_dict=transform_dict,
|
||||
concat_docs=True,
|
||||
max_length=config['model']['max_length'],
|
||||
process_index=accelerator.process_index,
|
||||
num_processes=accelerator.num_processes,
|
||||
)
|
||||
train_loader = DataLoader(
|
||||
data_set,
|
||||
batch_size=config['train']['train_batch_size'],
|
||||
# If num_workers is greater than 1, duplicate data may occur.
|
||||
num_workers=0,
|
||||
collate_fn=collate_fn_gen(tokenizer, config['model']['max_length']),
|
||||
drop_last=True,
|
||||
)
|
||||
# smaller initializer_range make training more stable
|
||||
# add stabel embedding to token embedding
|
||||
raw_model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=tokenizer.vocab_size,
|
||||
initializer_range=config['model']['initializer_range'],
|
||||
pad_token_id=tokenizer.pad_id,
|
||||
rms_norm_eps=1e-5,
|
||||
hidden_dropout_prob=config['model']['hidden_dropout_prob'],
|
||||
attention_dropout_prob=config['model']['attention_dropout_prob'],
|
||||
use_stable_embedding=config['model']['use_stable_embedding'],
|
||||
shared_input_output_embedding=config['model']['shared_input_output_embedding'],
|
||||
)
|
||||
)
|
||||
if config['train']['ckpt'] is not None:
|
||||
ckpt = torch.load(config['train']['ckpt'])
|
||||
raw_model.load_state_dict(ckpt)
|
||||
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
||||
trainer.train()
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
|
@ -1,175 +0,0 @@
|
|||
"""
|
||||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-17 14:27:28
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-05 22:46:31
|
||||
FilePath: /Open-Llama/pretrain_llama.py
|
||||
Description:
|
||||
pretrain GPT
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import wandb
|
||||
import torch
|
||||
import random
|
||||
import sentencepiece as spm
|
||||
from torchinfo import summary
|
||||
from accelerate import Accelerator
|
||||
from torch.utils.data import DataLoader
|
||||
from deepspeed.ops.adam import FusedAdam
|
||||
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
|
||||
|
||||
from dataset.validation import val_set
|
||||
from dataset.tokenizer import Tokenizer
|
||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||
from dataset.collate_fn import collate_fn_gen
|
||||
from dataset.pretrain_dataset import (
|
||||
preprocess_the_pile_gen,
|
||||
preprocess_wudao_gen,
|
||||
)
|
||||
from configs.pretrain_config import *
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
wandb.init(project="LLAMA Pretrain")
|
||||
|
||||
log_interval *= accelerator.gradient_accumulation_steps
|
||||
eval_interval *= accelerator.gradient_accumulation_steps
|
||||
save_interval *= accelerator.gradient_accumulation_steps
|
||||
|
||||
sp_model = spm.SentencePieceProcessor(model_file=tokenizer_model_path)
|
||||
tokenizer = Tokenizer(sp_model)
|
||||
|
||||
paths = create_shard_kwargs(patterns)
|
||||
random.shuffle(paths)
|
||||
transform_dict = {
|
||||
"wudao": preprocess_wudao_gen(tokenizer, max_length),
|
||||
"pile": preprocess_the_pile_gen(tokenizer, max_length),
|
||||
}
|
||||
data_set = DataIter(
|
||||
paths,
|
||||
transform_dict=transform_dict,
|
||||
concat_docs=True,
|
||||
max_length=max_length,
|
||||
process_index=accelerator.process_index,
|
||||
num_processes=accelerator.num_processes,
|
||||
)
|
||||
train_loader = DataLoader(
|
||||
data_set,
|
||||
batch_size=train_batch_size,
|
||||
# If num_workers is greater than 1, duplicate data may occur.
|
||||
num_workers=0,
|
||||
collate_fn=collate_fn_gen(tokenizer, max_length),
|
||||
drop_last=True,
|
||||
)
|
||||
# smaller initializer_range make training more stable
|
||||
# add stabel embedding to token embedding
|
||||
raw_model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=tokenizer.vocab_size,
|
||||
initializer_range=initializer_range,
|
||||
pad_token_id=tokenizer.pad_id,
|
||||
rms_norm_eps=1e-5,
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_dropout_prob=0.1,
|
||||
use_stable_embedding=True,
|
||||
shared_input_output_embedding=True,
|
||||
)
|
||||
)
|
||||
raw_model.eval()
|
||||
with torch.no_grad():
|
||||
summary(raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda())
|
||||
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in raw_model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in raw_model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optim = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=(0.9, 0.95))
|
||||
optim.zero_grad()
|
||||
factor = accelerator.num_processes / accelerator.gradient_accumulation_steps
|
||||
scheduler = get_cosine_schedule_with_warmup(
|
||||
optim,
|
||||
num_warmup_steps=num_warmup_steps * factor,
|
||||
num_training_steps=num_training_steps * factor,
|
||||
)
|
||||
|
||||
_, model, optim, scheduler = accelerator.prepare(
|
||||
train_loader, raw_model, optim, scheduler
|
||||
)
|
||||
print("start training...")
|
||||
train_loader_iter = iter(train_loader)
|
||||
global_step = 0
|
||||
start_time = time.time()
|
||||
for data_step in range(num_training_steps):
|
||||
model.train()
|
||||
with accelerator.accumulate(model):
|
||||
batch = next(train_loader_iter)
|
||||
for k, v in batch.items():
|
||||
batch[k] = v.to(accelerator.device, non_blocking=True)
|
||||
out = model(**batch, labels=batch["input_ids"])
|
||||
total_loss = out.loss
|
||||
losses = {"total_loss": total_loss}
|
||||
accelerator.backward(total_loss)
|
||||
optim.step()
|
||||
scheduler.step()
|
||||
optim.zero_grad()
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
if data_step % log_interval == 0 and data_step > 0 and accelerator.is_main_process:
|
||||
cost_time = time.time() - start_time
|
||||
start_time = time.time()
|
||||
tokens = train_batch_size * log_interval * max_length
|
||||
wandb.log({"Training/Token per second per gpu": tokens / cost_time})
|
||||
for k, v in losses.items():
|
||||
wandb.log({"Losses/{}".format(k): v})
|
||||
current_lr = optim.param_groups[0]["lr"]
|
||||
wandb.log({"Training/LR": current_lr})
|
||||
if optim.scaler is not None:
|
||||
wandb.log({"Training/Loss Scale": optim.scaler.get_scale()})
|
||||
wandb.log({"Training/Data Step": data_step})
|
||||
wandb.log({"Training/Global Step": global_step})
|
||||
accelerator.print(
|
||||
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
|
||||
global_step, data_step, losses["total_loss"], tokens / cost_time
|
||||
)
|
||||
)
|
||||
if data_step % eval_interval == 0 and accelerator.is_main_process:
|
||||
text_table = wandb.Table(columns=["question", "pred"])
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for data in val_set:
|
||||
raw_inputs = data
|
||||
inputs_len = len(raw_inputs)
|
||||
inputs = tokenizer(
|
||||
raw_inputs, return_tensors=True, add_special_tokens=False
|
||||
)
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.to(accelerator.device)
|
||||
pred = model.generate(
|
||||
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
|
||||
)
|
||||
pred = tokenizer.decode(pred.cpu())[0]
|
||||
pred = pred[inputs_len:]
|
||||
text_table.add_data(raw_inputs, pred)
|
||||
wandb.log({"Predictions on {}".format(global_step): text_table})
|
||||
if data_step % save_interval == 0 and data_step > 0 and accelerator.is_main_process:
|
||||
if not os.path.isdir(work_dir):
|
||||
os.mkdir(work_dir)
|
||||
torch.save(raw_model.state_dict(), "{}/{}.pt".format(work_dir, global_step))
|
||||
wandb.finish()
|
140
solver/trainer.py
Normal file
140
solver/trainer.py
Normal file
|
@ -0,0 +1,140 @@
|
|||
import os
|
||||
import time
|
||||
import wandb
|
||||
import torch
|
||||
from torchinfo import summary
|
||||
from deepspeed.ops.adam import FusedAdam
|
||||
from transformers import get_cosine_schedule_with_warmup
|
||||
|
||||
from dataset.validation import val_set
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, config, raw_model, train_loader, tokenizer, accelerator):
|
||||
self.config = config
|
||||
self.raw_model = raw_model
|
||||
self.train_loader = train_loader
|
||||
self.tokenizer = tokenizer
|
||||
self.accelerator = accelerator
|
||||
self.lr_scheduler_factor = accelerator.num_processes / accelerator.gradient_accumulation_steps
|
||||
self.log_interval = self.config['log_interval'] * accelerator.gradient_accumulation_steps
|
||||
self.eval_interval = self.config['eval_interval'] * accelerator.gradient_accumulation_steps
|
||||
self.save_interval = self.config['save_interval'] * accelerator.gradient_accumulation_steps
|
||||
self.work_dir = self.config['work_dir']
|
||||
self.get_model_info()
|
||||
if accelerator.is_main_process:
|
||||
wandb.init(project=self.config['project_name'])
|
||||
|
||||
def get_model_info(self):
|
||||
with torch.no_grad():
|
||||
summary(self.raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda())
|
||||
|
||||
def get_optimizer(self):
|
||||
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.raw_model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": self.config['train']['weight_decay'],
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.raw_model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
self.optim = FusedAdam(optimizer_grouped_parameters, lr=self.config['train']['lr'], betas=(0.9, 0.95))
|
||||
|
||||
def get_lr_scheduler(self):
|
||||
self.scheduler = get_cosine_schedule_with_warmup(
|
||||
self.optim,
|
||||
num_warmup_steps=self.config['train']['num_warmup_steps'] * self.lr_scheduler_factor,
|
||||
num_training_steps=self.config['train']['num_training_steps'] * self.lr_scheduler_factor,
|
||||
)
|
||||
|
||||
def prepare(self):
|
||||
_, self.model, self.optim, self.scheduler = self.accelerator.prepare(
|
||||
self.train_loader, self.raw_model, self.optim, self.scheduler
|
||||
)
|
||||
self.train_loader_iter = iter(self.train_loader)
|
||||
|
||||
def train_step(self, batch):
|
||||
for k, v in batch.items():
|
||||
batch[k] = v.to(self.accelerator.device, non_blocking=True)
|
||||
out = self.model(**batch, labels=batch["input_ids"])
|
||||
total_loss = out.loss
|
||||
losses = {"total_loss": total_loss}
|
||||
self.accelerator.backward(total_loss)
|
||||
self.optim.step()
|
||||
self.scheduler.step()
|
||||
self.optim.zero_grad()
|
||||
return losses
|
||||
|
||||
def train(self):
|
||||
self.get_optimizer()
|
||||
self.get_lr_scheduler()
|
||||
self.prepare()
|
||||
self.global_step = 0
|
||||
self.start_time = time.time()
|
||||
self.optim.zero_grad()
|
||||
for self.data_step in range(self.config['train']['num_training_steps']):
|
||||
self.model.train()
|
||||
with self.accelerator.accumulate(self.model):
|
||||
batch = next(self.train_loader_iter)
|
||||
losses = self.train_step(batch)
|
||||
if self.accelerator.sync_gradients:
|
||||
self.global_step += 1
|
||||
if self.data_step % self.log_interval == 0 and self.data_step > 0 and self.accelerator.is_main_process:
|
||||
self.log(losses)
|
||||
if self.data_step % self.eval_interval == 0 and self.accelerator.is_main_process:
|
||||
self.eval()
|
||||
if self.data_step % self.save_interval == 0 and self.data_step > 0 and self.accelerator.is_main_process:
|
||||
if not os.path.isdir(self.work_dir):
|
||||
os.mkdir(self.work_dir)
|
||||
torch.save(self.raw_model.state_dict(), "{}/{}.pt".format(self.work_dir, self.global_step))
|
||||
wandb.finish()
|
||||
|
||||
def log(self, losses):
|
||||
cost_time = time.time() - self.start_time
|
||||
self.start_time = time.time()
|
||||
tokens = self.config['train']['train_batch_size'] * \
|
||||
self.log_interval * self.config['model']['max_length']
|
||||
wandb.log({"Training/Token per second per gpu": tokens / cost_time})
|
||||
for k, v in losses.items():
|
||||
wandb.log({"Losses/{}".format(k): v})
|
||||
current_lr = self.optim.param_groups[0]["lr"]
|
||||
wandb.log({"Training/LR": current_lr})
|
||||
if self.optim.scaler is not None:
|
||||
wandb.log({"Training/Loss Scale": self.optim.scaler.get_scale()})
|
||||
wandb.log({"Training/Data Step": self.data_step})
|
||||
wandb.log({"Training/Global Step": self.global_step})
|
||||
self.accelerator.print(
|
||||
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
|
||||
self.global_step, self.data_step, losses["total_loss"], tokens / cost_time
|
||||
)
|
||||
)
|
||||
|
||||
def eval(self):
|
||||
text_table = wandb.Table(columns=["question", "pred"])
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
for data in val_set:
|
||||
raw_inputs = data
|
||||
inputs_len = len(raw_inputs)
|
||||
inputs = self.tokenizer(
|
||||
raw_inputs, return_tensors=True, add_special_tokens=False
|
||||
)
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.to(self.accelerator.device)
|
||||
pred = self.model.generate(
|
||||
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
|
||||
)
|
||||
pred = self.tokenizer.decode(pred.cpu())[0]
|
||||
pred = pred[inputs_len:]
|
||||
text_table.add_data(raw_inputs, pred)
|
||||
wandb.log({"Predictions on {}".format(self.global_step): text_table})
|
|
@ -17,4 +17,46 @@ extended_out_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden
|
|||
extended_out_embeddings = extended_out_embeddings * 0.001
|
||||
ckpt['output.weight'] = torch.cat([ckpt['output.weight'], extended_out_embeddings], dim=0)
|
||||
|
||||
rename_map = {
|
||||
"tok_embeddings.weight": "model.embed_tokens.weight",
|
||||
"norm.weight": "model.norm.weight",
|
||||
"output.weight": "lm_head.weight",
|
||||
}
|
||||
|
||||
for f, t in rename_map.items():
|
||||
v = ckpt.pop(f)
|
||||
ckpt[t] = v
|
||||
|
||||
from_names = [
|
||||
"layers.{}.attention.wq.weight",
|
||||
"layers.{}.attention.wk.weight",
|
||||
"layers.{}.attention.wv.weight",
|
||||
"layers.{}.attention.wo.weight",
|
||||
"layers.{}.feed_forward.w1.weight",
|
||||
"layers.{}.feed_forward.w2.weight",
|
||||
"layers.{}.feed_forward.w3.weight",
|
||||
"layers.{}.attention_norm.weight",
|
||||
"layers.{}.ffn_norm.weight",
|
||||
"layers.{}.attention.inner_attention.rope.freqs"
|
||||
]
|
||||
|
||||
to_names = [
|
||||
"model.layers.{}.self_attn.q_proj.weight",
|
||||
"model.layers.{}.self_attn.k_proj.weight",
|
||||
"model.layers.{}.self_attn.v_proj.weight",
|
||||
"model.layers.{}.self_attn.o_proj.weight",
|
||||
"model.layers.{}.mlp.gate_proj.weight",
|
||||
"model.layers.{}.mlp.down_proj.weight",
|
||||
"model.layers.{}.mlp.up_proj.weight",
|
||||
"model.layers.{}.input_layernorm.weight",
|
||||
"model.layers.{}.post_attention_layernorm.weight",
|
||||
"model.layers.{}.self_attn.rotary_emb.inv_freq",
|
||||
]
|
||||
|
||||
for layer in range(32):
|
||||
for f, t in zip(from_names, to_names):
|
||||
f = f.format(layer)
|
||||
t = t.format(layer)
|
||||
v = ckpt.pop(f)
|
||||
ckpt[t] = v
|
||||
torch.save(ckpt, 'data/llama_raw_ckpt/7B/extended.pth')
|
Loading…
Reference in New Issue
Block a user