add non_blocking

This commit is contained in:
LiangSong 2023-03-27 23:29:42 +08:00
parent 9e6b12e41b
commit 23d307367f

View File

@ -61,6 +61,7 @@ data_set = IterableDataset.from_generator(
train_loader = DataLoader(
data_set,
batch_size=train_batch_size,
# If num_workers is greater than 1, duplicate data may occur.
num_workers=1,
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
drop_last=True,
@ -122,7 +123,7 @@ for data_step in range(num_training_steps):
with accelerator.accumulate(model):
batch = next(train_loader_iter)
for k, v in batch.items():
batch[k] = v.to(accelerator.device)
batch[k] = v.to(accelerator.device, non_blocking=True)
out = model(**batch, labels=batch['input_ids'])
total_loss = out.loss
losses = {"total_loss": total_loss}