support multiple epochs
This commit is contained in:
parent
f05e929aad
commit
c2184c6dd1
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-24 20:05:21
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-29 21:59:51
|
||||
LastEditTime: 2023-05-02 23:55:37
|
||||
FilePath: /Open-Llama/solver/trainer.py
|
||||
Description:
|
||||
|
||||
|
@ -109,9 +109,12 @@ class Trainer:
|
|||
logging.warn("No ckpt found in {}".format(self.work_dir))
|
||||
if self.global_step > 0:
|
||||
skip_steps = self.global_step * self.gradient_accumulation_steps
|
||||
self.train_loader = self.accelerator.skip_first_batches(
|
||||
logging.warn("Skiped {} steps.".format(skip_steps))
|
||||
self.train_loader_skiped = self.accelerator.skip_first_batches(
|
||||
self.train_loader, num_batches=skip_steps
|
||||
)
|
||||
else:
|
||||
self.train_loader_skiped = self.train_loader
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
def train_step(self, batch):
|
||||
|
@ -129,35 +132,45 @@ class Trainer:
|
|||
self.get_lr_scheduler()
|
||||
self.prepare()
|
||||
self.start_time = time.time()
|
||||
for self.data_step, batch in enumerate(self.train_loader):
|
||||
# end training
|
||||
self.epoch = 0
|
||||
self.data_step = 0
|
||||
while True:
|
||||
if self.data_step >= self.config["train"]["num_training_steps"]:
|
||||
break
|
||||
# data to device
|
||||
for k, v in batch.items():
|
||||
batch[k] = v.to(self.accelerator.device, non_blocking=True)
|
||||
self.model.train()
|
||||
# train step
|
||||
with self.accelerator.accumulate(self.model):
|
||||
losses = self.train_step(batch)
|
||||
if self.accelerator.sync_gradients:
|
||||
self.global_step += 1
|
||||
# log
|
||||
if (
|
||||
self.data_step % self.log_interval == 0
|
||||
and self.data_step > 0
|
||||
and self.accelerator.is_main_process
|
||||
):
|
||||
self.log(losses)
|
||||
# eval/vis model output
|
||||
if (
|
||||
self.data_step % self.eval_interval == 0
|
||||
and self.accelerator.is_main_process
|
||||
):
|
||||
self.eval()
|
||||
# save state
|
||||
if self.data_step % self.save_interval == 0 and self.data_step > 0:
|
||||
self.accelerator.save_state(self.work_dir)
|
||||
if self.epoch == 0:
|
||||
train_loader = self.train_loader_skiped
|
||||
else:
|
||||
train_loader = self.train_loader
|
||||
for batch in train_loader:
|
||||
# end training
|
||||
if self.data_step >= self.config["train"]["num_training_steps"]:
|
||||
break
|
||||
# data to device
|
||||
for k, v in batch.items():
|
||||
batch[k] = v.to(self.accelerator.device, non_blocking=True)
|
||||
self.model.train()
|
||||
# train step
|
||||
with self.accelerator.accumulate(self.model):
|
||||
losses = self.train_step(batch)
|
||||
if self.accelerator.sync_gradients:
|
||||
self.global_step += 1
|
||||
# log
|
||||
if (
|
||||
self.data_step % self.log_interval == 0
|
||||
and self.data_step > 0
|
||||
and self.accelerator.is_main_process
|
||||
):
|
||||
self.log(losses)
|
||||
# eval/vis model output
|
||||
if (
|
||||
self.data_step % self.eval_interval == 0
|
||||
and self.accelerator.is_main_process
|
||||
):
|
||||
self.eval()
|
||||
# save state
|
||||
if self.data_step % self.save_interval == 0 and self.data_step > 0:
|
||||
self.accelerator.save_state(self.work_dir)
|
||||
self.data_step += 1
|
||||
wandb.finish()
|
||||
|
||||
def log(self, losses):
|
||||
|
|
Loading…
Reference in New Issue
Block a user