74 lines
2.6 KiB
Python
74 lines
2.6 KiB
Python
import yaml
|
|
import torch
|
|
import random
|
|
from absl import app
|
|
from absl import flags
|
|
import sentencepiece as spm
|
|
from accelerate import Accelerator
|
|
from torch.utils.data import DataLoader
|
|
from transformers import LlamaForCausalLM, LlamaConfig
|
|
|
|
from dataset.tokenizer import Tokenizer
|
|
from dataset.data_iter import create_shard_kwargs, DataIter
|
|
from dataset.collate_fn import collate_fn_gen
|
|
from dataset.pretrain_dataset import (
|
|
preprocess_the_pile_gen,
|
|
preprocess_wudao_gen,
|
|
)
|
|
from solver.trainer import Trainer
|
|
|
|
FLAGS = flags.FLAGS
|
|
flags.DEFINE_string("config", None, "Training config path")
|
|
|
|
def main(argv):
|
|
accelerator = Accelerator()
|
|
|
|
with open(FLAGS.config, 'r', encoding="utf-8") as fp:
|
|
config = yaml.load(fp, Loader=yaml.FullLoader)
|
|
sp_model = spm.SentencePieceProcessor(model_file=config['data']['tokenizer_model_path'])
|
|
tokenizer = Tokenizer(sp_model)
|
|
|
|
paths = create_shard_kwargs(config['data']['patterns'])
|
|
random.shuffle(paths)
|
|
transform_dict = {
|
|
"wudao": preprocess_wudao_gen(tokenizer, config['model']['max_length']),
|
|
"pile": preprocess_the_pile_gen(tokenizer, config['model']['max_length']),
|
|
}
|
|
data_set = DataIter(
|
|
paths,
|
|
transform_dict=transform_dict,
|
|
concat_docs=True,
|
|
max_length=config['model']['max_length'],
|
|
process_index=accelerator.process_index,
|
|
num_processes=accelerator.num_processes,
|
|
)
|
|
train_loader = DataLoader(
|
|
data_set,
|
|
batch_size=config['train']['train_batch_size'],
|
|
# If num_workers is greater than 1, duplicate data may occur.
|
|
num_workers=0,
|
|
collate_fn=collate_fn_gen(tokenizer, config['model']['max_length']),
|
|
drop_last=True,
|
|
)
|
|
# smaller initializer_range make training more stable
|
|
# add stabel embedding to token embedding
|
|
raw_model = LlamaForCausalLM(
|
|
LlamaConfig(
|
|
vocab_size=tokenizer.vocab_size,
|
|
initializer_range=config['model']['initializer_range'],
|
|
pad_token_id=tokenizer.pad_id,
|
|
rms_norm_eps=1e-5,
|
|
hidden_dropout_prob=config['model']['hidden_dropout_prob'],
|
|
attention_dropout_prob=config['model']['attention_dropout_prob'],
|
|
use_stable_embedding=config['model']['use_stable_embedding'],
|
|
shared_input_output_embedding=config['model']['shared_input_output_embedding'],
|
|
)
|
|
)
|
|
if config['train']['ckpt'] is not None:
|
|
ckpt = torch.load(config['train']['ckpt'])
|
|
raw_model.load_state_dict(ckpt)
|
|
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
|
trainer.train()
|
|
|
|
if __name__ == '__main__':
|
|
app.run(main) |