add continue training

This commit is contained in:
LiangSong 2023-04-29 20:28:39 +08:00
parent 28b11a5bed
commit fc21a75d1e
3 changed files with 37 additions and 22 deletions

View File

@ -2,13 +2,14 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-06 22:30:10
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-27 20:34:58
LastEditTime: 2023-04-29 19:38:54
FilePath: /Open-Llama/chat_server.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import torch
import logging
import gradio as gr
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
@ -39,7 +40,7 @@ ckpt = torch.load(
raw_model.load_state_dict(ckpt)
raw_model.eval()
model = raw_model.cuda()
print("ready")
logging.warn("ready")
def parse_codeblock(text):
@ -70,7 +71,7 @@ with gr.Blocks() as demo:
clear = gr.Button("Clear")
def user(user_message, history):
print(user_message)
logging.warn(user_message)
return "", history + [[user_message, None]]
def bot(history):
@ -103,7 +104,7 @@ with gr.Blocks() as demo:
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
pred = pred[:, inputs_len:]
pred = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
print(pred)
logging.warn(pred)
bot_message = parse_codeblock(pred)
history[-1][1] = bot_message
return history

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-24 20:05:21
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-27 20:34:47
LastEditTime: 2023-04-29 20:26:34
FilePath: /Open-Llama/solver/trainer.py
Description:
@ -12,6 +12,7 @@ import os
import time
import wandb
import torch
import logging
from torchinfo import summary
from deepspeed.ops.adam import FusedAdam
from transformers import get_cosine_schedule_with_warmup
@ -26,6 +27,9 @@ class Trainer:
self.train_loader = train_loader
self.tokenizer = tokenizer
self.accelerator = accelerator
self.gradient_accumulation_steps = config["train"].get(
"gradient_accumulation_steps", 1
)
self.lr_scheduler_factor = (
accelerator.num_processes / accelerator.gradient_accumulation_steps
)
@ -94,6 +98,20 @@ class Trainer:
) = self.accelerator.prepare(
self.train_loader, self.raw_model, self.optim, self.scheduler
)
self.optim.zero_grad()
self.global_step = 0
try:
self.accelerator.load_state(self.work_dir)
self.global_step = self.scheduler.scheduler._step_count - 1
self.global_step = self.global_step // self.accelerator.num_processes
logging.warn("Restored ckpt from {}".format(self.work_dir))
except:
logging.warn("No ckpt found in {}".format(self.work_dir))
if self.global_step > 0:
skip_steps = self.global_step * self.gradient_accumulation_steps
self.train_loader = self.accelerator.skip_first_batches(
self.train_loader, num_batches=skip_steps
)
def train_step(self, batch):
out = self.model(**batch)
@ -109,41 +127,36 @@ class Trainer:
self.get_optimizer()
self.get_lr_scheduler()
self.prepare()
self.global_step = 0
self.optim.zero_grad()
self.start_time = time.time()
for self.data_step, batch in enumerate(self.train_loader):
for k, v in batch.items():
batch[k] = v.to(self.accelerator.device, non_blocking=True)
# end training
if self.data_step >= self.config["train"]["num_training_steps"]:
break
# data to device
for k, v in batch.items():
batch[k] = v.to(self.accelerator.device, non_blocking=True)
self.model.train()
# train step
with self.accelerator.accumulate(self.model):
losses = self.train_step(batch)
if self.accelerator.sync_gradients:
self.global_step += 1
# log
if (
self.data_step % self.log_interval == 0
and self.data_step > 0
and self.accelerator.is_main_process
):
self.log(losses)
# eval/vis model output
if (
self.data_step % self.eval_interval == 0
and self.accelerator.is_main_process
):
self.eval()
if (
self.data_step % self.save_interval == 0
and self.data_step > 0
and self.accelerator.is_main_process
):
if not os.path.isdir(self.work_dir):
os.mkdir(self.work_dir)
torch.save(
self.raw_model.state_dict(),
"{}/{}.pt".format(self.work_dir, self.global_step),
)
# save state
if self.data_step % self.save_interval == 0 and self.data_step > 0:
self.accelerator.save_state(self.work_dir)
wandb.finish()
def log(self, losses):

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-12 19:12:42
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-27 23:08:47
LastEditTime: 2023-04-29 19:38:47
FilePath: /Open-Llama/train_lm.py
Description:
@ -10,6 +10,7 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import yaml
import torch
import logging
from absl import app
from absl import flags
from accelerate import Accelerator
@ -72,7 +73,7 @@ def main(argv):
if config["train"]["ckpt"] is not None:
ckpt = torch.load(config["train"]["ckpt"])
raw_model.load_state_dict(ckpt)
print('Loaded ckpt from: {}'.format(config["train"]["ckpt"]))
logging.warn("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
trainer.train()