"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-11 20:07:35
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-11 21:56:07
FilePath: /Open-Llama/speed_test/lightning/run.py
Description: 

Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. 
"""
import time
import torch
import lightning.pytorch as pl
from deepspeed.ops.adam import FusedAdam
from transformers import LlamaForCausalLM, LlamaConfig
from lightning.pytorch.strategies import DeepSpeedStrategy


batch_size = 2
seq_length = 2048
vocab_size = 32000
total_step = 100
use_activation_ckpt = False


class FakeSet(torch.utils.data.Dataset):
    def __getitem__(self, idx):
        return torch.randint(0, vocab_size, (seq_length,))

    def __len__(self):
        return 1000000000


class SpeedTest(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = LlamaForCausalLM(
            LlamaConfig(
                vocab_size=vocab_size,
            )
        )
        if use_activation_ckpt:
            self.model.gradient_checkpointing_enable()
        self.start_time = None

    def training_step(self, batch, batch_idx):
        out = self.model(batch, labels=batch)
        loss = out.loss
        if self.start_time is None:
            print("start")
            self.start_time = time.time()
        return loss

    def configure_optimizers(self):
        optimizer = FusedAdam(self.trainer.model.parameters(), lr=1e-5)
        return optimizer


model = SpeedTest()
train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size)

strategy = DeepSpeedStrategy(
    stage=2,
    offload_optimizer=False,
    offload_parameters=False,
    process_group_backend="nccl",
)
trainer = pl.Trainer(
    limit_train_batches=total_step,
    max_epochs=1,
    devices=8,
    accelerator="gpu",
    strategy=strategy,
    precision=16,
    enable_checkpointing=False,
)


def train(model, train_loader):
    start_time = time.time()
    trainer.fit(model=model, train_dataloaders=train_loader)
    end_time = time.time()
    return end_time - model.start_time


print("total time: {}".format(train(model, train_loader)))