add continue training
This commit is contained in:
parent
28b11a5bed
commit
fc21a75d1e
|
@ -2,13 +2,14 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-06 22:30:10
|
Date: 2023-04-06 22:30:10
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-27 20:34:58
|
LastEditTime: 2023-04-29 19:38:54
|
||||||
FilePath: /Open-Llama/chat_server.py
|
FilePath: /Open-Llama/chat_server.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
|
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
|
||||||
|
|
||||||
|
@ -39,7 +40,7 @@ ckpt = torch.load(
|
||||||
raw_model.load_state_dict(ckpt)
|
raw_model.load_state_dict(ckpt)
|
||||||
raw_model.eval()
|
raw_model.eval()
|
||||||
model = raw_model.cuda()
|
model = raw_model.cuda()
|
||||||
print("ready")
|
logging.warn("ready")
|
||||||
|
|
||||||
|
|
||||||
def parse_codeblock(text):
|
def parse_codeblock(text):
|
||||||
|
@ -70,7 +71,7 @@ with gr.Blocks() as demo:
|
||||||
clear = gr.Button("Clear")
|
clear = gr.Button("Clear")
|
||||||
|
|
||||||
def user(user_message, history):
|
def user(user_message, history):
|
||||||
print(user_message)
|
logging.warn(user_message)
|
||||||
return "", history + [[user_message, None]]
|
return "", history + [[user_message, None]]
|
||||||
|
|
||||||
def bot(history):
|
def bot(history):
|
||||||
|
@ -103,7 +104,7 @@ with gr.Blocks() as demo:
|
||||||
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
|
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
|
||||||
pred = pred[:, inputs_len:]
|
pred = pred[:, inputs_len:]
|
||||||
pred = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
|
pred = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
|
||||||
print(pred)
|
logging.warn(pred)
|
||||||
bot_message = parse_codeblock(pred)
|
bot_message = parse_codeblock(pred)
|
||||||
history[-1][1] = bot_message
|
history[-1][1] = bot_message
|
||||||
return history
|
return history
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-24 20:05:21
|
Date: 2023-04-24 20:05:21
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-27 20:34:47
|
LastEditTime: 2023-04-29 20:26:34
|
||||||
FilePath: /Open-Llama/solver/trainer.py
|
FilePath: /Open-Llama/solver/trainer.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ import os
|
||||||
import time
|
import time
|
||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
from torchinfo import summary
|
from torchinfo import summary
|
||||||
from deepspeed.ops.adam import FusedAdam
|
from deepspeed.ops.adam import FusedAdam
|
||||||
from transformers import get_cosine_schedule_with_warmup
|
from transformers import get_cosine_schedule_with_warmup
|
||||||
|
@ -26,6 +27,9 @@ class Trainer:
|
||||||
self.train_loader = train_loader
|
self.train_loader = train_loader
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.accelerator = accelerator
|
self.accelerator = accelerator
|
||||||
|
self.gradient_accumulation_steps = config["train"].get(
|
||||||
|
"gradient_accumulation_steps", 1
|
||||||
|
)
|
||||||
self.lr_scheduler_factor = (
|
self.lr_scheduler_factor = (
|
||||||
accelerator.num_processes / accelerator.gradient_accumulation_steps
|
accelerator.num_processes / accelerator.gradient_accumulation_steps
|
||||||
)
|
)
|
||||||
|
@ -94,6 +98,20 @@ class Trainer:
|
||||||
) = self.accelerator.prepare(
|
) = self.accelerator.prepare(
|
||||||
self.train_loader, self.raw_model, self.optim, self.scheduler
|
self.train_loader, self.raw_model, self.optim, self.scheduler
|
||||||
)
|
)
|
||||||
|
self.optim.zero_grad()
|
||||||
|
self.global_step = 0
|
||||||
|
try:
|
||||||
|
self.accelerator.load_state(self.work_dir)
|
||||||
|
self.global_step = self.scheduler.scheduler._step_count - 1
|
||||||
|
self.global_step = self.global_step // self.accelerator.num_processes
|
||||||
|
logging.warn("Restored ckpt from {}".format(self.work_dir))
|
||||||
|
except:
|
||||||
|
logging.warn("No ckpt found in {}".format(self.work_dir))
|
||||||
|
if self.global_step > 0:
|
||||||
|
skip_steps = self.global_step * self.gradient_accumulation_steps
|
||||||
|
self.train_loader = self.accelerator.skip_first_batches(
|
||||||
|
self.train_loader, num_batches=skip_steps
|
||||||
|
)
|
||||||
|
|
||||||
def train_step(self, batch):
|
def train_step(self, batch):
|
||||||
out = self.model(**batch)
|
out = self.model(**batch)
|
||||||
|
@ -109,41 +127,36 @@ class Trainer:
|
||||||
self.get_optimizer()
|
self.get_optimizer()
|
||||||
self.get_lr_scheduler()
|
self.get_lr_scheduler()
|
||||||
self.prepare()
|
self.prepare()
|
||||||
self.global_step = 0
|
|
||||||
self.optim.zero_grad()
|
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
for self.data_step, batch in enumerate(self.train_loader):
|
for self.data_step, batch in enumerate(self.train_loader):
|
||||||
for k, v in batch.items():
|
# end training
|
||||||
batch[k] = v.to(self.accelerator.device, non_blocking=True)
|
|
||||||
if self.data_step >= self.config["train"]["num_training_steps"]:
|
if self.data_step >= self.config["train"]["num_training_steps"]:
|
||||||
break
|
break
|
||||||
|
# data to device
|
||||||
|
for k, v in batch.items():
|
||||||
|
batch[k] = v.to(self.accelerator.device, non_blocking=True)
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
# train step
|
||||||
with self.accelerator.accumulate(self.model):
|
with self.accelerator.accumulate(self.model):
|
||||||
losses = self.train_step(batch)
|
losses = self.train_step(batch)
|
||||||
if self.accelerator.sync_gradients:
|
if self.accelerator.sync_gradients:
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
# log
|
||||||
if (
|
if (
|
||||||
self.data_step % self.log_interval == 0
|
self.data_step % self.log_interval == 0
|
||||||
and self.data_step > 0
|
and self.data_step > 0
|
||||||
and self.accelerator.is_main_process
|
and self.accelerator.is_main_process
|
||||||
):
|
):
|
||||||
self.log(losses)
|
self.log(losses)
|
||||||
|
# eval/vis model output
|
||||||
if (
|
if (
|
||||||
self.data_step % self.eval_interval == 0
|
self.data_step % self.eval_interval == 0
|
||||||
and self.accelerator.is_main_process
|
and self.accelerator.is_main_process
|
||||||
):
|
):
|
||||||
self.eval()
|
self.eval()
|
||||||
if (
|
# save state
|
||||||
self.data_step % self.save_interval == 0
|
if self.data_step % self.save_interval == 0 and self.data_step > 0:
|
||||||
and self.data_step > 0
|
self.accelerator.save_state(self.work_dir)
|
||||||
and self.accelerator.is_main_process
|
|
||||||
):
|
|
||||||
if not os.path.isdir(self.work_dir):
|
|
||||||
os.mkdir(self.work_dir)
|
|
||||||
torch.save(
|
|
||||||
self.raw_model.state_dict(),
|
|
||||||
"{}/{}.pt".format(self.work_dir, self.global_step),
|
|
||||||
)
|
|
||||||
wandb.finish()
|
wandb.finish()
|
||||||
|
|
||||||
def log(self, losses):
|
def log(self, losses):
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-12 19:12:42
|
Date: 2023-04-12 19:12:42
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-27 23:08:47
|
LastEditTime: 2023-04-29 19:38:47
|
||||||
FilePath: /Open-Llama/train_lm.py
|
FilePath: /Open-Llama/train_lm.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
import yaml
|
import yaml
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
@ -72,7 +73,7 @@ def main(argv):
|
||||||
if config["train"]["ckpt"] is not None:
|
if config["train"]["ckpt"] is not None:
|
||||||
ckpt = torch.load(config["train"]["ckpt"])
|
ckpt = torch.load(config["train"]["ckpt"])
|
||||||
raw_model.load_state_dict(ckpt)
|
raw_model.load_state_dict(ckpt)
|
||||||
print('Loaded ckpt from: {}'.format(config["train"]["ckpt"]))
|
logging.warn("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
|
||||||
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user