add non_blocking
This commit is contained in:
parent
9e6b12e41b
commit
23d307367f
|
@ -61,6 +61,7 @@ data_set = IterableDataset.from_generator(
|
|||
train_loader = DataLoader(
|
||||
data_set,
|
||||
batch_size=train_batch_size,
|
||||
# If num_workers is greater than 1, duplicate data may occur.
|
||||
num_workers=1,
|
||||
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
|
||||
drop_last=True,
|
||||
|
@ -122,7 +123,7 @@ for data_step in range(num_training_steps):
|
|||
with accelerator.accumulate(model):
|
||||
batch = next(train_loader_iter)
|
||||
for k, v in batch.items():
|
||||
batch[k] = v.to(accelerator.device)
|
||||
batch[k] = v.to(accelerator.device, non_blocking=True)
|
||||
out = model(**batch, labels=batch['input_ids'])
|
||||
total_loss = out.loss
|
||||
losses = {"total_loss": total_loss}
|
||||
|
|
Loading…
Reference in New Issue
Block a user