add speed test
This commit is contained in:
parent
4cb94d2687
commit
0ee9612f40
12
speed_test/accelerate/ddp.yaml
Normal file
12
speed_test/accelerate/ddp.yaml
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
deepspeed_config: {}
|
||||||
|
distributed_type: MULTI_GPU
|
||||||
|
fsdp_config: {}
|
||||||
|
machine_rank: 0
|
||||||
|
main_process_ip: null
|
||||||
|
main_process_port: null
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 2
|
||||||
|
use_cpu: false
|
18
speed_test/accelerate/deepspeed_stage1.yaml
Normal file
18
speed_test/accelerate/deepspeed_stage1.yaml
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
deepspeed_config:
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
gradient_clipping: 1.0
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: true
|
||||||
|
zero_stage: 1
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
fsdp_config: {}
|
||||||
|
machine_rank: 0
|
||||||
|
main_process_ip: null
|
||||||
|
main_process_port: null
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 2
|
||||||
|
use_cpu: false
|
18
speed_test/accelerate/deepspeed_stage2.yaml
Normal file
18
speed_test/accelerate/deepspeed_stage2.yaml
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
deepspeed_config:
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
gradient_clipping: 1.0
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: true
|
||||||
|
zero_stage: 2
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
fsdp_config: {}
|
||||||
|
machine_rank: 0
|
||||||
|
main_process_ip: null
|
||||||
|
main_process_port: null
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 2
|
||||||
|
use_cpu: false
|
19
speed_test/accelerate/deepspeed_stage3.yaml
Normal file
19
speed_test/accelerate/deepspeed_stage3.yaml
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
deepspeed_config:
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
gradient_clipping: 1.0
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: true
|
||||||
|
zero3_save_16bit_model: true
|
||||||
|
zero_stage: 3
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
fsdp_config: {}
|
||||||
|
machine_rank: 0
|
||||||
|
main_process_ip: null
|
||||||
|
main_process_port: null
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 2
|
||||||
|
use_cpu: false
|
24
speed_test/accelerate/deepspeed_stage3_dynamo.yaml
Normal file
24
speed_test/accelerate/deepspeed_stage3_dynamo.yaml
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
deepspeed_config:
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
gradient_clipping: 1.0
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: true
|
||||||
|
zero3_save_16bit_model: true
|
||||||
|
zero_stage: 3
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
fsdp_config: {}
|
||||||
|
dynamo_config:
|
||||||
|
dynamo_backend: INDUCTOR
|
||||||
|
dynamo_mode: default
|
||||||
|
dynamo_use_dynamic: false
|
||||||
|
dynamo_use_fullgraph: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_process_ip: null
|
||||||
|
main_process_port: null
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 2
|
||||||
|
use_cpu: false
|
19
speed_test/accelerate/deepspeed_stage3_offload.yaml
Normal file
19
speed_test/accelerate/deepspeed_stage3_offload.yaml
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
deepspeed_config:
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
gradient_clipping: 1.0
|
||||||
|
offload_optimizer_device: cpu
|
||||||
|
offload_param_device: cpu
|
||||||
|
zero3_init_flag: true
|
||||||
|
zero3_save_16bit_model: true
|
||||||
|
zero_stage: 3
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
fsdp_config: {}
|
||||||
|
machine_rank: 0
|
||||||
|
main_process_ip: null
|
||||||
|
main_process_port: null
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 2
|
||||||
|
use_cpu: false
|
19
speed_test/accelerate/fsdp.yaml
Normal file
19
speed_test/accelerate/fsdp.yaml
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
deepspeed_config: {}
|
||||||
|
distributed_type: FSDP
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_backward_prefetch_policy: BACKWARD_PRE
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_sharding_strategy: 1
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
|
machine_rank: 0
|
||||||
|
main_process_ip: null
|
||||||
|
main_process_port: null
|
||||||
|
main_training_function: main
|
||||||
|
mixed_precision: 'bf16'
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 2
|
||||||
|
use_cpu: false
|
23
speed_test/accelerate/megatron.yaml
Normal file
23
speed_test/accelerate/megatron.yaml
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
deepspeed_config: {}
|
||||||
|
distributed_type: MEGATRON_LM
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
fsdp_config: {}
|
||||||
|
machine_rank: 0
|
||||||
|
main_process_ip: null
|
||||||
|
main_process_port: null
|
||||||
|
main_training_function: main
|
||||||
|
megatron_lm_config:
|
||||||
|
megatron_lm_gradient_clipping: 1.0
|
||||||
|
megatron_lm_num_micro_batches: 2
|
||||||
|
megatron_lm_pp_degree: 2
|
||||||
|
megatron_lm_recompute_activations: true
|
||||||
|
megatron_lm_sequence_parallelism: true
|
||||||
|
megatron_lm_tp_degree: 2
|
||||||
|
megatron_lm_use_distributed_optimizer: true
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 4
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
use_cpu: false
|
61
speed_test/accelerate/run.py
Normal file
61
speed_test/accelerate/run.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
"""
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-04-08 22:44:44
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-04-08 23:15:57
|
||||||
|
FilePath: /Open-Llama/speed_test/accelerate/run.py
|
||||||
|
Description:
|
||||||
|
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from deepspeed.ops.adam import FusedAdam
|
||||||
|
from accelerate import Accelerator, DistributedType
|
||||||
|
from transformers import LlamaForCausalLM, LlamaConfig
|
||||||
|
|
||||||
|
batch_size = 32
|
||||||
|
seq_length = 2048
|
||||||
|
vocab_size = 32000
|
||||||
|
total_step = 2
|
||||||
|
use_activation_ckpt = True
|
||||||
|
|
||||||
|
class FakeSet(torch.utils.data.Dataset):
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return torch.randint(0, vocab_size, (seq_length, ))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 1000000000
|
||||||
|
|
||||||
|
accelerator = Accelerator()
|
||||||
|
raw_model = LlamaForCausalLM(
|
||||||
|
LlamaConfig(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if use_activation_ckpt:
|
||||||
|
raw_model.gradient_checkpointing_enable()
|
||||||
|
optimizer = FusedAdam(raw_model.parameters(), lr=1e-5)
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size)
|
||||||
|
if accelerator.distributed_type == DistributedType.FSDP:
|
||||||
|
accelerator.print('FSDP')
|
||||||
|
model = accelerator.prepare(raw_model)
|
||||||
|
optimizer, train_loader = accelerator.prepare(optimizer, train_loader)
|
||||||
|
else:
|
||||||
|
model, optimizer, train_loader = accelerator.prepare(raw_model, optimizer, train_loader)
|
||||||
|
|
||||||
|
def train(model, optimizer, train_loader):
|
||||||
|
start_time = time.time()
|
||||||
|
for i, batch in enumerate(train_loader):
|
||||||
|
if i == total_step:
|
||||||
|
break
|
||||||
|
optimizer.zero_grad()
|
||||||
|
out = model(input_ids=batch, labels=batch)
|
||||||
|
loss = out.loss
|
||||||
|
accelerator.backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
end_time = time.time()
|
||||||
|
return end_time - start_time
|
||||||
|
|
||||||
|
accelerator.print('total time: {}'.format(train(model, optimizer, train_loader)))
|
12
speed_test/accelerate/run.sh
Normal file
12
speed_test/accelerate/run.sh
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
###
|
||||||
|
# @Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
# @Date: 2023-04-08 22:44:27
|
||||||
|
# @LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
# @LastEditTime: 2023-04-11 21:58:43
|
||||||
|
# @FilePath: /Open-Llama/speed_test/accelerate/run.sh
|
||||||
|
# @Description:
|
||||||
|
#
|
||||||
|
# Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
###
|
||||||
|
total_gpu=8
|
||||||
|
accelerate launch --config_file deepspeed_stage2.yaml --main_process_ip 127.0.0.1 --main_process_port 23335 --num_processes $total_gpu run.py
|
376
speed_test/colossal-ai/run.py
Normal file
376
speed_test/colossal-ai/run.py
Normal file
|
@ -0,0 +1,376 @@
|
||||||
|
"""
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-04-11 20:07:35
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-04-11 21:56:23
|
||||||
|
FilePath: /Open-Llama/speed_test/colossal-ai/run.py
|
||||||
|
Description:
|
||||||
|
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import LlamaForCausalLM, LlamaConfig
|
||||||
|
from utils import get_data, get_profile_context, get_tflops, get_time_stamp
|
||||||
|
from packaging import version
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
|
||||||
|
|
||||||
|
CAI_VERSION = colossalai.__version__
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = colossalai.get_default_parser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--distplan",
|
||||||
|
type=str,
|
||||||
|
default='CAI_Gemini',
|
||||||
|
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tp_degree",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--placement",
|
||||||
|
type=str,
|
||||||
|
default='cpu',
|
||||||
|
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--shardinit",
|
||||||
|
action='store_true',
|
||||||
|
help=
|
||||||
|
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="batch size per DP group of training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_type",
|
||||||
|
type=str,
|
||||||
|
default="Llama-7B",
|
||||||
|
help="model model scale",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_step",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="training iterations for test",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def model_builder(VOCAB_SIZE, checkpoint=False):
|
||||||
|
raw_model = LlamaForCausalLM(
|
||||||
|
LlamaConfig(
|
||||||
|
vocab_size=VOCAB_SIZE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if checkpoint:
|
||||||
|
raw_model.gradient_checkpointing_enable()
|
||||||
|
return raw_model
|
||||||
|
|
||||||
|
|
||||||
|
# Parameter Sharding Strategies for Tensor Parallelism
|
||||||
|
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||||
|
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
|
param.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
||||||
|
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||||
|
split_param_single_dim_tp1d(0, param, pg)
|
||||||
|
|
||||||
|
|
||||||
|
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||||
|
split_param_single_dim_tp1d(-1, param, pg)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTLMLoss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.loss_fn = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def forward(self, logits, labels):
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||||
|
|
||||||
|
|
||||||
|
def get_cpu_mem():
|
||||||
|
return psutil.Process().memory_info().rss / 1024**2
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_mem():
|
||||||
|
return torch.cuda.memory_allocated() / 1024**2
|
||||||
|
|
||||||
|
|
||||||
|
def get_mem_info(prefix=''):
|
||||||
|
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_size(model: nn.Module):
|
||||||
|
total_numel = 0
|
||||||
|
for module in model.modules():
|
||||||
|
for p in module.parameters(recurse=False):
|
||||||
|
total_numel += p.numel()
|
||||||
|
return total_numel
|
||||||
|
|
||||||
|
|
||||||
|
def model_size_formatter(numel: int) -> str:
|
||||||
|
GB_SIZE = 10**9
|
||||||
|
MB_SIZE = 10**6
|
||||||
|
KB_SIZE = 10**3
|
||||||
|
if numel >= GB_SIZE:
|
||||||
|
return f'{numel / GB_SIZE:.1f}B'
|
||||||
|
elif numel >= MB_SIZE:
|
||||||
|
return f'{numel / MB_SIZE:.1f}M'
|
||||||
|
elif numel >= KB_SIZE:
|
||||||
|
return f'{numel / KB_SIZE:.1f}K'
|
||||||
|
else:
|
||||||
|
return str(numel)
|
||||||
|
|
||||||
|
|
||||||
|
def set_cpu_maximum_parallelism():
|
||||||
|
conf_str = torch.__config__.parallel_info()
|
||||||
|
inter_str = conf_str.split("hardware_concurrency() : ")[1]
|
||||||
|
max_concurrency = inter_str.split('\n')[0]
|
||||||
|
os.environ["OMP_NUM_THREADS"] = max_concurrency
|
||||||
|
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
|
||||||
|
|
||||||
|
|
||||||
|
# Tensor Parallel
|
||||||
|
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||||
|
"""tensor_parallelize
|
||||||
|
Sharding the Model Parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): a torch module to be sharded
|
||||||
|
"""
|
||||||
|
for mn, module in model.named_modules():
|
||||||
|
for pn, param in module.named_parameters(recurse=False):
|
||||||
|
# NOTE() a param maybe shared by two modules
|
||||||
|
if hasattr(param, 'visited'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# if shard init, then convert param to replica and use the dp-only ProcessGroup
|
||||||
|
param: ColoParameter = param
|
||||||
|
param.set_dist_spec(ReplicaSpec())
|
||||||
|
param.set_process_group(pg)
|
||||||
|
|
||||||
|
# shard it w.r.t tp pattern
|
||||||
|
if 'mlp.c_fc' in mn:
|
||||||
|
if 'weight' in pn or 'bias' in pn:
|
||||||
|
split_param_col_tp1d(param, pg) # colmn slice
|
||||||
|
# keep the shape of the output from c_fc
|
||||||
|
param.compute_spec.set_output_replicate(False)
|
||||||
|
else:
|
||||||
|
param.set_dist_spec(ReplicaSpec())
|
||||||
|
elif 'mlp.c_proj' in mn:
|
||||||
|
if 'weight' in pn:
|
||||||
|
split_param_row_tp1d(param, pg) # row slice
|
||||||
|
else:
|
||||||
|
param.set_dist_spec(ReplicaSpec())
|
||||||
|
elif 'wte' in mn or 'wpe' in mn:
|
||||||
|
split_param_col_tp1d(param, pg) # colmn slice
|
||||||
|
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||||
|
split_param_col_tp1d(param, pg) # colmn slice
|
||||||
|
else:
|
||||||
|
param.set_dist_spec(ReplicaSpec())
|
||||||
|
param.visited = True
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# version check
|
||||||
|
# this example is supposed to work for versions greater than 0.2.0
|
||||||
|
assert version.parse(CAI_VERSION) >= version.parse("0.2.0")
|
||||||
|
|
||||||
|
set_cpu_maximum_parallelism()
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
|
||||||
|
if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]:
|
||||||
|
raise TypeError(f"{args.distplan} is error")
|
||||||
|
|
||||||
|
# batch size per DP degree
|
||||||
|
BATCH_SIZE = args.batch_size
|
||||||
|
SEQ_LEN = 2048
|
||||||
|
VOCAB_SIZE = 32000
|
||||||
|
|
||||||
|
NUM_STEPS = args.train_step
|
||||||
|
|
||||||
|
WARMUP_STEPS = 1
|
||||||
|
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
|
||||||
|
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
|
||||||
|
PROF_FLAG = False # The flag of profiling, False by default
|
||||||
|
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch_from_torch(config={})
|
||||||
|
|
||||||
|
logger = get_dist_logger()
|
||||||
|
logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])
|
||||||
|
|
||||||
|
# build criterion
|
||||||
|
criterion = GPTLMLoss()
|
||||||
|
|
||||||
|
torch.manual_seed(123)
|
||||||
|
if args.distplan.startswith("CAI"):
|
||||||
|
# all param must use the same process group.
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
|
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
|
||||||
|
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
|
||||||
|
|
||||||
|
if args.shardinit and args.distplan != "CAI_Gemini":
|
||||||
|
raise RuntimeError("You can only use shardinit with CAI_Gemini")
|
||||||
|
|
||||||
|
# build GPT model
|
||||||
|
with ColoInitContext(device=get_current_device(),
|
||||||
|
dtype=torch.half,
|
||||||
|
default_dist_spec=default_dist_spec,
|
||||||
|
default_pg=shard_pg):
|
||||||
|
model = model_builder(VOCAB_SIZE, checkpoint=True)
|
||||||
|
|
||||||
|
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||||
|
# Tensor Parallelism (TP)
|
||||||
|
# You should notice that v0.1.10 is not compatible with TP degree > 1
|
||||||
|
if args.tp_degree > 1:
|
||||||
|
tensor_parallelize(model, tp_pg)
|
||||||
|
|
||||||
|
# asign running configurations
|
||||||
|
gemini_config = None
|
||||||
|
if args.distplan.startswith("CAI_ZeRO"):
|
||||||
|
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
|
||||||
|
elif args.distplan == "CAI_Gemini":
|
||||||
|
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
|
||||||
|
device=get_current_device(),
|
||||||
|
placement_policy=args.placement,
|
||||||
|
pin_memory=True,
|
||||||
|
hidden_dim=model.model.config.hidden_size,
|
||||||
|
search_range_mb=128)
|
||||||
|
optim_config = dict(gpu_margin_mem_ratio=0.)
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
# build a highly optimized gpu/cpu optimizer
|
||||||
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
if args.distplan == "CAI_ZeRO1":
|
||||||
|
zero_stage = 1
|
||||||
|
elif args.distplan == "CAI_ZeRO2":
|
||||||
|
zero_stage = 2
|
||||||
|
elif args.distplan == "CAI_Gemini":
|
||||||
|
zero_stage = 3
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
# wrap your model and optimizer
|
||||||
|
model = zero_model_wrapper(model, zero_stage, gemini_config)
|
||||||
|
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
|
||||||
|
|
||||||
|
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||||
|
elif args.distplan.startswith("Pytorch"):
|
||||||
|
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
|
||||||
|
model = model_builder(VOCAB_SIZE, checkpoint=True).cuda()
|
||||||
|
model = DDP(model)
|
||||||
|
if args.distplan.endswith("DDP"):
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
|
elif args.distplan.endswith("ZeRO"):
|
||||||
|
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||||
|
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
# model is shared after TP
|
||||||
|
numel = get_model_size(model)
|
||||||
|
logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
|
||||||
|
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||||
|
|
||||||
|
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
|
||||||
|
# = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
|
||||||
|
# = batch_per_DP_group * numel * seq_len * 8
|
||||||
|
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
model.train()
|
||||||
|
tflops_list = []
|
||||||
|
|
||||||
|
def train_step():
|
||||||
|
# we just use randomly generated data here
|
||||||
|
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
start = time()
|
||||||
|
outputs = model(input_ids, attn_mask)[0]
|
||||||
|
loss = criterion(outputs, input_ids)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
fwd_end = time()
|
||||||
|
fwd_time = fwd_end - start
|
||||||
|
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
|
||||||
|
|
||||||
|
if args.distplan.startswith("CAI"):
|
||||||
|
optimizer.backward(loss)
|
||||||
|
elif args.distplan.startswith("Pytorch"):
|
||||||
|
loss.backward()
|
||||||
|
else:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
bwd_end = time()
|
||||||
|
bwd_time = bwd_end - fwd_end
|
||||||
|
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
optim_time = time() - bwd_end
|
||||||
|
step_time = time() - start
|
||||||
|
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
|
||||||
|
|
||||||
|
step_tflops = get_tflops_func(step_time)
|
||||||
|
logger.info(
|
||||||
|
f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
|
||||||
|
ranks=[0],
|
||||||
|
)
|
||||||
|
if n >= WARMUP_STEPS:
|
||||||
|
tflops_list.append(step_tflops)
|
||||||
|
|
||||||
|
demo_profiler = get_profile_context(PROF_FLAG,
|
||||||
|
WARMUP_STEPS,
|
||||||
|
NUM_STEPS - WARMUP_STEPS,
|
||||||
|
save_dir=f"profile/{get_time_stamp()}-demo")
|
||||||
|
|
||||||
|
with demo_profiler as prof:
|
||||||
|
start_time = time()
|
||||||
|
for n in range(NUM_STEPS):
|
||||||
|
train_step()
|
||||||
|
prof.step()
|
||||||
|
end_time = time()
|
||||||
|
print('total time: {}'.format(end_time - start_time))
|
||||||
|
|
||||||
|
tflops_list.sort()
|
||||||
|
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
|
||||||
|
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
31
speed_test/colossal-ai/run.sh
Normal file
31
speed_test/colossal-ai/run.sh
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
set -x
|
||||||
|
# distplan in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]
|
||||||
|
export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
|
||||||
|
|
||||||
|
# The following options only valid when DISTPLAN="colossalai"
|
||||||
|
export GPUNUM=${GPUNUM:-8}
|
||||||
|
export TPDEGREE=${TPDEGREE:-1}
|
||||||
|
export PLACEMENT=${PLACEMENT:-"auto"}
|
||||||
|
export USE_SHARD_INIT=${USE_SHARD_INIT:-True}
|
||||||
|
export BATCH_SIZE=${BATCH_SIZE:-40}
|
||||||
|
export MODEL_TYPE=${MODEL_TYPE:-"Llama-7B"}
|
||||||
|
export TRAIN_STEP=${TRAIN_STEP:-10}
|
||||||
|
# export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
|
||||||
|
if [ ${USE_SHARD_INIT} = "True" ]; then
|
||||||
|
USE_SHARD_INIT="--shardinit"
|
||||||
|
else
|
||||||
|
USE_SHARD_INIT=""
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p gemini_logs
|
||||||
|
|
||||||
|
torchrun --nproc_per_node=${GPUNUM} --rdzv_endpoint=127.0.0.1:23335 run.py \
|
||||||
|
--tp_degree=${TPDEGREE} \
|
||||||
|
--model_type=${MODEL_TYPE} \
|
||||||
|
--batch_size=${BATCH_SIZE} \
|
||||||
|
--placement=${PLACEMENT} \
|
||||||
|
${USE_SHARD_INIT} \
|
||||||
|
--distplan=${DISTPLAN} \
|
||||||
|
--train_step=${TRAIN_STEP} \
|
||||||
|
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
|
41
speed_test/colossal-ai/utils.py
Normal file
41
speed_test/colossal-ai/utils.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
|
||||||
|
|
||||||
|
|
||||||
|
class DummyProfiler:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.step_number = 0
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.step_number += 1
|
||||||
|
|
||||||
|
|
||||||
|
# Randomly Generated Data
|
||||||
|
def get_data(batch_size, seq_len, vocab_size):
|
||||||
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
return input_ids, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||||
|
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||||
|
|
||||||
|
|
||||||
|
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
|
||||||
|
if enable_flag:
|
||||||
|
return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||||
|
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
|
||||||
|
on_trace_ready=tensorboard_trace_handler(save_dir),
|
||||||
|
record_shapes=True,
|
||||||
|
profile_memory=True)
|
||||||
|
else:
|
||||||
|
return nullcontext(DummyProfiler())
|
||||||
|
|
||||||
|
|
||||||
|
def get_time_stamp():
|
||||||
|
cur_time = time.strftime("%d-%H:%M", time.localtime())
|
||||||
|
return cur_time
|
80
speed_test/lightning/run.py
Normal file
80
speed_test/lightning/run.py
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
"""
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-04-11 20:07:35
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-04-11 21:56:07
|
||||||
|
FilePath: /Open-Llama/speed_test/lightning/run.py
|
||||||
|
Description:
|
||||||
|
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import lightning.pytorch as pl
|
||||||
|
from deepspeed.ops.adam import FusedAdam
|
||||||
|
from transformers import LlamaForCausalLM, LlamaConfig
|
||||||
|
from lightning.pytorch.strategies import DeepSpeedStrategy
|
||||||
|
|
||||||
|
|
||||||
|
batch_size = 2
|
||||||
|
seq_length = 2048
|
||||||
|
vocab_size = 32000
|
||||||
|
total_step = 100
|
||||||
|
use_activation_ckpt = False
|
||||||
|
|
||||||
|
class FakeSet(torch.utils.data.Dataset):
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return torch.randint(0, vocab_size, (seq_length, ))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 1000000000
|
||||||
|
|
||||||
|
class SpeedTest(pl.LightningModule):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.model = LlamaForCausalLM(
|
||||||
|
LlamaConfig(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if use_activation_ckpt:
|
||||||
|
self.model.gradient_checkpointing_enable()
|
||||||
|
self.start_time = None
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
out = self.model(batch, labels=batch)
|
||||||
|
loss = out.loss
|
||||||
|
if self.start_time is None:
|
||||||
|
print('start')
|
||||||
|
self.start_time = time.time()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = FusedAdam(self.trainer.model.parameters(), lr=1e-5)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
model = SpeedTest()
|
||||||
|
train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size)
|
||||||
|
|
||||||
|
strategy=DeepSpeedStrategy(
|
||||||
|
stage=2,
|
||||||
|
offload_optimizer=False,
|
||||||
|
offload_parameters=False,
|
||||||
|
process_group_backend="nccl"
|
||||||
|
)
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
limit_train_batches=total_step,
|
||||||
|
max_epochs=1,
|
||||||
|
devices=8,
|
||||||
|
accelerator="gpu",
|
||||||
|
strategy=strategy,
|
||||||
|
precision=16,
|
||||||
|
enable_checkpointing=False)
|
||||||
|
|
||||||
|
def train(model, train_loader):
|
||||||
|
start_time = time.time()
|
||||||
|
trainer.fit(model=model, train_dataloaders=train_loader)
|
||||||
|
end_time = time.time()
|
||||||
|
return end_time - model.start_time
|
||||||
|
|
||||||
|
print('total time: {}'.format(train(model, train_loader)))
|
Loading…
Reference in New Issue
Block a user