87 lines
2.1 KiB
Python
87 lines
2.1 KiB
Python
"""
|
|
Author: LiangSong(sl12160010@gmail.com)
|
|
Date: 2023-04-11 20:07:35
|
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
|
LastEditTime: 2023-04-11 21:56:07
|
|
FilePath: /Open-Llama/speed_test/lightning/run.py
|
|
Description:
|
|
|
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
|
"""
|
|
import time
|
|
import torch
|
|
import lightning.pytorch as pl
|
|
from deepspeed.ops.adam import FusedAdam
|
|
from transformers import LlamaForCausalLM, LlamaConfig
|
|
from lightning.pytorch.strategies import DeepSpeedStrategy
|
|
|
|
|
|
batch_size = 2
|
|
seq_length = 2048
|
|
vocab_size = 32000
|
|
total_step = 100
|
|
use_activation_ckpt = False
|
|
|
|
|
|
class FakeSet(torch.utils.data.Dataset):
|
|
def __getitem__(self, idx):
|
|
return torch.randint(0, vocab_size, (seq_length,))
|
|
|
|
def __len__(self):
|
|
return 1000000000
|
|
|
|
|
|
class SpeedTest(pl.LightningModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.model = LlamaForCausalLM(
|
|
LlamaConfig(
|
|
vocab_size=vocab_size,
|
|
)
|
|
)
|
|
if use_activation_ckpt:
|
|
self.model.gradient_checkpointing_enable()
|
|
self.start_time = None
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
out = self.model(batch, labels=batch)
|
|
loss = out.loss
|
|
if self.start_time is None:
|
|
print("start")
|
|
self.start_time = time.time()
|
|
return loss
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = FusedAdam(self.trainer.model.parameters(), lr=1e-5)
|
|
return optimizer
|
|
|
|
|
|
model = SpeedTest()
|
|
train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size)
|
|
|
|
strategy = DeepSpeedStrategy(
|
|
stage=2,
|
|
offload_optimizer=False,
|
|
offload_parameters=False,
|
|
process_group_backend="nccl",
|
|
)
|
|
trainer = pl.Trainer(
|
|
limit_train_batches=total_step,
|
|
max_epochs=1,
|
|
devices=8,
|
|
accelerator="gpu",
|
|
strategy=strategy,
|
|
precision=16,
|
|
enable_checkpointing=False,
|
|
)
|
|
|
|
|
|
def train(model, train_loader):
|
|
start_time = time.time()
|
|
trainer.fit(model=model, train_dataloaders=train_loader)
|
|
end_time = time.time()
|
|
return end_time - model.start_time
|
|
|
|
|
|
print("total time: {}".format(train(model, train_loader)))
|