""" Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-11 20:07:35 LastEditors: LiangSong(sl12160010@gmail.com) LastEditTime: 2023-04-11 21:56:23 FilePath: /Open-Llama/speed_test/colossal-ai/run.py Description: Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. """ import os from functools import partial from time import time import psutil import torch import torch.nn as nn from transformers import LlamaForCausalLM, LlamaConfig from utils import get_data, get_profile_context, get_tflops, get_time_stamp from packaging import version from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam from colossalai.tensor import ( ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec, ) from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper CAI_VERSION = colossalai.__version__ def parse_args(): parser = colossalai.get_default_parser() parser.add_argument( "--distplan", type=str, default="CAI_Gemini", help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", ) parser.add_argument( "--tp_degree", type=int, default=1, help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", ) parser.add_argument( "--placement", type=str, default="cpu", help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", ) parser.add_argument( "--shardinit", action="store_true", help="Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", ) parser.add_argument( "--batch_size", type=int, default=8, help="batch size per DP group of training.", ) parser.add_argument( "--model_type", type=str, default="Llama-7B", help="model model scale", ) parser.add_argument( "--train_step", type=int, default=10, help="training iterations for test", ) args = parser.parse_args() return args def model_builder(VOCAB_SIZE, checkpoint=False): raw_model = LlamaForCausalLM( LlamaConfig( vocab_size=VOCAB_SIZE, ) ) if checkpoint: raw_model.gradient_checkpointing_enable() return raw_model # Parameter Sharding Strategies for Tensor Parallelism def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) param.set_tensor_spec(*spec) def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): split_param_single_dim_tp1d(0, param, pg) def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): split_param_single_dim_tp1d(-1, param, pg) class GPTLMLoss(nn.Module): def __init__(self): super().__init__() self.loss_fn = nn.CrossEntropyLoss() def forward(self, logits, labels): shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens return self.loss_fn( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) def get_cpu_mem(): return psutil.Process().memory_info().rss / 1024**2 def get_gpu_mem(): return torch.cuda.memory_allocated() / 1024**2 def get_mem_info(prefix=""): return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB" def get_model_size(model: nn.Module): total_numel = 0 for module in model.modules(): for p in module.parameters(recurse=False): total_numel += p.numel() return total_numel def model_size_formatter(numel: int) -> str: GB_SIZE = 10**9 MB_SIZE = 10**6 KB_SIZE = 10**3 if numel >= GB_SIZE: return f"{numel / GB_SIZE:.1f}B" elif numel >= MB_SIZE: return f"{numel / MB_SIZE:.1f}M" elif numel >= KB_SIZE: return f"{numel / KB_SIZE:.1f}K" else: return str(numel) def set_cpu_maximum_parallelism(): conf_str = torch.__config__.parallel_info() inter_str = conf_str.split("hardware_concurrency() : ")[1] max_concurrency = inter_str.split("\n")[0] os.environ["OMP_NUM_THREADS"] = max_concurrency print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.") # Tensor Parallel def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): """tensor_parallelize Sharding the Model Parameters. Args: model (torch.nn.Module): a torch module to be sharded """ for mn, module in model.named_modules(): for pn, param in module.named_parameters(recurse=False): # NOTE() a param maybe shared by two modules if hasattr(param, "visited"): continue # if shard init, then convert param to replica and use the dp-only ProcessGroup param: ColoParameter = param param.set_dist_spec(ReplicaSpec()) param.set_process_group(pg) # shard it w.r.t tp pattern if "mlp.c_fc" in mn: if "weight" in pn or "bias" in pn: split_param_col_tp1d(param, pg) # colmn slice # keep the shape of the output from c_fc param.compute_spec.set_output_replicate(False) else: param.set_dist_spec(ReplicaSpec()) elif "mlp.c_proj" in mn: if "weight" in pn: split_param_row_tp1d(param, pg) # row slice else: param.set_dist_spec(ReplicaSpec()) elif "wte" in mn or "wpe" in mn: split_param_col_tp1d(param, pg) # colmn slice elif "c_attn" in mn or "c_proj" in mn: split_param_col_tp1d(param, pg) # colmn slice else: param.set_dist_spec(ReplicaSpec()) param.visited = True def main(): # version check # this example is supposed to work for versions greater than 0.2.0 assert version.parse(CAI_VERSION) >= version.parse("0.2.0") set_cpu_maximum_parallelism() args = parse_args() # if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]: if args.distplan not in [ "CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO", ]: raise TypeError(f"{args.distplan} is error") # batch size per DP degree BATCH_SIZE = args.batch_size SEQ_LEN = 2048 VOCAB_SIZE = 32000 NUM_STEPS = args.train_step WARMUP_STEPS = 1 assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" assert ( NUM_STEPS - WARMUP_STEPS ) % 2 == 1, "the number of valid steps should be odd to take the median" PROF_FLAG = False # The flag of profiling, False by default disable_existing_loggers() colossalai.launch_from_torch(config={}) logger = get_dist_logger() logger.info( f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0] ) # build criterion criterion = GPTLMLoss() torch.manual_seed(123) if args.distplan.startswith("CAI"): # all param must use the same process group. world_size = torch.distributed.get_world_size() shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None if args.shardinit and args.distplan != "CAI_Gemini": raise RuntimeError("You can only use shardinit with CAI_Gemini") # build GPT model with ColoInitContext( device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg, ): model = model_builder(VOCAB_SIZE, checkpoint=True) tp_pg = ProcessGroup(tp_degree=args.tp_degree) # Tensor Parallelism (TP) # You should notice that v0.1.10 is not compatible with TP degree > 1 if args.tp_degree > 1: tensor_parallelize(model, tp_pg) # asign running configurations gemini_config = None if args.distplan.startswith("CAI_ZeRO"): optim_config = dict( reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True, ) elif args.distplan == "CAI_Gemini": gemini_config = dict( strict_ddp_mode=args.tp_degree == 1, device=get_current_device(), placement_policy=args.placement, pin_memory=True, hidden_dim=model.model.config.hidden_size, search_range_mb=128, ) optim_config = dict(gpu_margin_mem_ratio=0.0) else: raise RuntimeError # build a highly optimized gpu/cpu optimizer optimizer = HybridAdam(model.parameters(), lr=1e-3) if args.distplan == "CAI_ZeRO1": zero_stage = 1 elif args.distplan == "CAI_ZeRO2": zero_stage = 2 elif args.distplan == "CAI_Gemini": zero_stage = 3 else: raise RuntimeError # wrap your model and optimizer model = zero_model_wrapper(model, zero_stage, gemini_config) optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) logger.info(get_mem_info(prefix="After init optim, "), ranks=[0]) elif args.distplan.startswith("Pytorch"): assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples." model = model_builder(VOCAB_SIZE, checkpoint=True).cuda() model = DDP(model) if args.distplan.endswith("DDP"): optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) elif args.distplan.endswith("ZeRO"): from torch.distributed.optim import ZeroRedundancyOptimizer optimizer = ZeroRedundancyOptimizer( model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3 ) else: raise RuntimeError # model is shared after TP numel = get_model_size(model) logger.info(f"the size of testing model size is {model_size_formatter(numel)}.") logger.info(get_mem_info(prefix="After init model, "), ranks=[0]) # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree) # = batch_per_DP_group * numel * seq_len * 8 get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) torch.cuda.synchronize() model.train() tflops_list = [] def train_step(): # we just use randomly generated data here input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) optimizer.zero_grad() start = time() outputs = model(input_ids, attn_mask)[0] loss = criterion(outputs, input_ids) torch.cuda.synchronize() fwd_end = time() fwd_time = fwd_end - start logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Forward "), ranks=[0]) if args.distplan.startswith("CAI"): optimizer.backward(loss) elif args.distplan.startswith("Pytorch"): loss.backward() else: raise RuntimeError torch.cuda.synchronize() bwd_end = time() bwd_time = bwd_end - fwd_end logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Backward "), ranks=[0]) optimizer.step() torch.cuda.synchronize() optim_time = time() - bwd_end step_time = time() - start logger.info( get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Optimizer step "), ranks=[0] ) step_tflops = get_tflops_func(step_time) logger.info( f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s", ranks=[0], ) if n >= WARMUP_STEPS: tflops_list.append(step_tflops) demo_profiler = get_profile_context( PROF_FLAG, WARMUP_STEPS, NUM_STEPS - WARMUP_STEPS, save_dir=f"profile/{get_time_stamp()}-demo", ) with demo_profiler as prof: start_time = time() for n in range(NUM_STEPS): train_step() prof.step() end_time = time() print("total time: {}".format(end_time - start_time)) tflops_list.sort() median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") torch.cuda.synchronize() if __name__ == "__main__": main()