update label prepare to speedup
This commit is contained in:
parent
0be3091b19
commit
9e6b12e41b
|
@ -123,9 +123,7 @@ for data_step in range(num_training_steps):
|
|||
batch = next(train_loader_iter)
|
||||
for k, v in batch.items():
|
||||
batch[k] = v.to(accelerator.device)
|
||||
labels = batch["input_ids"].clone()
|
||||
labels[labels == tokenizer.pad_id] = -100
|
||||
out = model(**batch, labels=labels)
|
||||
out = model(**batch, labels=batch['input_ids'])
|
||||
total_loss = out.loss
|
||||
losses = {"total_loss": total_loss}
|
||||
accelerator.backward(total_loss)
|
||||
|
|
Loading…
Reference in New Issue
Block a user