add non_blocking
This commit is contained in:
parent
9e6b12e41b
commit
23d307367f
|
@ -61,6 +61,7 @@ data_set = IterableDataset.from_generator(
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
data_set,
|
data_set,
|
||||||
batch_size=train_batch_size,
|
batch_size=train_batch_size,
|
||||||
|
# If num_workers is greater than 1, duplicate data may occur.
|
||||||
num_workers=1,
|
num_workers=1,
|
||||||
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
|
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
|
@ -122,7 +123,7 @@ for data_step in range(num_training_steps):
|
||||||
with accelerator.accumulate(model):
|
with accelerator.accumulate(model):
|
||||||
batch = next(train_loader_iter)
|
batch = next(train_loader_iter)
|
||||||
for k, v in batch.items():
|
for k, v in batch.items():
|
||||||
batch[k] = v.to(accelerator.device)
|
batch[k] = v.to(accelerator.device, non_blocking=True)
|
||||||
out = model(**batch, labels=batch['input_ids'])
|
out = model(**batch, labels=batch['input_ids'])
|
||||||
total_loss = out.loss
|
total_loss = out.loss
|
||||||
losses = {"total_loss": total_loss}
|
losses = {"total_loss": total_loss}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user