add speed test

This commit is contained in:
LiangSong 2023-04-11 21:59:18 +08:00
parent 4cb94d2687
commit 0ee9612f40
14 changed files with 753 additions and 0 deletions

View File

@ -0,0 +1,12 @@
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: MULTI_GPU
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
use_cpu: false

View File

@ -0,0 +1,18 @@
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero_stage: 1
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
use_cpu: false

View File

@ -0,0 +1,18 @@
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero_stage: 2
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
use_cpu: false

View File

@ -0,0 +1,19 @@
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
use_cpu: false

View File

@ -0,0 +1,24 @@
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
fsdp_config: {}
dynamo_config:
dynamo_backend: INDUCTOR
dynamo_mode: default
dynamo_use_dynamic: false
dynamo_use_fullgraph: false
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
use_cpu: false

View File

@ -0,0 +1,19 @@
compute_environment: LOCAL_MACHINE
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
use_cpu: false

View File

@ -0,0 +1,19 @@
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 2
use_cpu: false

View File

@ -0,0 +1,23 @@
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: MEGATRON_LM
downcast_bf16: 'no'
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
megatron_lm_config:
megatron_lm_gradient_clipping: 1.0
megatron_lm_num_micro_batches: 2
megatron_lm_pp_degree: 2
megatron_lm_recompute_activations: true
megatron_lm_sequence_parallelism: true
megatron_lm_tp_degree: 2
megatron_lm_use_distributed_optimizer: true
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
use_cpu: false

View File

@ -0,0 +1,61 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-08 22:44:44
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-08 23:15:57
FilePath: /Open-Llama/speed_test/accelerate/run.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import time
import torch
from deepspeed.ops.adam import FusedAdam
from accelerate import Accelerator, DistributedType
from transformers import LlamaForCausalLM, LlamaConfig
batch_size = 32
seq_length = 2048
vocab_size = 32000
total_step = 2
use_activation_ckpt = True
class FakeSet(torch.utils.data.Dataset):
def __getitem__(self, idx):
return torch.randint(0, vocab_size, (seq_length, ))
def __len__(self):
return 1000000000
accelerator = Accelerator()
raw_model = LlamaForCausalLM(
LlamaConfig(
vocab_size=vocab_size,
)
)
if use_activation_ckpt:
raw_model.gradient_checkpointing_enable()
optimizer = FusedAdam(raw_model.parameters(), lr=1e-5)
train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size)
if accelerator.distributed_type == DistributedType.FSDP:
accelerator.print('FSDP')
model = accelerator.prepare(raw_model)
optimizer, train_loader = accelerator.prepare(optimizer, train_loader)
else:
model, optimizer, train_loader = accelerator.prepare(raw_model, optimizer, train_loader)
def train(model, optimizer, train_loader):
start_time = time.time()
for i, batch in enumerate(train_loader):
if i == total_step:
break
optimizer.zero_grad()
out = model(input_ids=batch, labels=batch)
loss = out.loss
accelerator.backward(loss)
optimizer.step()
end_time = time.time()
return end_time - start_time
accelerator.print('total time: {}'.format(train(model, optimizer, train_loader)))

View File

@ -0,0 +1,12 @@
###
# @Author: LiangSong(sl12160010@gmail.com)
# @Date: 2023-04-08 22:44:27
# @LastEditors: LiangSong(sl12160010@gmail.com)
# @LastEditTime: 2023-04-11 21:58:43
# @FilePath: /Open-Llama/speed_test/accelerate/run.sh
# @Description:
#
# Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
###
total_gpu=8
accelerate launch --config_file deepspeed_stage2.yaml --main_process_ip 127.0.0.1 --main_process_port 23335 --num_processes $total_gpu run.py

View File

