add continue training
This commit is contained in:
parent
28b11a5bed
commit
fc21a75d1e
|
@ -2,13 +2,14 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-06 22:30:10
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-27 20:34:58
|
||||
LastEditTime: 2023-04-29 19:38:54
|
||||
FilePath: /Open-Llama/chat_server.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import torch
|
||||
import logging
|
||||
import gradio as gr
|
||||
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
|
||||
|
||||
|
@ -39,7 +40,7 @@ ckpt = torch.load(
|
|||
raw_model.load_state_dict(ckpt)
|
||||
raw_model.eval()
|
||||
model = raw_model.cuda()
|
||||
print("ready")
|
||||
logging.warn("ready")
|
||||
|
||||
|
||||
def parse_codeblock(text):
|
||||
|
@ -70,7 +71,7 @@ with gr.Blocks() as demo:
|
|||
clear = gr.Button("Clear")
|
||||
|
||||
def user(user_message, history):
|
||||
print(user_message)
|
||||
logging.warn(user_message)
|
||||
return "", history + [[user_message, None]]
|
||||
|
||||
def bot(history):
|
||||
|
@ -103,7 +104,7 @@ with gr.Blocks() as demo:
|
|||
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
|
||||
pred = pred[:, inputs_len:]
|
||||
pred = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
|
||||
print(pred)
|
||||
logging.warn(pred)
|
||||
bot_message = parse_codeblock(pred)
|
||||
history[-1][1] = bot_message
|
||||
return history
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-24 20:05:21
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-27 20:34:47
|
||||
LastEditTime: 2023-04-29 20:26:34
|
||||
FilePath: /Open-Llama/solver/trainer.py
|
||||
Description:
|
||||
|
||||
|
@ -12,6 +12,7 @@ import os
|
|||
import time
|
||||
import wandb
|
||||
import torch
|
||||
import logging
|
||||
from torchinfo import summary
|
||||
from deepspeed.ops.adam import FusedAdam
|
||||
from transformers import get_cosine_schedule_with_warmup
|
||||
|
@ -26,6 +27,9 @@ class Trainer:
|
|||
self.train_loader = train_loader
|
||||
self.tokenizer = tokenizer
|
||||
self.accelerator = accelerator
|
||||
self.gradient_accumulation_steps = config["train"].get(
|
||||
"gradient_accumulation_steps", 1
|
||||
)
|
||||
self.lr_scheduler_factor = (
|
||||
accelerator.num_processes / accelerator.gradient_accumulation_steps
|
||||
)
|
||||
|
@ -94,6 +98,20 @@ class Trainer:
|
|||
) = self.accelerator.prepare(
|
||||
self.train_loader, self.raw_model, self.optim, self.scheduler
|
||||
)
|
||||
self.optim.zero_grad()
|
||||
self.global_step = 0
|
||||
try:
|
||||
self.accelerator.load_state(self.work_dir)
|
||||
self.global_step = self.scheduler.scheduler._step_count - 1
|
||||
self.global_step = self.global_step // self.accelerator.num_processes
|
||||
logging.warn("Restored ckpt from {}".format(self.work_dir))
|
||||
except:
|
||||
logging.warn("No ckpt found in {}".format(self.work_dir))
|
||||
if self.global_step > 0:
|
||||
skip_steps = self.global_step * self.gradient_accumulation_steps
|
||||
self.train_loader = self.accelerator.skip_first_batches(
|
||||
self.train_loader, num_batches=skip_steps
|
||||
)
|
||||
|
||||
def train_step(self, batch):
|
||||
out = self.model(**batch)
|
||||
|
@ -109,41 +127,36 @@ class Trainer:
|
|||
self.get_optimizer()
|
||||
self.get_lr_scheduler()
|
||||
self.prepare()
|
||||
self.global_step = 0
|
||||
self.optim.zero_grad()
|
||||
self.start_time = time.time()
|
||||
for self.data_step, batch in enumerate(self.train_loader):
|
||||
for k, v in batch.items():
|
||||
batch[k] = v.to(self.accelerator.device, non_blocking=True)
|
||||
# end training
|
||||
if self.data_step >= self.config["train"]["num_training_steps"]:
|
||||
break
|
||||
# data to device
|
||||
for k, v in batch.items():
|
||||
batch[k] = v.to(self.accelerator.device, non_blocking=True)
|
||||
self.model.train()
|
||||
# train step
|
||||
with self.accelerator.accumulate(self.model):
|
||||
losses = self.train_step(batch)
|
||||
if self.accelerator.sync_gradients:
|
||||
self.global_step += 1
|
||||
# log
|
||||
if (
|
||||
self.data_step % self.log_interval == 0
|
||||
and self.data_step > 0
|
||||
and self.accelerator.is_main_process
|
||||
):
|
||||
self.log(losses)
|
||||
# eval/vis model output
|
||||
if (
|
||||
self.data_step % self.eval_interval == 0
|
||||
and self.accelerator.is_main_process
|
||||
):
|
||||
self.eval()
|
||||
if (
|
||||
self.data_step % self.save_interval == 0
|
||||
and self.data_step > 0
|
||||
and self.accelerator.is_main_process
|
||||
):
|
||||
if not os.path.isdir(self.work_dir):
|
||||
os.mkdir(self.work_dir)
|
||||
torch.save(
|
||||
self.raw_model.state_dict(),
|
||||
"{}/{}.pt".format(self.work_dir, self.global_step),
|
||||
)
|
||||
# save state
|
||||
if self.data_step % self.save_interval == 0 and self.data_step > 0:
|
||||
self.accelerator.save_state(self.work_dir)
|
||||
wandb.finish()
|
||||
|
||||
def log(self, losses):
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-12 19:12:42
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-27 23:08:47
|
||||
LastEditTime: 2023-04-29 19:38:47
|
||||
FilePath: /Open-Llama/train_lm.py
|
||||
Description:
|
||||
|
||||
|
@ -10,6 +10,7 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
|||
"""
|
||||
import yaml
|
||||
import torch
|
||||
import logging
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from accelerate import Accelerator
|
||||
|
@ -72,7 +73,7 @@ def main(argv):
|
|||
if config["train"]["ckpt"] is not None:
|
||||
ckpt = torch.load(config["train"]["ckpt"])
|
||||
raw_model.load_state_dict(ckpt)
|
||||
print('Loaded ckpt from: {}'.format(config["train"]["ckpt"]))
|
||||
logging.warn("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
|
||||
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
||||
trainer.train()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user