add continue training

This commit is contained in:
LiangSong 2023-04-29 20:28:39 +08:00
parent 28b11a5bed
commit fc21a75d1e
3 changed files with 37 additions and 22 deletions

View File

@ -2,13 +2,14 @@
Author: LiangSong(sl12160010@gmail.com) Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-06 22:30:10 Date: 2023-04-06 22:30:10
LastEditors: LiangSong(sl12160010@gmail.com) LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-27 20:34:58 LastEditTime: 2023-04-29 19:38:54
FilePath: /Open-Llama/chat_server.py FilePath: /Open-Llama/chat_server.py
Description: Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
""" """
import torch import torch
import logging
import gradio as gr import gradio as gr
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
@ -39,7 +40,7 @@ ckpt = torch.load(
raw_model.load_state_dict(ckpt) raw_model.load_state_dict(ckpt)
raw_model.eval() raw_model.eval()
model = raw_model.cuda() model = raw_model.cuda()
print("ready") logging.warn("ready")
def parse_codeblock(text): def parse_codeblock(text):
@ -70,7 +71,7 @@ with gr.Blocks() as demo:
clear = gr.Button("Clear") clear = gr.Button("Clear")
def user(user_message, history): def user(user_message, history):
print(user_message) logging.warn(user_message)
return "", history + [[user_message, None]] return "", history + [[user_message, None]]
def bot(history): def bot(history):
@ -103,7 +104,7 @@ with gr.Blocks() as demo:
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True) pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
pred = pred[:, inputs_len:] pred = pred[:, inputs_len:]
pred = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True) pred = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
print(pred) logging.warn(pred)
bot_message = parse_codeblock(pred) bot_message = parse_codeblock(pred)
history[-1][1] = bot_message history[-1][1] = bot_message
return history return history

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com) Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-24 20:05:21 Date: 2023-04-24 20:05:21
LastEditors: LiangSong(sl12160010@gmail.com) LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-27 20:34:47 LastEditTime: 2023-04-29 20:26:34
FilePath: /Open-Llama/solver/trainer.py FilePath: /Open-Llama/solver/trainer.py
Description: Description:
@ -12,6 +12,7 @@ import os
import time import time
import wandb import wandb
import torch import torch
import logging
from torchinfo import summary from torchinfo import summary
from deepspeed.ops.adam import FusedAdam from deepspeed.ops.adam import FusedAdam
from transformers import get_cosine_schedule_with_warmup from transformers import get_cosine_schedule_with_warmup
@ -26,6 +27,9 @@ class Trainer:
self.train_loader = train_loader self.train_loader = train_loader
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.accelerator = accelerator self.accelerator = accelerator
self.gradient_accumulation_steps = config["train"].get(
"gradient_accumulation_steps", 1
)
self.lr_scheduler_factor = ( self.lr_scheduler_factor = (
accelerator.num_processes / accelerator.gradient_accumulation_steps accelerator.num_processes / accelerator.gradient_accumulation_steps
) )
@ -94,6 +98,20 @@ class Trainer:
) = self.accelerator.prepare( ) = self.accelerator.prepare(
self.train_loader, self.raw_model, self.optim, self.scheduler self.train_loader, self.raw_model, self.optim, self.scheduler
) )
self.optim.zero_grad()
self.global_step = 0
try:
self.accelerator.load_state(self.work_dir)
self.global_step = self.scheduler.scheduler._step_count - 1
self.global_step = self.global_step // self.accelerator.num_processes
logging.warn("Restored ckpt from {}".format(self.work_dir))
except:
logging.warn("No ckpt found in {}".format(self.work_dir))
if self.global_step > 0:
skip_steps = self.global_step * self.gradient_accumulation_steps
self.train_loader = self.accelerator.skip_first_batches(
self.train_loader, num_batches=skip_steps
)
def train_step(self, batch): def train_step(self, batch):
out = self.model(**batch) out = self.model(**batch)
@ -109,41 +127,36 @@ class Trainer:
self.get_optimizer() self.get_optimizer()
self.get_lr_scheduler() self.get_lr_scheduler()
self.prepare() self.prepare()
self.global_step = 0
self.optim.zero_grad()
self.start_time = time.time() self.start_time = time.time()
for self.data_step, batch in enumerate(self.train_loader): for self.data_step, batch in enumerate(self.train_loader):
for k, v in batch.items(): # end training
batch[k] = v.to(self.accelerator.device, non_blocking=True)
if self.data_step >= self.config["train"]["num_training_steps"]: if self.data_step >= self.config["train"]["num_training_steps"]:
break break
# data to device
for k, v in batch.items():
batch[k] = v.to(self.accelerator.device, non_blocking=True)
self.model.train() self.model.train()
# train step
with self.accelerator.accumulate(self.model): with self.accelerator.accumulate(self.model):
losses = self.train_step(batch) losses = self.train_step(batch)
if self.accelerator.sync_gradients: if self.accelerator.sync_gradients:
self.global_step += 1 self.global_step += 1
# log
if ( if (
self.data_step % self.log_interval == 0 self.data_step % self.log_interval == 0
and self.data_step > 0 and self.data_step > 0
and self.accelerator.is_main_process and self.accelerator.is_main_process
): ):
self.log(losses) self.log(losses)
# eval/vis model output
if ( if (
self.data_step % self.eval_interval == 0 self.data_step % self.eval_interval == 0
and self.accelerator.is_main_process and self.accelerator.is_main_process
): ):
self.eval() self.eval()
if ( # save state
self.data_step % self.save_interval == 0 if self.data_step % self.save_interval == 0 and self.data_step > 0:
and self.data_step > 0 self.accelerator.save_state(self.work_dir)
and self.accelerator.is_main_process
):
if not os.path.isdir(self.work_dir):
os.mkdir(self.work_dir)
torch.save(
self.raw_model.state_dict(),
"{}/{}.pt".format(self.work_dir, self.global_step),
)
wandb.finish() wandb.finish()
def log(self, losses): def log(self, losses):

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com) Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-12 19:12:42 Date: 2023-04-12 19:12:42
LastEditors: LiangSong(sl12160010@gmail.com) LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-27 23:08:47 LastEditTime: 2023-04-29 19:38:47
FilePath: /Open-Llama/train_lm.py FilePath: /Open-Llama/train_lm.py
Description: Description:
@ -10,6 +10,7 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
""" """
import yaml import yaml
import torch import torch
import logging
from absl import app from absl import app
from absl import flags from absl import flags
from accelerate import Accelerator from accelerate import Accelerator
@ -72,7 +73,7 @@ def main(argv):
if config["train"]["ckpt"] is not None: if config["train"]["ckpt"] is not None:
ckpt = torch.load(config["train"]["ckpt"]) ckpt = torch.load(config["train"]["ckpt"])
raw_model.load_state_dict(ckpt) raw_model.load_state_dict(ckpt)
print('Loaded ckpt from: {}'.format(config["train"]["ckpt"])) logging.warn("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator) trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
trainer.train() trainer.train()