diff --git a/chat_server.py b/chat_server.py index 213a8ea..411760e 100644 --- a/chat_server.py +++ b/chat_server.py @@ -2,13 +2,14 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-06 22:30:10 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-04-27 20:34:58 +LastEditTime: 2023-04-29 19:38:54 FilePath: /Open-Llama/chat_server.py Description: Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. """ import torch +import logging import gradio as gr from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer @@ -39,7 +40,7 @@ ckpt = torch.load( raw_model.load_state_dict(ckpt) raw_model.eval() model = raw_model.cuda() -print("ready") +logging.warn("ready") def parse_codeblock(text): @@ -70,7 +71,7 @@ with gr.Blocks() as demo: clear = gr.Button("Clear") def user(user_message, history): - print(user_message) + logging.warn(user_message) return "", history + [[user_message, None]] def bot(history): @@ -103,7 +104,7 @@ with gr.Blocks() as demo: pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True) pred = pred[:, inputs_len:] pred = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True) - print(pred) + logging.warn(pred) bot_message = parse_codeblock(pred) history[-1][1] = bot_message return history diff --git a/solver/trainer.py b/solver/trainer.py index a84f29f..08f0d8b 100644 --- a/solver/trainer.py +++ b/solver/trainer.py @@ -2,7 +2,7 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-24 20:05:21 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-04-27 20:34:47 +LastEditTime: 2023-04-29 20:26:34 FilePath: /Open-Llama/solver/trainer.py Description: @@ -12,6 +12,7 @@ import os import time import wandb import torch +import logging from torchinfo import summary from deepspeed.ops.adam import FusedAdam from transformers import get_cosine_schedule_with_warmup @@ -26,6 +27,9 @@ class Trainer: self.train_loader = train_loader self.tokenizer = tokenizer self.accelerator = accelerator + self.gradient_accumulation_steps = config["train"].get( + "gradient_accumulation_steps", 1 + ) self.lr_scheduler_factor = ( accelerator.num_processes / accelerator.gradient_accumulation_steps ) @@ -94,6 +98,20 @@ class Trainer: ) = self.accelerator.prepare( self.train_loader, self.raw_model, self.optim, self.scheduler ) + self.optim.zero_grad() + self.global_step = 0 + try: + self.accelerator.load_state(self.work_dir) + self.global_step = self.scheduler.scheduler._step_count - 1 + self.global_step = self.global_step // self.accelerator.num_processes + logging.warn("Restored ckpt from {}".format(self.work_dir)) + except: + logging.warn("No ckpt found in {}".format(self.work_dir)) + if self.global_step > 0: + skip_steps = self.global_step * self.gradient_accumulation_steps + self.train_loader = self.accelerator.skip_first_batches( + self.train_loader, num_batches=skip_steps + ) def train_step(self, batch): out = self.model(**batch) @@ -109,41 +127,36 @@ class Trainer: self.get_optimizer() self.get_lr_scheduler() self.prepare() - self.global_step = 0 - self.optim.zero_grad() self.start_time = time.time() for self.data_step, batch in enumerate(self.train_loader): - for k, v in batch.items(): - batch[k] = v.to(self.accelerator.device, non_blocking=True) + # end training if self.data_step >= self.config["train"]["num_training_steps"]: break + # data to device + for k, v in batch.items(): + batch[k] = v.to(self.accelerator.device, non_blocking=True) self.model.train() + # train step with self.accelerator.accumulate(self.model): losses = self.train_step(batch) if self.accelerator.sync_gradients: self.global_step += 1 + # log if ( self.data_step % self.log_interval == 0 and self.data_step > 0 and self.accelerator.is_main_process ): self.log(losses) + # eval/vis model output if ( self.data_step % self.eval_interval == 0 and self.accelerator.is_main_process ): self.eval() - if ( - self.data_step % self.save_interval == 0 - and self.data_step > 0 - and self.accelerator.is_main_process - ): - if not os.path.isdir(self.work_dir): - os.mkdir(self.work_dir) - torch.save( - self.raw_model.state_dict(), - "{}/{}.pt".format(self.work_dir, self.global_step), - ) + # save state + if self.data_step % self.save_interval == 0 and self.data_step > 0: + self.accelerator.save_state(self.work_dir) wandb.finish() def log(self, losses): diff --git a/train_lm.py b/train_lm.py index 99c5e55..31ba9bf 100644 --- a/train_lm.py +++ b/train_lm.py @@ -2,7 +2,7 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-12 19:12:42 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-04-27 23:08:47 +LastEditTime: 2023-04-29 19:38:47 FilePath: /Open-Llama/train_lm.py Description: @@ -10,6 +10,7 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. """ import yaml import torch +import logging from absl import app from absl import flags from accelerate import Accelerator @@ -72,7 +73,7 @@ def main(argv): if config["train"]["ckpt"] is not None: ckpt = torch.load(config["train"]["ckpt"]) raw_model.load_state_dict(ckpt) - print('Loaded ckpt from: {}'.format(config["train"]["ckpt"])) + logging.warn("Loaded ckpt from: {}".format(config["train"]["ckpt"])) trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator) trainer.train()