update label prepare to speedup

This commit is contained in:
LiangSong 2023-03-27 17:13:59 +08:00
parent 0be3091b19
commit 9e6b12e41b

View File

@ -123,9 +123,7 @@ for data_step in range(num_training_steps):
batch = next(train_loader_iter)
for k, v in batch.items():
batch[k] = v.to(accelerator.device)
labels = batch["input_ids"].clone()
labels[labels == tokenizer.pad_id] = -100
out = model(**batch, labels=labels)
out = model(**batch, labels=batch['input_ids'])
total_loss = out.loss
losses = {"total_loss": total_loss}
accelerator.backward(total_loss)