support multiple epochs

This commit is contained in:
LiangSong 2023-05-03 00:02:01 +08:00
parent f05e929aad
commit c2184c6dd1

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com) Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-24 20:05:21 Date: 2023-04-24 20:05:21
LastEditors: LiangSong(sl12160010@gmail.com) LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-29 21:59:51 LastEditTime: 2023-05-02 23:55:37
FilePath: /Open-Llama/solver/trainer.py FilePath: /Open-Llama/solver/trainer.py
Description: Description:
@ -109,9 +109,12 @@ class Trainer:
logging.warn("No ckpt found in {}".format(self.work_dir)) logging.warn("No ckpt found in {}".format(self.work_dir))
if self.global_step > 0: if self.global_step > 0:
skip_steps = self.global_step * self.gradient_accumulation_steps skip_steps = self.global_step * self.gradient_accumulation_steps
self.train_loader = self.accelerator.skip_first_batches( logging.warn("Skiped {} steps.".format(skip_steps))
self.train_loader_skiped = self.accelerator.skip_first_batches(
self.train_loader, num_batches=skip_steps self.train_loader, num_batches=skip_steps
) )
else:
self.train_loader_skiped = self.train_loader
self.accelerator.wait_for_everyone() self.accelerator.wait_for_everyone()
def train_step(self, batch): def train_step(self, batch):
@ -129,35 +132,45 @@ class Trainer:
self.get_lr_scheduler() self.get_lr_scheduler()
self.prepare() self.prepare()
self.start_time = time.time() self.start_time = time.time()
for self.data_step, batch in enumerate(self.train_loader): self.epoch = 0
# end training self.data_step = 0
while True:
if self.data_step >= self.config["train"]["num_training_steps"]: if self.data_step >= self.config["train"]["num_training_steps"]:
break break
# data to device if self.epoch == 0:
for k, v in batch.items(): train_loader = self.train_loader_skiped
batch[k] = v.to(self.accelerator.device, non_blocking=True) else:
self.model.train() train_loader = self.train_loader
# train step for batch in train_loader:
with self.accelerator.accumulate(self.model): # end training
losses = self.train_step(batch) if self.data_step >= self.config["train"]["num_training_steps"]:
if self.accelerator.sync_gradients: break
self.global_step += 1 # data to device
# log for k, v in batch.items():
if ( batch[k] = v.to(self.accelerator.device, non_blocking=True)
self.data_step % self.log_interval == 0 self.model.train()
and self.data_step > 0 # train step
and self.accelerator.is_main_process with self.accelerator.accumulate(self.model):
): losses = self.train_step(batch)
self.log(losses) if self.accelerator.sync_gradients:
# eval/vis model output self.global_step += 1
if ( # log
self.data_step % self.eval_interval == 0 if (
and self.accelerator.is_main_process self.data_step % self.log_interval == 0
): and self.data_step > 0
self.eval() and self.accelerator.is_main_process
# save state ):
if self.data_step % self.save_interval == 0 and self.data_step > 0: self.log(losses)
self.accelerator.save_state(self.work_dir) # eval/vis model output
if (
self.data_step % self.eval_interval == 0
and self.accelerator.is_main_process
):
self.eval()
# save state
if self.data_step % self.save_interval == 0 and self.data_step > 0:
self.accelerator.save_state(self.work_dir)
self.data_step += 1
wandb.finish() wandb.finish()
def log(self, losses): def log(self, losses):