From 0ee9612f40dac73b9ce6c64de0492def9ba32983 Mon Sep 17 00:00:00 2001 From: LiangSong Date: Tue, 11 Apr 2023 21:59:18 +0800 Subject: [PATCH] add speed test --- speed_test/accelerate/ddp.yaml | 12 + speed_test/accelerate/deepspeed_stage1.yaml | 18 + speed_test/accelerate/deepspeed_stage2.yaml | 18 + speed_test/accelerate/deepspeed_stage3.yaml | 19 + .../accelerate/deepspeed_stage3_dynamo.yaml | 24 ++ .../accelerate/deepspeed_stage3_offload.yaml | 19 + speed_test/accelerate/fsdp.yaml | 19 + speed_test/accelerate/megatron.yaml | 23 ++ speed_test/accelerate/run.py | 61 +++ speed_test/accelerate/run.sh | 12 + speed_test/colossal-ai/run.py | 376 ++++++++++++++++++ speed_test/colossal-ai/run.sh | 31 ++ speed_test/colossal-ai/utils.py | 41 ++ speed_test/lightning/run.py | 80 ++++ 14 files changed, 753 insertions(+) create mode 100644 speed_test/accelerate/ddp.yaml create mode 100644 speed_test/accelerate/deepspeed_stage1.yaml create mode 100644 speed_test/accelerate/deepspeed_stage2.yaml create mode 100644 speed_test/accelerate/deepspeed_stage3.yaml create mode 100644 speed_test/accelerate/deepspeed_stage3_dynamo.yaml create mode 100644 speed_test/accelerate/deepspeed_stage3_offload.yaml create mode 100644 speed_test/accelerate/fsdp.yaml create mode 100644 speed_test/accelerate/megatron.yaml create mode 100644 speed_test/accelerate/run.py create mode 100644 speed_test/accelerate/run.sh create mode 100644 speed_test/colossal-ai/run.py create mode 100644 speed_test/colossal-ai/run.sh create mode 100644 speed_test/colossal-ai/utils.py create mode 100644 speed_test/lightning/run.py diff --git a/speed_test/accelerate/ddp.yaml b/speed_test/accelerate/ddp.yaml new file mode 100644 index 0000000..e96b991 --- /dev/null +++ b/speed_test/accelerate/ddp.yaml @@ -0,0 +1,12 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: {} +distributed_type: MULTI_GPU +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: null +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +use_cpu: false \ No newline at end of file diff --git a/speed_test/accelerate/deepspeed_stage1.yaml b/speed_test/accelerate/deepspeed_stage1.yaml new file mode 100644 index 0000000..f03f6a3 --- /dev/null +++ b/speed_test/accelerate/deepspeed_stage1.yaml @@ -0,0 +1,18 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero_stage: 1 +distributed_type: DEEPSPEED +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: null +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +use_cpu: false \ No newline at end of file diff --git a/speed_test/accelerate/deepspeed_stage2.yaml b/speed_test/accelerate/deepspeed_stage2.yaml new file mode 100644 index 0000000..8c01eea --- /dev/null +++ b/speed_test/accelerate/deepspeed_stage2.yaml @@ -0,0 +1,18 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero_stage: 2 +distributed_type: DEEPSPEED +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: null +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +use_cpu: false \ No newline at end of file diff --git a/speed_test/accelerate/deepspeed_stage3.yaml b/speed_test/accelerate/deepspeed_stage3.yaml new file mode 100644 index 0000000..5bd17c2 --- /dev/null +++ b/speed_test/accelerate/deepspeed_stage3.yaml @@ -0,0 +1,19 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: null +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +use_cpu: false \ No newline at end of file diff --git a/speed_test/accelerate/deepspeed_stage3_dynamo.yaml b/speed_test/accelerate/deepspeed_stage3_dynamo.yaml new file mode 100644 index 0000000..68405ad --- /dev/null +++ b/speed_test/accelerate/deepspeed_stage3_dynamo.yaml @@ -0,0 +1,24 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +fsdp_config: {} +dynamo_config: + dynamo_backend: INDUCTOR + dynamo_mode: default + dynamo_use_dynamic: false + dynamo_use_fullgraph: false +machine_rank: 0 +main_process_ip: null +main_process_port: null +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +use_cpu: false \ No newline at end of file diff --git a/speed_test/accelerate/deepspeed_stage3_offload.yaml b/speed_test/accelerate/deepspeed_stage3_offload.yaml new file mode 100644 index 0000000..f02e951 --- /dev/null +++ b/speed_test/accelerate/deepspeed_stage3_offload.yaml @@ -0,0 +1,19 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: null +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 2 +use_cpu: false \ No newline at end of file diff --git a/speed_test/accelerate/fsdp.yaml b/speed_test/accelerate/fsdp.yaml new file mode 100644 index 0000000..f4d6b5f --- /dev/null +++ b/speed_test/accelerate/fsdp.yaml @@ -0,0 +1,19 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: {} +distributed_type: FSDP +downcast_bf16: 'no' +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch_policy: BACKWARD_PRE + fsdp_offload_params: false + fsdp_sharding_strategy: 1 + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer +machine_rank: 0 +main_process_ip: null +main_process_port: null +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 2 +use_cpu: false \ No newline at end of file diff --git a/speed_test/accelerate/megatron.yaml b/speed_test/accelerate/megatron.yaml new file mode 100644 index 0000000..7119b79 --- /dev/null +++ b/speed_test/accelerate/megatron.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: {} +distributed_type: MEGATRON_LM +downcast_bf16: 'no' +fsdp_config: {} +machine_rank: 0 +main_process_ip: null +main_process_port: null +main_training_function: main +megatron_lm_config: + megatron_lm_gradient_clipping: 1.0 + megatron_lm_num_micro_batches: 2 + megatron_lm_pp_degree: 2 + megatron_lm_recompute_activations: true + megatron_lm_sequence_parallelism: true + megatron_lm_tp_degree: 2 + megatron_lm_use_distributed_optimizer: true +mixed_precision: bf16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +use_cpu: false \ No newline at end of file diff --git a/speed_test/accelerate/run.py b/speed_test/accelerate/run.py new file mode 100644 index 0000000..9a3b8ac --- /dev/null +++ b/speed_test/accelerate/run.py @@ -0,0 +1,61 @@ +""" +Author: LiangSong(sl12160010@gmail.com) +Date: 2023-04-08 22:44:44 +LastEditors: LiangSong(sl12160010@gmail.com) +LastEditTime: 2023-04-08 23:15:57 +FilePath: /Open-Llama/speed_test/accelerate/run.py +Description: + +Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. +""" +import time +import torch +from deepspeed.ops.adam import FusedAdam +from accelerate import Accelerator, DistributedType +from transformers import LlamaForCausalLM, LlamaConfig + +batch_size = 32 +seq_length = 2048 +vocab_size = 32000 +total_step = 2 +use_activation_ckpt = True + +class FakeSet(torch.utils.data.Dataset): + def __getitem__(self, idx): + return torch.randint(0, vocab_size, (seq_length, )) + + def __len__(self): + return 1000000000 + +accelerator = Accelerator() +raw_model = LlamaForCausalLM( + LlamaConfig( + vocab_size=vocab_size, + ) +) +if use_activation_ckpt: + raw_model.gradient_checkpointing_enable() +optimizer = FusedAdam(raw_model.parameters(), lr=1e-5) + +train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size) +if accelerator.distributed_type == DistributedType.FSDP: + accelerator.print('FSDP') + model = accelerator.prepare(raw_model) + optimizer, train_loader = accelerator.prepare(optimizer, train_loader) +else: + model, optimizer, train_loader = accelerator.prepare(raw_model, optimizer, train_loader) + +def train(model, optimizer, train_loader): + start_time = time.time() + for i, batch in enumerate(train_loader): + if i == total_step: + break + optimizer.zero_grad() + out = model(input_ids=batch, labels=batch) + loss = out.loss + accelerator.backward(loss) + optimizer.step() + end_time = time.time() + return end_time - start_time + +accelerator.print('total time: {}'.format(train(model, optimizer, train_loader))) diff --git a/speed_test/accelerate/run.sh b/speed_test/accelerate/run.sh new file mode 100644 index 0000000..2e60847 --- /dev/null +++ b/speed_test/accelerate/run.sh @@ -0,0 +1,12 @@ +### + # @Author: LiangSong(sl12160010@gmail.com) + # @Date: 2023-04-08 22:44:27 + # @LastEditors: LiangSong(sl12160010@gmail.com) + # @LastEditTime: 2023-04-11 21:58:43 + # @FilePath: /Open-Llama/speed_test/accelerate/run.sh + # @Description: + # + # Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. +### +total_gpu=8 +accelerate launch --config_file deepspeed_stage2.yaml --main_process_ip 127.0.0.1 --main_process_port 23335 --num_processes $total_gpu run.py \ No newline at end of file diff --git a/speed_test/colossal-ai/run.py b/speed_test/colossal-ai/run.py new file mode 100644 index 0000000..92a638f --- /dev/null +++ b/speed_test/colossal-ai/run.py @@ -0,0 +1,376 @@ +""" +Author: LiangSong(sl12160010@gmail.com) +Date: 2023-04-11 20:07:35 +LastEditors: LiangSong(sl12160010@gmail.com) +LastEditTime: 2023-04-11 21:56:23 +FilePath: /Open-Llama/speed_test/colossal-ai/run.py +Description: + +Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. +""" +import os +from functools import partial +from time import time + +import psutil +import torch +import torch.nn as nn +from transformers import LlamaForCausalLM, LlamaConfig +from utils import get_data, get_profile_context, get_tflops, get_time_stamp +from packaging import version +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper + +CAI_VERSION = colossalai.__version__ + + +def parse_args(): + parser = colossalai.get_default_parser() + parser.add_argument( + "--distplan", + type=str, + default='CAI_Gemini', + help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", + ) + parser.add_argument( + "--tp_degree", + type=int, + default=1, + help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--placement", + type=str, + default='cpu', + help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--shardinit", + action='store_true', + help= + "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=8, + help="batch size per DP group of training.", + ) + parser.add_argument( + "--model_type", + type=str, + default="Llama-7B", + help="model model scale", + ) + parser.add_argument( + "--train_step", + type=int, + default=10, + help="training iterations for test", + ) + + args = parser.parse_args() + return args + + +def model_builder(VOCAB_SIZE, checkpoint=False): + raw_model = LlamaForCausalLM( + LlamaConfig( + vocab_size=VOCAB_SIZE, + ) + ) + if checkpoint: + raw_model.gradient_checkpointing_enable() + return raw_model + + +# Parameter Sharding Strategies for Tensor Parallelism +def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): + spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + param.set_tensor_spec(*spec) + + +def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(0, param, pg) + + +def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): + split_param_single_dim_tp1d(-1, param, pg) + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +def get_cpu_mem(): + return psutil.Process().memory_info().rss / 1024**2 + + +def get_gpu_mem(): + return torch.cuda.memory_allocated() / 1024**2 + + +def get_mem_info(prefix=''): + return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' + + +def get_model_size(model: nn.Module): + total_numel = 0 + for module in model.modules(): + for p in module.parameters(recurse=False): + total_numel += p.numel() + return total_numel + + +def model_size_formatter(numel: int) -> str: + GB_SIZE = 10**9 + MB_SIZE = 10**6 + KB_SIZE = 10**3 + if numel >= GB_SIZE: + return f'{numel / GB_SIZE:.1f}B' + elif numel >= MB_SIZE: + return f'{numel / MB_SIZE:.1f}M' + elif numel >= KB_SIZE: + return f'{numel / KB_SIZE:.1f}K' + else: + return str(numel) + + +def set_cpu_maximum_parallelism(): + conf_str = torch.__config__.parallel_info() + inter_str = conf_str.split("hardware_concurrency() : ")[1] + max_concurrency = inter_str.split('\n')[0] + os.environ["OMP_NUM_THREADS"] = max_concurrency + print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.") + + +# Tensor Parallel +def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): + """tensor_parallelize + Sharding the Model Parameters. + + Args: + model (torch.nn.Module): a torch module to be sharded + """ + for mn, module in model.named_modules(): + for pn, param in module.named_parameters(recurse=False): + # NOTE() a param maybe shared by two modules + if hasattr(param, 'visited'): + continue + + # if shard init, then convert param to replica and use the dp-only ProcessGroup + param: ColoParameter = param + param.set_dist_spec(ReplicaSpec()) + param.set_process_group(pg) + + # shard it w.r.t tp pattern + if 'mlp.c_fc' in mn: + if 'weight' in pn or 'bias' in pn: + split_param_col_tp1d(param, pg) # colmn slice + # keep the shape of the output from c_fc + param.compute_spec.set_output_replicate(False) + else: + param.set_dist_spec(ReplicaSpec()) + elif 'mlp.c_proj' in mn: + if 'weight' in pn: + split_param_row_tp1d(param, pg) # row slice + else: + param.set_dist_spec(ReplicaSpec()) + elif 'wte' in mn or 'wpe' in mn: + split_param_col_tp1d(param, pg) # colmn slice + elif 'c_attn' in mn or 'c_proj' in mn: + split_param_col_tp1d(param, pg) # colmn slice + else: + param.set_dist_spec(ReplicaSpec()) + param.visited = True + + +def main(): + # version check + # this example is supposed to work for versions greater than 0.2.0 + assert version.parse(CAI_VERSION) >= version.parse("0.2.0") + + set_cpu_maximum_parallelism() + args = parse_args() + + # if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]: + if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]: + raise TypeError(f"{args.distplan} is error") + + # batch size per DP degree + BATCH_SIZE = args.batch_size + SEQ_LEN = 2048 + VOCAB_SIZE = 32000 + + NUM_STEPS = args.train_step + + WARMUP_STEPS = 1 + assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" + assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" + PROF_FLAG = False # The flag of profiling, False by default + + disable_existing_loggers() + colossalai.launch_from_torch(config={}) + + logger = get_dist_logger() + logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0]) + + # build criterion + criterion = GPTLMLoss() + + torch.manual_seed(123) + if args.distplan.startswith("CAI"): + # all param must use the same process group. + world_size = torch.distributed.get_world_size() + shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None + default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None + + if args.shardinit and args.distplan != "CAI_Gemini": + raise RuntimeError("You can only use shardinit with CAI_Gemini") + + # build GPT model + with ColoInitContext(device=get_current_device(), + dtype=torch.half, + default_dist_spec=default_dist_spec, + default_pg=shard_pg): + model = model_builder(VOCAB_SIZE, checkpoint=True) + + tp_pg = ProcessGroup(tp_degree=args.tp_degree) + # Tensor Parallelism (TP) + # You should notice that v0.1.10 is not compatible with TP degree > 1 + if args.tp_degree > 1: + tensor_parallelize(model, tp_pg) + + # asign running configurations + gemini_config = None + if args.distplan.startswith("CAI_ZeRO"): + optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) + elif args.distplan == "CAI_Gemini": + gemini_config = dict(strict_ddp_mode=args.tp_degree == 1, + device=get_current_device(), + placement_policy=args.placement, + pin_memory=True, + hidden_dim=model.model.config.hidden_size, + search_range_mb=128) + optim_config = dict(gpu_margin_mem_ratio=0.) + else: + raise RuntimeError + + # build a highly optimized gpu/cpu optimizer + optimizer = HybridAdam(model.parameters(), lr=1e-3) + + if args.distplan == "CAI_ZeRO1": + zero_stage = 1 + elif args.distplan == "CAI_ZeRO2": + zero_stage = 2 + elif args.distplan == "CAI_Gemini": + zero_stage = 3 + else: + raise RuntimeError + + # wrap your model and optimizer + model = zero_model_wrapper(model, zero_stage, gemini_config) + optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) + + logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) + elif args.distplan.startswith("Pytorch"): + assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples." + model = model_builder(VOCAB_SIZE, checkpoint=True).cuda() + model = DDP(model) + if args.distplan.endswith("DDP"): + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + elif args.distplan.endswith("ZeRO"): + from torch.distributed.optim import ZeroRedundancyOptimizer + optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3) + else: + raise RuntimeError + + # model is shared after TP + numel = get_model_size(model) + logger.info(f"the size of testing model size is {model_size_formatter(numel)}.") + logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) + + # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu + # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree) + # = batch_per_DP_group * numel * seq_len * 8 + get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) + + torch.cuda.synchronize() + model.train() + tflops_list = [] + + def train_step(): + # we just use randomly generated data here + input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) + optimizer.zero_grad() + + start = time() + outputs = model(input_ids, attn_mask)[0] + loss = criterion(outputs, input_ids) + torch.cuda.synchronize() + fwd_end = time() + fwd_time = fwd_end - start + logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0]) + + if args.distplan.startswith("CAI"): + optimizer.backward(loss) + elif args.distplan.startswith("Pytorch"): + loss.backward() + else: + raise RuntimeError + + torch.cuda.synchronize() + bwd_end = time() + bwd_time = bwd_end - fwd_end + logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0]) + + optimizer.step() + torch.cuda.synchronize() + optim_time = time() - bwd_end + step_time = time() - start + logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) + + step_tflops = get_tflops_func(step_time) + logger.info( + f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s", + ranks=[0], + ) + if n >= WARMUP_STEPS: + tflops_list.append(step_tflops) + + demo_profiler = get_profile_context(PROF_FLAG, + WARMUP_STEPS, + NUM_STEPS - WARMUP_STEPS, + save_dir=f"profile/{get_time_stamp()}-demo") + + with demo_profiler as prof: + start_time = time() + for n in range(NUM_STEPS): + train_step() + prof.step() + end_time = time() + print('total time: {}'.format(end_time - start_time)) + + tflops_list.sort() + median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS + logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") + torch.cuda.synchronize() + + +if __name__ == '__main__': + main() diff --git a/speed_test/colossal-ai/run.sh b/speed_test/colossal-ai/run.sh new file mode 100644 index 0000000..593a287 --- /dev/null +++ b/speed_test/colossal-ai/run.sh @@ -0,0 +1,31 @@ +set -x +# distplan in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"] +export DISTPLAN=${DISTPLAN:-"CAI_Gemini"} + +# The following options only valid when DISTPLAN="colossalai" +export GPUNUM=${GPUNUM:-8} +export TPDEGREE=${TPDEGREE:-1} +export PLACEMENT=${PLACEMENT:-"auto"} +export USE_SHARD_INIT=${USE_SHARD_INIT:-True} +export BATCH_SIZE=${BATCH_SIZE:-40} +export MODEL_TYPE=${MODEL_TYPE:-"Llama-7B"} +export TRAIN_STEP=${TRAIN_STEP:-10} +# export PYTHONPATH=$PWD:$PYTHONPATH + +if [ ${USE_SHARD_INIT} = "True" ]; then + USE_SHARD_INIT="--shardinit" +else + USE_SHARD_INIT="" +fi + +mkdir -p gemini_logs + +torchrun --nproc_per_node=${GPUNUM} --rdzv_endpoint=127.0.0.1:23335 run.py \ +--tp_degree=${TPDEGREE} \ +--model_type=${MODEL_TYPE} \ +--batch_size=${BATCH_SIZE} \ +--placement=${PLACEMENT} \ +${USE_SHARD_INIT} \ +--distplan=${DISTPLAN} \ +--train_step=${TRAIN_STEP} \ +2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log diff --git a/speed_test/colossal-ai/utils.py b/speed_test/colossal-ai/utils.py new file mode 100644 index 0000000..1749a21 --- /dev/null +++ b/speed_test/colossal-ai/utils.py @@ -0,0 +1,41 @@ +import time +from contextlib import nullcontext + +import torch +from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler + + +class DummyProfiler: + + def __init__(self): + self.step_number = 0 + + def step(self): + self.step_number += 1 + + +# Randomly Generated Data +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) + + +def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): + if enable_flag: + return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), + on_trace_ready=tensorboard_trace_handler(save_dir), + record_shapes=True, + profile_memory=True) + else: + return nullcontext(DummyProfiler()) + + +def get_time_stamp(): + cur_time = time.strftime("%d-%H:%M", time.localtime()) + return cur_time \ No newline at end of file diff --git a/speed_test/lightning/run.py b/speed_test/lightning/run.py new file mode 100644 index 0000000..b8ed1fa --- /dev/null +++ b/speed_test/lightning/run.py @@ -0,0 +1,80 @@ +""" +Author: LiangSong(sl12160010@gmail.com) +Date: 2023-04-11 20:07:35 +LastEditors: LiangSong(sl12160010@gmail.com) +LastEditTime: 2023-04-11 21:56:07 +FilePath: /Open-Llama/speed_test/lightning/run.py +Description: + +Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. +""" +import time +import torch +import lightning.pytorch as pl +from deepspeed.ops.adam import FusedAdam +from transformers import LlamaForCausalLM, LlamaConfig +from lightning.pytorch.strategies import DeepSpeedStrategy + + +batch_size = 2 +seq_length = 2048 +vocab_size = 32000 +total_step = 100 +use_activation_ckpt = False + +class FakeSet(torch.utils.data.Dataset): + def __getitem__(self, idx): + return torch.randint(0, vocab_size, (seq_length, )) + + def __len__(self): + return 1000000000 + +class SpeedTest(pl.LightningModule): + def __init__(self): + super().__init__() + self.model = LlamaForCausalLM( + LlamaConfig( + vocab_size=vocab_size, + ) + ) + if use_activation_ckpt: + self.model.gradient_checkpointing_enable() + self.start_time = None + + def training_step(self, batch, batch_idx): + out = self.model(batch, labels=batch) + loss = out.loss + if self.start_time is None: + print('start') + self.start_time = time.time() + return loss + + def configure_optimizers(self): + optimizer = FusedAdam(self.trainer.model.parameters(), lr=1e-5) + return optimizer + +model = SpeedTest() +train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size) + +strategy=DeepSpeedStrategy( + stage=2, + offload_optimizer=False, + offload_parameters=False, + process_group_backend="nccl" +) +trainer = pl.Trainer( + limit_train_batches=total_step, + max_epochs=1, + devices=8, + accelerator="gpu", + strategy=strategy, + precision=16, + enable_checkpointing=False) + +def train(model, train_loader): + start_time = time.time() + trainer.fit(model=model, train_dataloaders=train_loader) + end_time = time.time() + return end_time - model.start_time + +print('total time: {}'.format(train(model, train_loader)))