""" Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-24 20:05:21 LastEditors: LiangSong(sl12160010@gmail.com) LastEditTime: 2023-05-04 08:41:37 FilePath: /Open-Llama/solver/trainer.py Description: Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. """ import time import wandb import torch import logging from torchinfo import summary from deepspeed.ops.adam import FusedAdam from transformers import get_cosine_schedule_with_warmup from dataset.validation import val_set class Trainer: def __init__(self, config, raw_model, train_loader, tokenizer, accelerator): self.config = config self.raw_model = raw_model self.train_loader = train_loader self.tokenizer = tokenizer self.accelerator = accelerator self.gradient_accumulation_steps = config["train"].get( "gradient_accumulation_steps", 1 ) self.lr_scheduler_factor = ( accelerator.num_processes / accelerator.gradient_accumulation_steps ) self.log_interval = ( self.config["log_interval"] * accelerator.gradient_accumulation_steps ) self.eval_interval = ( self.config["eval_interval"] * accelerator.gradient_accumulation_steps ) self.save_interval = ( self.config["save_interval"] * accelerator.gradient_accumulation_steps ) self.work_dir = self.config["work_dir"] self.get_model_info() if accelerator.is_main_process: wandb.init(project=self.config["project_name"]) def get_model_info(self): with torch.no_grad(): summary( self.raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda(), ) def get_optimizer(self): no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in self.raw_model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": self.config["train"]["weight_decay"], }, { "params": [ p for n, p in self.raw_model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] self.optim = FusedAdam( optimizer_grouped_parameters, lr=self.config["train"]["lr"], betas=(0.9, 0.95), ) def get_lr_scheduler(self): self.scheduler = get_cosine_schedule_with_warmup( self.optim, num_warmup_steps=self.config["train"]["num_warmup_steps"] * self.lr_scheduler_factor, num_training_steps=self.config["train"]["num_training_steps"] * self.lr_scheduler_factor, ) def prepare(self): ( _, self.model, self.optim, self.scheduler, ) = self.accelerator.prepare( self.train_loader, self.raw_model, self.optim, self.scheduler ) self.optim.zero_grad() self.global_step = 0 try: self.accelerator.load_state(self.work_dir) self.global_step = self.scheduler.scheduler._step_count - 1 self.global_step = self.global_step // self.accelerator.num_processes logging.warn("Restored ckpt from {}".format(self.work_dir)) except: logging.warn("No ckpt found in {}".format(self.work_dir)) if self.global_step > 0: skip_steps = self.global_step * self.gradient_accumulation_steps logging.warn("Skiped {} steps.".format(skip_steps)) self.train_loader_skiped = self.accelerator.skip_first_batches( self.train_loader, num_batches=skip_steps ) else: self.train_loader_skiped = self.train_loader self.accelerator.wait_for_everyone() def train_step(self, batch): out = self.model(**batch) total_loss = out.loss losses = {"total_loss": total_loss} self.accelerator.backward(total_loss) self.optim.step() self.scheduler.step() self.optim.zero_grad() return losses def train(self): self.get_optimizer() self.get_lr_scheduler() self.prepare() self.start_time = time.time() self.epoch = 0 self.data_step = 0 while True: if self.data_step >= self.config["train"]["num_training_steps"]: break if self.epoch == 0: train_loader = self.train_loader_skiped else: train_loader = self.train_loader for batch in train_loader: # end training if self.data_step >= self.config["train"]["num_training_steps"]: break # data to device for k, v in batch.items(): batch[k] = v.to(self.accelerator.device, non_blocking=True) self.model.train() # train step with self.accelerator.accumulate(self.model): losses = self.train_step(batch) if self.accelerator.sync_gradients: self.global_step += 1 # log if ( self.data_step % self.log_interval == 0 and self.data_step > 0 and self.accelerator.is_main_process ): self.log(losses) # eval/vis model output if ( self.data_step % self.eval_interval == 0 and self.accelerator.is_main_process ): self.eval() # save state if self.data_step % self.save_interval == 0 and self.data_step > 0: self.accelerator.save_state(self.work_dir) self.data_step += 1 wandb.finish() def log(self, losses): cost_time = time.time() - self.start_time self.start_time = time.time() tokens = ( self.config["train"]["train_batch_size"] * self.log_interval * self.config["data"]["seq_length"] ) wandb.log({"Training/Token per second per gpu": tokens / cost_time}) for k, v in losses.items(): wandb.log({"Losses/{}".format(k): v}) current_lr = self.optim.param_groups[0]["lr"] wandb.log({"Training/LR": current_lr}) if self.optim.scaler is not None: wandb.log({"Training/Loss Scale": self.optim.scaler.get_scale()}) wandb.log({"Training/Data Step": self.data_step}) wandb.log({"Training/Global Step": self.global_step}) self.accelerator.print( "Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format( self.global_step, self.data_step, losses["total_loss"], tokens / cost_time, ) ) def eval(self): text_table = wandb.Table(columns=["question", "pred"]) self.model.eval() with torch.no_grad(): for data in val_set: raw_inputs = data inputs = self.tokenizer( raw_inputs, return_tensors="pt", add_special_tokens=False, return_attention_mask=False, ) input_length = inputs["input_ids"].shape[1] for k, v in inputs.items(): inputs[k] = v.to(self.accelerator.device) pred = self.model.generate( **inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0 ) pred = pred[0, input_length:] pred = self.tokenizer.decode(pred.cpu(), skip_special_tokens=True) text_table.add_data(raw_inputs, pred) wandb.log({"Predictions on {}".format(self.global_step): text_table})