Open-Llama/utils/speed_test/lightning/run.py
2023-04-12 22:16:15 +08:00

87 lines
2.1 KiB
Python

"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-11 20:07:35
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-11 21:56:07
FilePath: /Open-Llama/speed_test/lightning/run.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import time
import torch
import lightning.pytorch as pl
from deepspeed.ops.adam import FusedAdam
from transformers import LlamaForCausalLM, LlamaConfig
from lightning.pytorch.strategies import DeepSpeedStrategy
batch_size = 2
seq_length = 2048
vocab_size = 32000
total_step = 100
use_activation_ckpt = False
class FakeSet(torch.utils.data.Dataset):
def __getitem__(self, idx):
return torch.randint(0, vocab_size, (seq_length,))
def __len__(self):
return 1000000000
class SpeedTest(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = LlamaForCausalLM(
LlamaConfig(
vocab_size=vocab_size,
)
)
if use_activation_ckpt:
self.model.gradient_checkpointing_enable()
self.start_time = None
def training_step(self, batch, batch_idx):
out = self.model(batch, labels=batch)
loss = out.loss
if self.start_time is None:
print("start")
self.start_time = time.time()
return loss
def configure_optimizers(self):
optimizer = FusedAdam(self.trainer.model.parameters(), lr=1e-5)
return optimizer
model = SpeedTest()
train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size)
strategy = DeepSpeedStrategy(
stage=2,
offload_optimizer=False,
offload_parameters=False,
process_group_backend="nccl",
)
trainer = pl.Trainer(
limit_train_batches=total_step,
max_epochs=1,
devices=8,
accelerator="gpu",
strategy=strategy,
precision=16,
enable_checkpointing=False,
)
def train(model, train_loader):
start_time = time.time()
trainer.fit(model=model, train_dataloaders=train_loader)
end_time = time.time()
return end_time - model.start_time
print("total time: {}".format(train(model, train_loader)))