@ -0,0 +1,376 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-11 20:07:35
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-11 21:56:23
FilePath: /Open-Llama/speed_test/colossal-ai/run.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import os
from functools import partial
from time import time
import psutil
import torch
import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaConfig
from utils import get_data, get_profile_context, get_tflops, get_time_stamp
from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
CAI_VERSION = colossalai.__version__
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--distplan",
type=str,
default='CAI_Gemini',
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
)
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
action='store_true',
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help="batch size per DP group of training.",
)
parser.add_argument(
"--model_type",
type=str,
default="Llama-7B",
help="model model scale",
)
parser.add_argument(
"--train_step",
type=int,
default=10,
help="training iterations for test",
)
args = parser.parse_args()
return args
def model_builder(VOCAB_SIZE, checkpoint=False):
raw_model = LlamaForCausalLM(
LlamaConfig(
vocab_size=VOCAB_SIZE,
)
)
if checkpoint:
raw_model.gradient_checkpointing_enable()
return raw_model
# Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
def get_cpu_mem():
return psutil.Process().memory_info().rss / 1024**2
def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2
def get_mem_info(prefix=''):
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
def get_model_size(model: nn.Module):
total_numel = 0
for module in model.modules():
for p in module.parameters(recurse=False):
total_numel += p.numel()
return total_numel
def model_size_formatter(numel: int) -> str:
GB_SIZE = 10**9
MB_SIZE = 10**6
KB_SIZE = 10**3
if numel >= GB_SIZE:
return f'{numel / GB_SIZE:.1f}B'
elif numel >= MB_SIZE:
return f'{numel / MB_SIZE:.1f}M'
elif numel >= KB_SIZE:
return f'{numel / KB_SIZE:.1f}K'
else:
return str(numel)
def set_cpu_maximum_parallelism():
conf_str = torch.__config__.parallel_info()
inter_str = conf_str.split("hardware_concurrency() : ")[1]
max_concurrency = inter_str.split('\n')[0]
os.environ["OMP_NUM_THREADS"] = max_concurrency
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
# NOTE() a param maybe shared by two modules
if hasattr(param, 'visited'):
continue
# if shard init, then convert param to replica and use the dp-only ProcessGroup
param: ColoParameter = param
param.set_dist_spec(ReplicaSpec())
param.set_process_group(pg)
# shard it w.r.t tp pattern
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # colmn slice
# keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False)
else:
param.set_dist_spec(ReplicaSpec())
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg) # colmn slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
def main():
# version check
# this example is supposed to work for versions greater than 0.2.0
assert version.parse(CAI_VERSION) >= version.parse("0.2.0")
set_cpu_maximum_parallelism()
args = parse_args()
# if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]:
raise TypeError(f"{args.distplan} is error")
# batch size per DP degree
BATCH_SIZE = args.batch_size
SEQ_LEN = 2048
VOCAB_SIZE = 32000
NUM_STEPS = args.train_step
WARMUP_STEPS = 1
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
PROF_FLAG = False # The flag of profiling, False by default
disable_existing_loggers()
colossalai.launch_from_torch(config={})
logger = get_dist_logger()
logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])
# build criterion
criterion = GPTLMLoss()
torch.manual_seed(123)
if args.distplan.startswith("CAI"):
# all param must use the same process group.
world_size = torch.distributed.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
if args.shardinit and args.distplan != "CAI_Gemini":
raise RuntimeError("You can only use shardinit with CAI_Gemini")
# build GPT model
with ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = model_builder(VOCAB_SIZE, checkpoint=True)
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
# Tensor Parallelism (TP)
# You should notice that v0.1.10 is not compatible with TP degree > 1
if args.tp_degree > 1:
tensor_parallelize(model, tp_pg)
# asign running configurations
gemini_config = None
if args.distplan.startswith("CAI_ZeRO"):
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
elif args.distplan == "CAI_Gemini":
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
device=get_current_device(),
placement_policy=args.placement,
pin_memory=True,
hidden_dim=model.model.config.hidden_size,
search_range_mb=128)
optim_config = dict(gpu_margin_mem_ratio=0.)
else:
raise RuntimeError
# build a highly optimized gpu/cpu optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3)
if args.distplan == "CAI_ZeRO1":
zero_stage = 1
elif args.distplan == "CAI_ZeRO2":
zero_stage = 2
elif args.distplan == "CAI_Gemini":
zero_stage = 3
else:
raise RuntimeError
# wrap your model and optimizer
model = zero_model_wrapper(model, zero_stage, gemini_config)
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
elif args.distplan.startswith("Pytorch"):
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
model = model_builder(VOCAB_SIZE, checkpoint=True).cuda()
model = DDP(model)
if args.distplan.endswith("DDP"):
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
elif args.distplan.endswith("ZeRO"):
from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
else:
raise RuntimeError
# model is shared after TP
numel = get_model_size(model)
logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
# = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
# = batch_per_DP_group * numel * seq_len * 8
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
torch.cuda.synchronize()
model.train()
tflops_list = []
def train_step():
# we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
optimizer.zero_grad()
start = time()
outputs = model(input_ids, attn_mask)[0]
loss = criterion(outputs, input_ids)
torch.cuda.synchronize()
fwd_end = time()
fwd_time = fwd_end - start
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
if args.distplan.startswith("CAI"):
optimizer.backward(loss)
elif args.distplan.startswith("Pytorch"):
loss.backward()
else:
raise RuntimeError
torch.cuda.synchronize()
bwd_end = time()
bwd_time = bwd_end - fwd_end
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
optimizer.step()
torch.cuda.synchronize()
optim_time = time() - bwd_end
step_time = time() - start
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
step_tflops = get_tflops_func(step_time)
logger.info(
f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
ranks=[0],
)
if n >= WARMUP_STEPS:
tflops_list.append(step_tflops)
demo_profiler = get_profile_context(PROF_FLAG,
WARMUP_STEPS,
NUM_STEPS - WARMUP_STEPS,
save_dir=f"profile/{get_time_stamp()}-demo")
with demo_profiler as prof:
start_time = time()
for n in range(NUM_STEPS):
train_step()
prof.step()
end_time = time()
print('total time: {}'.format(end_time - start_time))
tflops_list.sort()
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
torch.cuda.synchronize()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,31 @@
set -x
# distplan in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]
export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
# The following options only valid when DISTPLAN="colossalai"
export GPUNUM=${GPUNUM:-8}
export TPDEGREE=${TPDEGREE:-1}
export PLACEMENT=${PLACEMENT:-"auto"}
export USE_SHARD_INIT=${USE_SHARD_INIT:-True}
export BATCH_SIZE=${BATCH_SIZE:-40}
export MODEL_TYPE=${MODEL_TYPE:-"Llama-7B"}
export TRAIN_STEP=${TRAIN_STEP:-10}
# export PYTHONPATH=$PWD:$PYTHONPATH
if [ ${USE_SHARD_INIT} = "True" ]; then
USE_SHARD_INIT="--shardinit"
else
USE_SHARD_INIT=""
fi
mkdir -p gemini_logs
torchrun --nproc_per_node=${GPUNUM} --rdzv_endpoint=127.0.0.1:23335 run.py \
--tp_degree=${TPDEGREE} \
--model_type=${MODEL_TYPE} \
--batch_size=${BATCH_SIZE} \
--placement=${PLACEMENT} \
${USE_SHARD_INIT} \
--distplan=${DISTPLAN} \
--train_step=${TRAIN_STEP} \
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log

