add trainer and utils
This commit is contained in:
parent
ae0691c509
commit
a4aa109dd3
|
@ -6,25 +6,14 @@ deepspeed_config:
|
||||||
offload_optimizer_device: none
|
offload_optimizer_device: none
|
||||||
offload_param_device: none
|
offload_param_device: none
|
||||||
zero3_init_flag: false
|
zero3_init_flag: false
|
||||||
zero_stage: 1
|
zero_stage: 2
|
||||||
distributed_type: DEEPSPEED
|
distributed_type: DEEPSPEED
|
||||||
downcast_bf16: 'no'
|
|
||||||
dynamo_backend: 'no'
|
|
||||||
# dynamo_config:
|
|
||||||
# dynamo_backend: INDUCTOR
|
|
||||||
# dynamo_mode: default
|
|
||||||
# dynamo_use_dynamic: true
|
|
||||||
# dynamo_use_fullgraph: false
|
|
||||||
fsdp_config: {}
|
fsdp_config: {}
|
||||||
machine_rank: 0
|
machine_rank: 0
|
||||||
|
main_process_ip: null
|
||||||
|
main_process_port: null
|
||||||
main_training_function: main
|
main_training_function: main
|
||||||
megatron_lm_config: {}
|
|
||||||
mixed_precision: bf16
|
mixed_precision: bf16
|
||||||
num_machines: 1
|
num_machines: 1
|
||||||
num_processes: 8
|
num_processes: 8
|
||||||
rdzv_backend: static
|
|
||||||
same_network: true
|
|
||||||
tpu_env: []
|
|
||||||
tpu_use_cluster: false
|
|
||||||
tpu_use_sudo: false
|
|
||||||
use_cpu: false
|
use_cpu: false
|
|
@ -1,14 +0,0 @@
|
||||||
max_length = 1024
|
|
||||||
train_batch_size = 2
|
|
||||||
num_training_steps = 1000000
|
|
||||||
num_warmup_steps = 2000
|
|
||||||
initializer_range = 1e-2
|
|
||||||
lr = 2e-4
|
|
||||||
weight_decay = 1e-1
|
|
||||||
tokenizer_model_path = "configs/10w_vocab_wudao5_pile10.model"
|
|
||||||
patterns = ["data/pretrain_data/part-*.jsonl.zst"]
|
|
||||||
# global step
|
|
||||||
log_interval = 5
|
|
||||||
eval_interval = 200
|
|
||||||
save_interval = 800
|
|
||||||
work_dir = "data/saved_ckpt/"
|
|
24
configs/pretrain_config.yaml
Normal file
24
configs/pretrain_config.yaml
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
data:
|
||||||
|
patterns: ["data/pretrain_data/part-*.jsonl.zst"]
|
||||||
|
tokenizer_model_path: "configs/10w_vocab_wudao5_pile10.model"
|
||||||
|
model:
|
||||||
|
initializer_range: 1.0e-2
|
||||||
|
max_length: 1024
|
||||||
|
hidden_dropout_prob: 0.1
|
||||||
|
attention_dropout_prob: 0.1
|
||||||
|
use_stable_embedding: True
|
||||||
|
shared_input_output_embedding: True
|
||||||
|
train:
|
||||||
|
train_batch_size: 2
|
||||||
|
num_training_steps: 1000000
|
||||||
|
num_warmup_steps: 2000
|
||||||
|
initializer_range: 1.0e-2
|
||||||
|
lr: 2.0e-4
|
||||||
|
weight_decay: 1.0e-1
|
||||||
|
ckpt: null
|
||||||
|
# global step
|
||||||
|
log_interval: 5
|
||||||
|
eval_interval: 200
|
||||||
|
save_interval: 800
|
||||||
|
work_dir: "data/saved_ckpt/"
|
||||||
|
project_name: "Llama Pretrain"
|
74
pretrain.py
Normal file
74
pretrain.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
import yaml
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
import sentencepiece as spm
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import LlamaForCausalLM, LlamaConfig
|
||||||
|
|
||||||
|
from dataset.tokenizer import Tokenizer
|
||||||
|
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||||
|
from dataset.collate_fn import collate_fn_gen
|
||||||
|
from dataset.pretrain_dataset import (
|
||||||
|
preprocess_the_pile_gen,
|
||||||
|
preprocess_wudao_gen,
|
||||||
|
)
|
||||||
|
from solver.trainer import Trainer
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_string("config", None, "Training config path")
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
accelerator = Accelerator()
|
||||||
|
|
||||||
|
with open(FLAGS.config, 'r', encoding="utf-8") as fp:
|
||||||
|
config = yaml.load(fp, Loader=yaml.FullLoader)
|
||||||
|
sp_model = spm.SentencePieceProcessor(model_file=config['data']['tokenizer_model_path'])
|
||||||
|
tokenizer = Tokenizer(sp_model)
|
||||||
|
|
||||||
|
paths = create_shard_kwargs(config['data']['patterns'])
|
||||||
|
random.shuffle(paths)
|
||||||
|
transform_dict = {
|
||||||
|
"wudao": preprocess_wudao_gen(tokenizer, config['model']['max_length']),
|
||||||
|
"pile": preprocess_the_pile_gen(tokenizer, config['model']['max_length']),
|
||||||
|
}
|
||||||
|
data_set = DataIter(
|
||||||
|
paths,
|
||||||
|
transform_dict=transform_dict,
|
||||||
|
concat_docs=True,
|
||||||
|
max_length=config['model']['max_length'],
|
||||||
|
process_index=accelerator.process_index,
|
||||||
|
num_processes=accelerator.num_processes,
|
||||||
|
)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
data_set,
|
||||||
|
batch_size=config['train']['train_batch_size'],
|
||||||
|
# If num_workers is greater than 1, duplicate data may occur.
|
||||||
|
num_workers=0,
|
||||||
|
collate_fn=collate_fn_gen(tokenizer, config['model']['max_length']),
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
# smaller initializer_range make training more stable
|
||||||
|
# add stabel embedding to token embedding
|
||||||
|
raw_model = LlamaForCausalLM(
|
||||||
|
LlamaConfig(
|
||||||
|
vocab_size=tokenizer.vocab_size,
|
||||||
|
initializer_range=config['model']['initializer_range'],
|
||||||
|
pad_token_id=tokenizer.pad_id,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
hidden_dropout_prob=config['model']['hidden_dropout_prob'],
|
||||||
|
attention_dropout_prob=config['model']['attention_dropout_prob'],
|
||||||
|
use_stable_embedding=config['model']['use_stable_embedding'],
|
||||||
|
shared_input_output_embedding=config['model']['shared_input_output_embedding'],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if config['train']['ckpt'] is not None:
|
||||||
|
ckpt = torch.load(config['train']['ckpt'])
|
||||||
|
raw_model.load_state_dict(ckpt)
|
||||||
|
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(main)
|
|
@ -1,175 +0,0 @@
|
||||||
"""
|
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
|
||||||
Date: 2023-03-17 14:27:28
|
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
|
||||||
LastEditTime: 2023-04-05 22:46:31
|
|
||||||
FilePath: /Open-Llama/pretrain_llama.py
|
|
||||||
Description:
|
|
||||||
pretrain GPT
|
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import wandb
|
|
||||||
import torch
|
|
||||||
import random
|
|
||||||
import sentencepiece as spm
|
|
||||||
from torchinfo import summary
|
|
||||||
from accelerate import Accelerator
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from deepspeed.ops.adam import FusedAdam
|
|
||||||
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
|
|
||||||
|
|
||||||
from dataset.validation import val_set
|
|
||||||
from dataset.tokenizer import Tokenizer
|
|
||||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
|
||||||
from dataset.collate_fn import collate_fn_gen
|
|
||||||
from dataset.pretrain_dataset import (
|
|
||||||
preprocess_the_pile_gen,
|
|
||||||
preprocess_wudao_gen,
|
|
||||||
)
|
|
||||||
from configs.pretrain_config import *
|
|
||||||
|
|
||||||
accelerator = Accelerator()
|
|
||||||
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
wandb.init(project="LLAMA Pretrain")
|
|
||||||
|
|
||||||
log_interval *= accelerator.gradient_accumulation_steps
|
|
||||||
eval_interval *= accelerator.gradient_accumulation_steps
|
|
||||||
save_interval *= accelerator.gradient_accumulation_steps
|
|
||||||
|
|
||||||
sp_model = spm.SentencePieceProcessor(model_file=tokenizer_model_path)
|
|
||||||
tokenizer = Tokenizer(sp_model)
|
|
||||||
|
|
||||||
paths = create_shard_kwargs(patterns)
|
|
||||||
random.shuffle(paths)
|
|
||||||
transform_dict = {
|
|
||||||
"wudao": preprocess_wudao_gen(tokenizer, max_length),
|
|
||||||
"pile": preprocess_the_pile_gen(tokenizer, max_length),
|
|
||||||
}
|
|
||||||
data_set = DataIter(
|
|
||||||
paths,
|
|
||||||
transform_dict=transform_dict,
|
|
||||||
concat_docs=True,
|
|
||||||
max_length=max_length,
|
|
||||||
process_index=accelerator.process_index,
|
|
||||||
num_processes=accelerator.num_processes,
|
|
||||||
)
|
|
||||||
train_loader = DataLoader(
|
|
||||||
data_set,
|
|
||||||
batch_size=train_batch_size,
|
|
||||||
# If num_workers is greater than 1, duplicate data may occur.
|
|
||||||
num_workers=0,
|
|
||||||
collate_fn=collate_fn_gen(tokenizer, max_length),
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
# smaller initializer_range make training more stable
|
|
||||||
# add stabel embedding to token embedding
|
|
||||||
raw_model = LlamaForCausalLM(
|
|
||||||
LlamaConfig(
|
|
||||||
vocab_size=tokenizer.vocab_size,
|
|
||||||
initializer_range=initializer_range,
|
|
||||||
pad_token_id=tokenizer.pad_id,
|
|
||||||
rms_norm_eps=1e-5,
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_dropout_prob=0.1,
|
|
||||||
use_stable_embedding=True,
|
|
||||||
shared_input_output_embedding=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raw_model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
summary(raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda())
|
|
||||||
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
|
|
||||||
optimizer_grouped_parameters = [
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p
|
|
||||||
for n, p in raw_model.named_parameters()
|
|
||||||
if not any(nd in n for nd in no_decay)
|
|
||||||
],
|
|
||||||
"weight_decay": weight_decay,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p
|
|
||||||
for n, p in raw_model.named_parameters()
|
|
||||||
if any(nd in n for nd in no_decay)
|
|
||||||
],
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
optim = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=(0.9, 0.95))
|
|
||||||
optim.zero_grad()
|
|
||||||
factor = accelerator.num_processes / accelerator.gradient_accumulation_steps
|
|
||||||
scheduler = get_cosine_schedule_with_warmup(
|
|
||||||
optim,
|
|
||||||
num_warmup_steps=num_warmup_steps * factor,
|
|
||||||
num_training_steps=num_training_steps * factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
_, model, optim, scheduler = accelerator.prepare(
|
|
||||||
train_loader, raw_model, optim, scheduler
|
|
||||||
)
|
|
||||||
print("start training...")
|
|
||||||
train_loader_iter = iter(train_loader)
|
|
||||||
global_step = 0
|
|
||||||
start_time = time.time()
|
|
||||||
for data_step in range(num_training_steps):
|
|
||||||
model.train()
|
|
||||||
with accelerator.accumulate(model):
|
|
||||||
batch = next(train_loader_iter)
|
|
||||||
for k, v in batch.items():
|
|
||||||
batch[k] = v.to(accelerator.device, non_blocking=True)
|
|
||||||
out = model(**batch, labels=batch["input_ids"])
|
|
||||||
total_loss = out.loss
|
|
||||||
losses = {"total_loss": total_loss}
|
|
||||||
accelerator.backward(total_loss)
|
|
||||||
optim.step()
|
|
||||||
scheduler.step()
|
|
||||||
optim.zero_grad()
|
|
||||||
if accelerator.sync_gradients:
|
|
||||||
global_step += 1
|
|
||||||
if data_step % log_interval == 0 and data_step > 0 and accelerator.is_main_process:
|
|
||||||
cost_time = time.time() - start_time
|
|
||||||
start_time = time.time()
|
|
||||||
tokens = train_batch_size * log_interval * max_length
|
|
||||||
wandb.log({"Training/Token per second per gpu": tokens / cost_time})
|
|
||||||
for k, v in losses.items():
|
|
||||||
wandb.log({"Losses/{}".format(k): v})
|
|
||||||
current_lr = optim.param_groups[0]["lr"]
|
|
||||||
wandb.log({"Training/LR": current_lr})
|
|
||||||
if optim.scaler is not None:
|
|
||||||
wandb.log({"Training/Loss Scale": optim.scaler.get_scale()})
|
|
||||||
wandb.log({"Training/Data Step": data_step})
|
|
||||||
wandb.log({"Training/Global Step": global_step})
|
|
||||||
accelerator.print(
|
|
||||||
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
|
|
||||||
global_step, data_step, losses["total_loss"], tokens / cost_time
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if data_step % eval_interval == 0 and accelerator.is_main_process:
|
|
||||||
text_table = wandb.Table(columns=["question", "pred"])
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
for data in val_set:
|
|
||||||
raw_inputs = data
|
|
||||||
inputs_len = len(raw_inputs)
|
|
||||||
inputs = tokenizer(
|
|
||||||
raw_inputs, return_tensors=True, add_special_tokens=False
|
|
||||||
)
|
|
||||||
for k, v in inputs.items():
|
|
||||||
inputs[k] = v.to(accelerator.device)
|
|
||||||
pred = model.generate(
|
|
||||||
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
|
|
||||||
)
|
|
||||||
pred = tokenizer.decode(pred.cpu())[0]
|
|
||||||
pred = pred[inputs_len:]
|
|
||||||
text_table.add_data(raw_inputs, pred)
|
|
||||||
wandb.log({"Predictions on {}".format(global_step): text_table})
|
|
||||||
if data_step % save_interval == 0 and data_step > 0 and accelerator.is_main_process:
|
|
||||||
if not os.path.isdir(work_dir):
|
|
||||||
os.mkdir(work_dir)
|
|
||||||
torch.save(raw_model.state_dict(), "{}/{}.pt".format(work_dir, global_step))
|
|
||||||
wandb.finish()
|
|
140
solver/trainer.py
Normal file
140
solver/trainer.py
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import wandb
|
||||||
|
import torch
|
||||||
|
from torchinfo import summary
|
||||||
|
from deepspeed.ops.adam import FusedAdam
|
||||||
|
from transformers import get_cosine_schedule_with_warmup
|
||||||
|
|
||||||
|
from dataset.validation import val_set
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(self, config, raw_model, train_loader, tokenizer, accelerator):
|
||||||
|
self.config = config
|
||||||
|
self.raw_model = raw_model
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.accelerator = accelerator
|
||||||
|
self.lr_scheduler_factor = accelerator.num_processes / accelerator.gradient_accumulation_steps
|
||||||
|
self.log_interval = self.config['log_interval'] * accelerator.gradient_accumulation_steps
|
||||||
|
self.eval_interval = self.config['eval_interval'] * accelerator.gradient_accumulation_steps
|
||||||
|
self.save_interval = self.config['save_interval'] * accelerator.gradient_accumulation_steps
|
||||||
|
self.work_dir = self.config['work_dir']
|
||||||
|
self.get_model_info()
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
wandb.init(project=self.config['project_name'])
|
||||||
|
|
||||||
|
def get_model_info(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
summary(self.raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda())
|
||||||
|
|
||||||
|
def get_optimizer(self):
|
||||||
|
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in self.raw_model.named_parameters()
|
||||||
|
if not any(nd in n for nd in no_decay)
|
||||||
|
],
|
||||||
|
"weight_decay": self.config['train']['weight_decay'],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in self.raw_model.named_parameters()
|
||||||
|
if any(nd in n for nd in no_decay)
|
||||||
|
],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
self.optim = FusedAdam(optimizer_grouped_parameters, lr=self.config['train']['lr'], betas=(0.9, 0.95))
|
||||||
|
|
||||||
|
def get_lr_scheduler(self):
|
||||||
|
self.scheduler = get_cosine_schedule_with_warmup(
|
||||||
|
self.optim,
|
||||||
|
num_warmup_steps=self.config['train']['num_warmup_steps'] * self.lr_scheduler_factor,
|
||||||
|
num_training_steps=self.config['train']['num_training_steps'] * self.lr_scheduler_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare(self):
|
||||||
|
_, self.model, self.optim, self.scheduler = self.accelerator.prepare(
|
||||||
|
self.train_loader, self.raw_model, self.optim, self.scheduler
|
||||||
|
)
|
||||||
|
self.train_loader_iter = iter(self.train_loader)
|
||||||
|
|
||||||
|
def train_step(self, batch):
|
||||||
|
for k, v in batch.items():
|
||||||
|
batch[k] = v.to(self.accelerator.device, non_blocking=True)
|
||||||
|
out = self.model(**batch, labels=batch["input_ids"])
|
||||||
|
total_loss = out.loss
|
||||||
|
losses = {"total_loss": total_loss}
|
||||||
|
self.accelerator.backward(total_loss)
|
||||||
|
self.optim.step()
|
||||||
|
self.scheduler.step()
|
||||||
|
self.optim.zero_grad()
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
self.get_optimizer()
|
||||||
|
self.get_lr_scheduler()
|
||||||
|
self.prepare()
|
||||||
|
self.global_step = 0
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.optim.zero_grad()
|
||||||
|
for self.data_step in range(self.config['train']['num_training_steps']):
|
||||||
|
self.model.train()
|
||||||
|
with self.accelerator.accumulate(self.model):
|
||||||
|
batch = next(self.train_loader_iter)
|
||||||
|
losses = self.train_step(batch)
|
||||||
|
if self.accelerator.sync_gradients:
|
||||||
|
self.global_step += 1
|
||||||
|
if self.data_step % self.log_interval == 0 and self.data_step > 0 and self.accelerator.is_main_process:
|
||||||
|
self.log(losses)
|
||||||
|
if self.data_step % self.eval_interval == 0 and self.accelerator.is_main_process:
|
||||||
|
self.eval()
|
||||||
|
if self.data_step % self.save_interval == 0 and self.data_step > 0 and self.accelerator.is_main_process:
|
||||||
|
if not os.path.isdir(self.work_dir):
|
||||||
|
os.mkdir(self.work_dir)
|
||||||
|
torch.save(self.raw_model.state_dict(), "{}/{}.pt".format(self.work_dir, self.global_step))
|
||||||
|
wandb.finish()
|
||||||
|
|
||||||
|
def log(self, losses):
|
||||||
|
cost_time = time.time() - self.start_time
|
||||||
|
self.start_time = time.time()
|
||||||
|
tokens = self.config['train']['train_batch_size'] * \
|
||||||
|
self.log_interval * self.config['model']['max_length']
|
||||||
|
wandb.log({"Training/Token per second per gpu": tokens / cost_time})
|
||||||
|
for k, v in losses.items():
|
||||||
|
wandb.log({"Losses/{}".format(k): v})
|
||||||
|
current_lr = self.optim.param_groups[0]["lr"]
|
||||||
|
wandb.log({"Training/LR": current_lr})
|
||||||
|
if self.optim.scaler is not None:
|
||||||
|
wandb.log({"Training/Loss Scale": self.optim.scaler.get_scale()})
|
||||||
|
wandb.log({"Training/Data Step": self.data_step})
|
||||||
|
wandb.log({"Training/Global Step": self.global_step})
|
||||||
|
self.accelerator.print(
|
||||||
|
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
|
||||||
|
self.global_step, self.data_step, losses["total_loss"], tokens / cost_time
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
text_table = wandb.Table(columns=["question", "pred"])
|
||||||
|
self.model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in val_set:
|
||||||
|
raw_inputs = data
|
||||||
|
inputs_len = len(raw_inputs)
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
raw_inputs, return_tensors=True, add_special_tokens=False
|
||||||
|
)
|
||||||
|
for k, v in inputs.items():
|
||||||
|
inputs[k] = v.to(self.accelerator.device)
|
||||||
|
pred = self.model.generate(
|
||||||
|
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
|
||||||
|
)
|
||||||
|
pred = self.tokenizer.decode(pred.cpu())[0]
|
||||||
|
pred = pred[inputs_len:]
|
||||||
|
text_table.add_data(raw_inputs, pred)
|
||||||
|
wandb.log({"Predictions on {}".format(self.global_step): text_table})
|
|
@ -17,4 +17,46 @@ extended_out_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden
|
||||||
extended_out_embeddings = extended_out_embeddings * 0.001
|
extended_out_embeddings = extended_out_embeddings * 0.001
|
||||||
ckpt['output.weight'] = torch.cat([ckpt['output.weight'], extended_out_embeddings], dim=0)
|
ckpt['output.weight'] = torch.cat([ckpt['output.weight'], extended_out_embeddings], dim=0)
|
||||||
|
|
||||||
|
rename_map = {
|
||||||
|
"tok_embeddings.weight": "model.embed_tokens.weight",
|
||||||
|
"norm.weight": "model.norm.weight",
|
||||||
|
"output.weight": "lm_head.weight",
|
||||||
|
}
|
||||||
|
|
||||||
|
for f, t in rename_map.items():
|
||||||
|
v = ckpt.pop(f)
|
||||||
|
ckpt[t] = v
|
||||||
|
|
||||||
|
from_names = [
|
||||||
|
"layers.{}.attention.wq.weight",
|
||||||
|
"layers.{}.attention.wk.weight",
|
||||||
|
"layers.{}.attention.wv.weight",
|
||||||
|
"layers.{}.attention.wo.weight",
|
||||||
|
"layers.{}.feed_forward.w1.weight",
|
||||||
|
"layers.{}.feed_forward.w2.weight",
|
||||||
|
"layers.{}.feed_forward.w3.weight",
|
||||||
|
"layers.{}.attention_norm.weight",
|
||||||
|
"layers.{}.ffn_norm.weight",
|
||||||
|
"layers.{}.attention.inner_attention.rope.freqs"
|
||||||
|
]
|
||||||
|
|
||||||
|
to_names = [
|
||||||
|
"model.layers.{}.self_attn.q_proj.weight",
|
||||||
|
"model.layers.{}.self_attn.k_proj.weight",
|
||||||
|
"model.layers.{}.self_attn.v_proj.weight",
|
||||||
|
"model.layers.{}.self_attn.o_proj.weight",
|
||||||
|
"model.layers.{}.mlp.gate_proj.weight",
|
||||||
|
"model.layers.{}.mlp.down_proj.weight",
|
||||||
|
"model.layers.{}.mlp.up_proj.weight",
|
||||||
|
"model.layers.{}.input_layernorm.weight",
|
||||||
|
"model.layers.{}.post_attention_layernorm.weight",
|
||||||
|
"model.layers.{}.self_attn.rotary_emb.inv_freq",
|
||||||
|
]
|
||||||
|
|
||||||
|
for layer in range(32):
|
||||||
|
for f, t in zip(from_names, to_names):
|
||||||
|
f = f.format(layer)
|
||||||
|
t = t.format(layer)
|
||||||
|
v = ckpt.pop(f)
|
||||||
|
ckpt[t] = v
|
||||||
torch.save(ckpt, 'data/llama_raw_ckpt/7B/extended.pth')
|
torch.save(ckpt, 'data/llama_raw_ckpt/7B/extended.pth')
|
Loading…
Reference in New Issue
Block a user