diff --git a/solver/trainer.py b/solver/trainer.py index 025bfd6..4ed68c9 100644 --- a/solver/trainer.py +++ b/solver/trainer.py @@ -2,7 +2,7 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-24 20:05:21 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-04-29 21:59:51 +LastEditTime: 2023-05-02 23:55:37 FilePath: /Open-Llama/solver/trainer.py Description: @@ -109,9 +109,12 @@ class Trainer: logging.warn("No ckpt found in {}".format(self.work_dir)) if self.global_step > 0: skip_steps = self.global_step * self.gradient_accumulation_steps - self.train_loader = self.accelerator.skip_first_batches( + logging.warn("Skiped {} steps.".format(skip_steps)) + self.train_loader_skiped = self.accelerator.skip_first_batches( self.train_loader, num_batches=skip_steps ) + else: + self.train_loader_skiped = self.train_loader self.accelerator.wait_for_everyone() def train_step(self, batch): @@ -129,35 +132,45 @@ class Trainer: self.get_lr_scheduler() self.prepare() self.start_time = time.time() - for self.data_step, batch in enumerate(self.train_loader): - # end training + self.epoch = 0 + self.data_step = 0 + while True: if self.data_step >= self.config["train"]["num_training_steps"]: break - # data to device - for k, v in batch.items(): - batch[k] = v.to(self.accelerator.device, non_blocking=True) - self.model.train() - # train step - with self.accelerator.accumulate(self.model): - losses = self.train_step(batch) - if self.accelerator.sync_gradients: - self.global_step += 1 - # log - if ( - self.data_step % self.log_interval == 0 - and self.data_step > 0 - and self.accelerator.is_main_process - ): - self.log(losses) - # eval/vis model output - if ( - self.data_step % self.eval_interval == 0 - and self.accelerator.is_main_process - ): - self.eval() - # save state - if self.data_step % self.save_interval == 0 and self.data_step > 0: - self.accelerator.save_state(self.work_dir) + if self.epoch == 0: + train_loader = self.train_loader_skiped + else: + train_loader = self.train_loader + for batch in train_loader: + # end training + if self.data_step >= self.config["train"]["num_training_steps"]: + break + # data to device + for k, v in batch.items(): + batch[k] = v.to(self.accelerator.device, non_blocking=True) + self.model.train() + # train step + with self.accelerator.accumulate(self.model): + losses = self.train_step(batch) + if self.accelerator.sync_gradients: + self.global_step += 1 + # log + if ( + self.data_step % self.log_interval == 0 + and self.data_step > 0 + and self.accelerator.is_main_process + ): + self.log(losses) + # eval/vis model output + if ( + self.data_step % self.eval_interval == 0 + and self.accelerator.is_main_process + ): + self.eval() + # save state + if self.data_step % self.save_interval == 0 and self.data_step > 0: + self.accelerator.save_state(self.work_dir) + self.data_step += 1 wandb.finish() def log(self, losses):