View File

@ -0,0 +1,41 @@
import time
from contextlib import nullcontext
import torch
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
class DummyProfiler:
def __init__(self):
self.step_number = 0
def step(self):
self.step_number += 1
# Randomly Generated Data
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
if enable_flag:
return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
on_trace_ready=tensorboard_trace_handler(save_dir),
record_shapes=True,
profile_memory=True)
else:
return nullcontext(DummyProfiler())
def get_time_stamp():
cur_time = time.strftime("%d-%H:%M", time.localtime())
return cur_time

View File

@ -0,0 +1,80 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-11 20:07:35
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-11 21:56:07
FilePath: /Open-Llama/speed_test/lightning/run.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import time
import torch
import lightning.pytorch as pl
from deepspeed.ops.adam import FusedAdam
from transformers import LlamaForCausalLM, LlamaConfig
from lightning.pytorch.strategies import DeepSpeedStrategy
batch_size = 2
seq_length = 2048
vocab_size = 32000
total_step = 100
use_activation_ckpt = False
class FakeSet(torch.utils.data.Dataset):
def __getitem__(self, idx):
return torch.randint(0, vocab_size, (seq_length, ))
def __len__(self):
return 1000000000
class SpeedTest(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = LlamaForCausalLM(
LlamaConfig(
vocab_size=vocab_size,
)
)
if use_activation_ckpt:
self.model.gradient_checkpointing_enable()
self.start_time = None
def training_step(self, batch, batch_idx):
out = self.model(batch, labels=batch)
loss = out.loss
if self.start_time is None:
print('start')
self.start_time = time.time()
return loss
def configure_optimizers(self):
optimizer = FusedAdam(self.trainer.model.parameters(), lr=1e-5)
return optimizer
model = SpeedTest()
train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size)
strategy=DeepSpeedStrategy(
stage=2,
offload_optimizer=False,
offload_parameters=False,
process_group_backend="nccl"
)
trainer = pl.Trainer(
limit_train_batches=total_step,
max_epochs=1,
devices=8,
accelerator="gpu",
strategy=strategy,
precision=16,
enable_checkpointing=False)
def train(model, train_loader):
start_time = time.time()
trainer.fit(model=model, train_dataloaders=train_loader)
end_time = time.time()
return end_time - model.start_time
print('total time: {}'.format(train(model, train_loader)))