Open-Llama/train_lm.py

82 lines
2.7 KiB
Python
Raw Normal View History

2023-04-12 14:16:15 +00:00
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-12 19:12:42
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-27 23:08:47
FilePath: /Open-Llama/train_lm.py
2023-04-12 14:16:15 +00:00
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
2023-04-12 09:59:05 +00:00
import yaml
import torch
from absl import app
from absl import flags
from accelerate import Accelerator
from torch.utils.data import DataLoader
from datasets.distributed import split_dataset_by_node
2023-04-26 10:53:30 +00:00
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
2023-04-12 09:59:05 +00:00
from dataset.dataset import construct_dataset
2023-04-12 09:59:05 +00:00
from solver.trainer import Trainer
FLAGS = flags.FLAGS
flags.DEFINE_string("config", None, "Training config path")
2023-04-12 14:16:15 +00:00
2023-04-12 09:59:05 +00:00
def main(argv):
2023-04-12 14:16:15 +00:00
with open(FLAGS.config, "r", encoding="utf-8") as fp:
2023-04-12 09:59:05 +00:00
config = yaml.load(fp, Loader=yaml.FullLoader)
accelerator = Accelerator(
gradient_accumulation_steps=config["train"].get(
"gradient_accumulation_steps", 1
)
)
2023-04-26 10:53:30 +00:00
tokenizer = LlamaTokenizer(
config["data"]["tokenizer_model_path"],
pad_token="<pad>",
add_bos_token=False,
add_eos_token=True,
2023-04-12 14:16:15 +00:00
)
data_config = config["data"]
train_dataset = construct_dataset(data_config, tokenizer)
train_dataset = split_dataset_by_node(
train_dataset,
rank=accelerator.process_index,
world_size=accelerator.num_processes,
)
train_loader = DataLoader(
train_dataset,
batch_size=config["train"]["train_batch_size"],
num_workers=config["train"]["train_num_workers"],
prefetch_factor=config["train"].get("prefetch_factor", 2),
pin_memory=True,
2023-04-12 09:59:05 +00:00
)
# smaller initializer_range make training more stable
# add stabel embedding to token embedding
raw_model = OpenLlamaForCausalLM(
OpenLlamaConfig(
2023-04-12 09:59:05 +00:00
vocab_size=tokenizer.vocab_size,
2023-04-12 14:16:15 +00:00
initializer_range=config["model"]["initializer_range"],
pad_token_id=tokenizer.pad_token_id,
2023-04-12 09:59:05 +00:00
rms_norm_eps=1e-5,
2023-04-12 14:16:15 +00:00
hidden_dropout_prob=config["model"]["hidden_dropout_prob"],
attention_dropout_prob=config["model"]["attention_dropout_prob"],
use_stable_embedding=config["model"]["use_stable_embedding"],
shared_input_output_embedding=config["model"][
"shared_input_output_embedding"
],
2023-04-12 09:59:05 +00:00
)
)
2023-04-12 14:16:15 +00:00
if config["train"]["ckpt"] is not None:
ckpt = torch.load(config["train"]["ckpt"])
2023-04-12 09:59:05 +00:00
raw_model.load_state_dict(ckpt)
2023-04-28 07:01:01 +00:00
print('Loaded ckpt from: {}'.format(config["train"]["ckpt"]))
2023-04-12 09:59:05 +00:00
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
trainer.train()
2023-04-12 14:16:15 +00:00
if __name__ == "__main__":
app.run(main)