From 23d307367fa8510b268304b6ec214d98be218a63 Mon Sep 17 00:00:00 2001 From: LiangSong Date: Mon, 27 Mar 2023 23:29:42 +0800 Subject: [PATCH] add non_blocking --- pretrain_llama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pretrain_llama.py b/pretrain_llama.py index 8c07586..2970bef 100644 --- a/pretrain_llama.py +++ b/pretrain_llama.py @@ -61,6 +61,7 @@ data_set = IterableDataset.from_generator( train_loader = DataLoader( data_set, batch_size=train_batch_size, + # If num_workers is greater than 1, duplicate data may occur. num_workers=1, collate_fn=pretrain_collate_fn_gen(tokenizer, max_length), drop_last=True, @@ -122,7 +123,7 @@ for data_step in range(num_training_steps): with accelerator.accumulate(model): batch = next(train_loader_iter) for k, v in batch.items(): - batch[k] = v.to(accelerator.device) + batch[k] = v.to(accelerator.device, non_blocking=True) out = model(**batch, labels=batch['input_ids']) total_loss = out.loss losses = {"total_loss": total_loss}