update format

This commit is contained in:
LiangSong 2023-04-12 22:16:15 +08:00
parent a4aa109dd3
commit 3f62a23ee2
8 changed files with 281 additions and 160 deletions

View File

@ -1,3 +1,13 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-12 19:12:42
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-12 21:01:32
FilePath: /Open-Llama/pretrain.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import yaml
import torch
import random
@ -20,55 +30,72 @@ from solver.trainer import Trainer
FLAGS = flags.FLAGS
flags.DEFINE_string("config", None, "Training config path")
class FakeSet(torch.utils.data.Dataset):
def __getitem__(self, idx):
return {"input_ids": torch.randint(0, 32000, (2048,))}
def __len__(self):
return 1000000000
def main(argv):
accelerator = Accelerator()
with open(FLAGS.config, 'r', encoding="utf-8") as fp:
with open(FLAGS.config, "r", encoding="utf-8") as fp:
config = yaml.load(fp, Loader=yaml.FullLoader)
sp_model = spm.SentencePieceProcessor(model_file=config['data']['tokenizer_model_path'])
sp_model = spm.SentencePieceProcessor(
model_file=config["data"]["tokenizer_model_path"]
)
tokenizer = Tokenizer(sp_model)
paths = create_shard_kwargs(config['data']['patterns'])
random.shuffle(paths)
transform_dict = {
"wudao": preprocess_wudao_gen(tokenizer, config['model']['max_length']),
"pile": preprocess_the_pile_gen(tokenizer, config['model']['max_length']),
}
data_set = DataIter(
paths,
transform_dict=transform_dict,
concat_docs=True,
max_length=config['model']['max_length'],
process_index=accelerator.process_index,
num_processes=accelerator.num_processes,
)
train_loader = DataLoader(
data_set,
batch_size=config['train']['train_batch_size'],
# If num_workers is greater than 1, duplicate data may occur.
num_workers=0,
collate_fn=collate_fn_gen(tokenizer, config['model']['max_length']),
drop_last=True,
# paths = create_shard_kwargs(config['data']['patterns'])
# random.shuffle(paths)
# transform_dict = {
# "wudao": preprocess_wudao_gen(tokenizer, config['model']['max_length']),
# "pile": preprocess_the_pile_gen(tokenizer, config['model']['max_length']),
# }
# data_set = DataIter(
# paths,
# transform_dict=transform_dict,
# concat_docs=True,
# max_length=config['model']['max_length'],
# process_index=accelerator.process_index,
# num_processes=accelerator.num_processes,
# )
# train_loader = DataLoader(
# data_set,
# batch_size=config['train']['train_batch_size'],
# # If num_workers is greater than 1, duplicate data may occur.
# num_workers=0,
# collate_fn=collate_fn_gen(tokenizer, config['model']['max_length']),
# drop_last=True,
# )
train_loader = torch.utils.data.DataLoader(
FakeSet(), batch_size=config["train"]["train_batch_size"]
)
# smaller initializer_range make training more stable
# add stabel embedding to token embedding
raw_model = LlamaForCausalLM(
LlamaConfig(
vocab_size=tokenizer.vocab_size,
initializer_range=config['model']['initializer_range'],
initializer_range=config["model"]["initializer_range"],
pad_token_id=tokenizer.pad_id,
rms_norm_eps=1e-5,
hidden_dropout_prob=config['model']['hidden_dropout_prob'],
attention_dropout_prob=config['model']['attention_dropout_prob'],
use_stable_embedding=config['model']['use_stable_embedding'],
shared_input_output_embedding=config['model']['shared_input_output_embedding'],
hidden_dropout_prob=config["model"]["hidden_dropout_prob"],
attention_dropout_prob=config["model"]["attention_dropout_prob"],
use_stable_embedding=config["model"]["use_stable_embedding"],
shared_input_output_embedding=config["model"][
"shared_input_output_embedding"
],
)
)
if config['train']['ckpt'] is not None:
ckpt = torch.load(config['train']['ckpt'])
if config["train"]["ckpt"] is not None:
ckpt = torch.load(config["train"]["ckpt"])
raw_model.load_state_dict(ckpt)
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
trainer.train()
if __name__ == '__main__':
app.run(main)
if __name__ == "__main__":
app.run(main)

View File

@ -8,6 +8,7 @@ from transformers import get_cosine_schedule_with_warmup
from dataset.validation import val_set
class Trainer:
def __init__(self, config, raw_model, train_loader, tokenizer, accelerator):
self.config = config
@ -15,18 +16,29 @@ class Trainer:
self.train_loader = train_loader
self.tokenizer = tokenizer
self.accelerator = accelerator
self.lr_scheduler_factor = accelerator.num_processes / accelerator.gradient_accumulation_steps
self.log_interval = self.config['log_interval'] * accelerator.gradient_accumulation_steps
self.eval_interval = self.config['eval_interval'] * accelerator.gradient_accumulation_steps
self.save_interval = self.config['save_interval'] * accelerator.gradient_accumulation_steps
self.work_dir = self.config['work_dir']
self.lr_scheduler_factor = (
accelerator.num_processes / accelerator.gradient_accumulation_steps
)
self.log_interval = (
self.config["log_interval"] * accelerator.gradient_accumulation_steps
)
self.eval_interval = (
self.config["eval_interval"] * accelerator.gradient_accumulation_steps
)
self.save_interval = (
self.config["save_interval"] * accelerator.gradient_accumulation_steps
)
self.work_dir = self.config["work_dir"]
self.get_model_info()
if accelerator.is_main_process:
wandb.init(project=self.config['project_name'])
wandb.init(project=self.config["project_name"])
def get_model_info(self):
with torch.no_grad():
summary(self.raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda())
summary(
self.raw_model.cuda(),
input_data=torch.ones(1, 64, dtype=torch.int64).cuda(),
)
def get_optimizer(self):
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
@ -37,7 +49,7 @@ class Trainer:
for n, p in self.raw_model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": self.config['train']['weight_decay'],
"weight_decay": self.config["train"]["weight_decay"],
},
{
"params": [
@ -48,15 +60,21 @@ class Trainer:
"weight_decay": 0.0,
},
]
self.optim = FusedAdam(optimizer_grouped_parameters, lr=self.config['train']['lr'], betas=(0.9, 0.95))
self.optim = FusedAdam(
optimizer_grouped_parameters,
lr=self.config["train"]["lr"],
betas=(0.9, 0.95),
)
def get_lr_scheduler(self):
self.scheduler = get_cosine_schedule_with_warmup(
self.optim,
num_warmup_steps=self.config['train']['num_warmup_steps'] * self.lr_scheduler_factor,
num_training_steps=self.config['train']['num_training_steps'] * self.lr_scheduler_factor,
num_warmup_steps=self.config["train"]["num_warmup_steps"]
* self.lr_scheduler_factor,
num_training_steps=self.config["train"]["num_training_steps"]
* self.lr_scheduler_factor,
)
def prepare(self):
_, self.model, self.optim, self.scheduler = self.accelerator.prepare(
self.train_loader, self.raw_model, self.optim, self.scheduler
@ -74,7 +92,7 @@ class Trainer:
self.scheduler.step()
self.optim.zero_grad()
return losses
def train(self):
self.get_optimizer()
self.get_lr_scheduler()
@ -82,28 +100,45 @@ class Trainer:
self.global_step = 0
self.start_time = time.time()
self.optim.zero_grad()
for self.data_step in range(self.config['train']['num_training_steps']):
for self.data_step in range(self.config["train"]["num_training_steps"]):
self.model.train()
with self.accelerator.accumulate(self.model):
batch = next(self.train_loader_iter)
losses = self.train_step(batch)
if self.accelerator.sync_gradients:
self.global_step += 1
if self.data_step % self.log_interval == 0 and self.data_step > 0 and self.accelerator.is_main_process:
if (
self.data_step % self.log_interval == 0
and self.data_step > 0
and self.accelerator.is_main_process
):
self.log(losses)
if self.data_step % self.eval_interval == 0 and self.accelerator.is_main_process:
if (
self.data_step % self.eval_interval == 0
and self.accelerator.is_main_process
):
self.eval()
if self.data_step % self.save_interval == 0 and self.data_step > 0 and self.accelerator.is_main_process:
if (
self.data_step % self.save_interval == 0
and self.data_step > 0
and self.accelerator.is_main_process
):
if not os.path.isdir(self.work_dir):
os.mkdir(self.work_dir)
torch.save(self.raw_model.state_dict(), "{}/{}.pt".format(self.work_dir, self.global_step))
torch.save(
self.raw_model.state_dict(),
"{}/{}.pt".format(self.work_dir, self.global_step),
)
wandb.finish()
def log(self, losses):
cost_time = time.time() - self.start_time
self.start_time = time.time()
tokens = self.config['train']['train_batch_size'] * \
self.log_interval * self.config['model']['max_length']
tokens = (
self.config["train"]["train_batch_size"]
* self.log_interval
* self.config["model"]["max_length"]
)
wandb.log({"Training/Token per second per gpu": tokens / cost_time})
for k, v in losses.items():
wandb.log({"Losses/{}".format(k): v})
@ -115,7 +150,10 @@ class Trainer:
wandb.log({"Training/Global Step": self.global_step})
self.accelerator.print(
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
self.global_step, self.data_step, losses["total_loss"], tokens / cost_time
self.global_step,
self.data_step,
losses["total_loss"],
tokens / cost_time,
)
)
@ -137,4 +175,4 @@ class Trainer:
pred = self.tokenizer.decode(pred.cpu())[0]
pred = pred[inputs_len:]
text_table.add_data(raw_inputs, pred)
wandb.log({"Predictions on {}".format(self.global_step): text_table})
wandb.log({"Predictions on {}".format(self.global_step): text_table})

View File

@ -6,16 +6,20 @@ sp_model = spm.SentencePieceProcessor(
model_file="configs/llama_tokenizer_extended.model"
)
merged_vocab_size = sp_model.vocab_size()
ckpt = torch.load('data/llama_raw_ckpt/7B/consolidated.00.pth')
ckpt = torch.load("data/llama_raw_ckpt/7B/consolidated.00.pth")
raw_vocab_size, hidden_size = ckpt['tok_embeddings.weight'].shape
raw_vocab_size, hidden_size = ckpt["tok_embeddings.weight"].shape
extended_tok_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden_size)
extended_tok_embeddings = extended_tok_embeddings * 0.001
ckpt['tok_embeddings.weight'] = torch.cat([ckpt['tok_embeddings.weight'], extended_tok_embeddings], dim=0)
ckpt["tok_embeddings.weight"] = torch.cat(
[ckpt["tok_embeddings.weight"], extended_tok_embeddings], dim=0
)
extended_out_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden_size)
extended_out_embeddings = extended_out_embeddings * 0.001
ckpt['output.weight'] = torch.cat([ckpt['output.weight'], extended_out_embeddings], dim=0)
ckpt["output.weight"] = torch.cat(
[ckpt["output.weight"], extended_out_embeddings], dim=0
)
rename_map = {
"tok_embeddings.weight": "model.embed_tokens.weight",
@ -26,31 +30,31 @@ rename_map = {
for f, t in rename_map.items():
v = ckpt.pop(f)
ckpt[t] = v
from_names = [
"layers.{}.attention.wq.weight",
"layers.{}.attention.wk.weight",
"layers.{}.attention.wv.weight",
"layers.{}.attention.wo.weight",
"layers.{}.feed_forward.w1.weight",
"layers.{}.feed_forward.w2.weight",
"layers.{}.feed_forward.w3.weight",
"layers.{}.attention_norm.weight",
"layers.{}.ffn_norm.weight",
"layers.{}.attention.inner_attention.rope.freqs"
"layers.{}.attention.wq.weight",
"layers.{}.attention.wk.weight",
"layers.{}.attention.wv.weight",
"layers.{}.attention.wo.weight",
"layers.{}.feed_forward.w1.weight",
"layers.{}.feed_forward.w2.weight",
"layers.{}.feed_forward.w3.weight",
"layers.{}.attention_norm.weight",
"layers.{}.ffn_norm.weight",
"layers.{}.attention.inner_attention.rope.freqs",
]
to_names = [
"model.layers.{}.self_attn.q_proj.weight",
"model.layers.{}.self_attn.k_proj.weight",
"model.layers.{}.self_attn.v_proj.weight",
"model.layers.{}.self_attn.o_proj.weight",
"model.layers.{}.mlp.gate_proj.weight",
"model.layers.{}.mlp.down_proj.weight",
"model.layers.{}.mlp.up_proj.weight",
"model.layers.{}.input_layernorm.weight",
"model.layers.{}.self_attn.q_proj.weight",
"model.layers.{}.self_attn.k_proj.weight",
"model.layers.{}.self_attn.v_proj.weight",
"model.layers.{}.self_attn.o_proj.weight",
"model.layers.{}.mlp.gate_proj.weight",
"model.layers.{}.mlp.down_proj.weight",
"model.layers.{}.mlp.up_proj.weight",
"model.layers.{}.input_layernorm.weight",
"model.layers.{}.post_attention_layernorm.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq",
"model.layers.{}.self_attn.rotary_emb.inv_freq",
]
for layer in range(32):
@ -59,4 +63,4 @@ for layer in range(32):
t = t.format(layer)
v = ckpt.pop(f)
ckpt[t] = v
torch.save(ckpt, 'data/llama_raw_ckpt/7B/extended.pth')
torch.save(ckpt, "data/llama_raw_ckpt/7B/extended.pth")

View File

@ -3,21 +3,21 @@ import sentencepiece as spm
from sentencepiece import sentencepiece_model_pb2 as model
raw_model = model.ModelProto()
raw_model.ParseFromString(open('configs/llama_tokenizer.model', 'rb').read())
raw_model.ParseFromString(open("configs/llama_tokenizer.model", "rb").read())
exist_pieces = set([p.piece for p in raw_model.pieces])
cn_model = model.ModelProto()
cn_model.ParseFromString(open('configs/4w_cn_vocab_wudao15.model', 'rb').read())
cn_model.ParseFromString(open("configs/4w_cn_vocab_wudao15.model", "rb").read())
for p in tqdm(cn_model.pieces, total=len(cn_model.pieces)):
if p.piece not in exist_pieces:
raw_model.pieces.append(p)
with open('configs/llama_tokenizer_extended.model', 'wb') as f:
with open("configs/llama_tokenizer_extended.model", "wb") as f:
f.write(raw_model.SerializeToString())
sp_model = spm.SentencePieceProcessor(
model_file="configs/llama_tokenizer_extended.model"
)
print('merged vocab size: {}'.format(sp_model.vocab_size()))
print("merged vocab size: {}".format(sp_model.vocab_size()))

View File

@ -20,13 +20,15 @@ vocab_size = 32000
total_step = 2
use_activation_ckpt = True
class FakeSet(torch.utils.data.Dataset):
def __getitem__(self, idx):
return torch.randint(0, vocab_size, (seq_length, ))
return torch.randint(0, vocab_size, (seq_length,))
def __len__(self):
return 1000000000
accelerator = Accelerator()
raw_model = LlamaForCausalLM(
LlamaConfig(
@ -39,15 +41,18 @@ optimizer = FusedAdam(raw_model.parameters(), lr=1e-5)
train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size)
if accelerator.distributed_type == DistributedType.FSDP:
accelerator.print('FSDP')
accelerator.print("FSDP")
model = accelerator.prepare(raw_model)
optimizer, train_loader = accelerator.prepare(optimizer, train_loader)
else:
model, optimizer, train_loader = accelerator.prepare(raw_model, optimizer, train_loader)
model, optimizer, train_loader = accelerator.prepare(
raw_model, optimizer, train_loader
)
def train(model, optimizer, train_loader):
start_time = time.time()
for i, batch in enumerate(train_loader):
for i, batch in enumerate(train_loader):
if i == total_step:
break
optimizer.zero_grad()
@ -58,4 +63,5 @@ def train(model, optimizer, train_loader):
end_time = time.time()
return end_time - start_time
accelerator.print('total time: {}'.format(train(model, optimizer, train_loader)))
accelerator.print("total time: {}".format(train(model, optimizer, train_loader)))

View File

@ -23,7 +23,14 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.tensor import (
ColoParameter,
ComputePattern,
ComputeSpec,
ProcessGroup,
ReplicaSpec,
ShardSpec,
)
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
@ -35,7 +42,7 @@ def parse_args():
parser.add_argument(
"--distplan",
type=str,
default='CAI_Gemini',
default="CAI_Gemini",
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
)
parser.add_argument(
@ -47,14 +54,13 @@ def parse_args():
parser.add_argument(
"--placement",
type=str,
default='cpu',
default="cpu",
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
action='store_true',
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
action="store_true",
help="Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--batch_size",
@ -105,7 +111,6 @@ def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
@ -114,7 +119,9 @@ class GPTLMLoss(nn.Module):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return self.loss_fn(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
def get_cpu_mem():
@ -125,8 +132,8 @@ def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2
def get_mem_info(prefix=''):
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
def get_mem_info(prefix=""):
return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB"
def get_model_size(model: nn.Module):
@ -142,11 +149,11 @@ def model_size_formatter(numel: int) -> str:
MB_SIZE = 10**6
KB_SIZE = 10**3
if numel >= GB_SIZE:
return f'{numel / GB_SIZE:.1f}B'
return f"{numel / GB_SIZE:.1f}B"
elif numel >= MB_SIZE:
return f'{numel / MB_SIZE:.1f}M'
return f"{numel / MB_SIZE:.1f}M"
elif numel >= KB_SIZE:
return f'{numel / KB_SIZE:.1f}K'
return f"{numel / KB_SIZE:.1f}K"
else:
return str(numel)
@ -154,7 +161,7 @@ def model_size_formatter(numel: int) -> str:
def set_cpu_maximum_parallelism():
conf_str = torch.__config__.parallel_info()
inter_str = conf_str.split("hardware_concurrency() : ")[1]
max_concurrency = inter_str.split('\n')[0]
max_concurrency = inter_str.split("\n")[0]
os.environ["OMP_NUM_THREADS"] = max_concurrency
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
@ -170,7 +177,7 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
# NOTE() a param maybe shared by two modules
if hasattr(param, 'visited'):
if hasattr(param, "visited"):
continue
# if shard init, then convert param to replica and use the dp-only ProcessGroup
@ -179,22 +186,22 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
param.set_process_group(pg)
# shard it w.r.t tp pattern
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # colmn slice
if "mlp.c_fc" in mn:
if "weight" in pn or "bias" in pn:
split_param_col_tp1d(param, pg) # colmn slice
# keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False)
else:
param.set_dist_spec(ReplicaSpec())
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg) # row slice
elif "mlp.c_proj" in mn:
if "weight" in pn:
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif "wte" in mn or "wpe" in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif "c_attn" in mn or "c_proj" in mn:
split_param_col_tp1d(param, pg) # colmn slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
@ -209,7 +216,13 @@ def main():
args = parse_args()
# if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]:
if args.distplan not in [
"CAI_ZeRO1",
"CAI_ZeRO2",
"CAI_Gemini",
"Pytorch_DDP",
"Pytorch_ZeRO",
]:
raise TypeError(f"{args.distplan} is error")
# batch size per DP degree
@ -221,14 +234,18 @@ def main():
WARMUP_STEPS = 1
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
PROF_FLAG = False # The flag of profiling, False by default
assert (
NUM_STEPS - WARMUP_STEPS
) % 2 == 1, "the number of valid steps should be odd to take the median"
PROF_FLAG = False # The flag of profiling, False by default
disable_existing_loggers()
colossalai.launch_from_torch(config={})
logger = get_dist_logger()
logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])
logger.info(
f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0]
)
# build criterion
criterion = GPTLMLoss()
@ -244,10 +261,12 @@ def main():
raise RuntimeError("You can only use shardinit with CAI_Gemini")
# build GPT model
with ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
with ColoInitContext(
device=get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg,
):
model = model_builder(VOCAB_SIZE, checkpoint=True)
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
@ -259,15 +278,21 @@ def main():
# asign running configurations
gemini_config = None
if args.distplan.startswith("CAI_ZeRO"):
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
optim_config = dict(
reduce_bucket_size=12 * 1024 * 1024,
overlap_communication=True,
verbose=True,
)
elif args.distplan == "CAI_Gemini":
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
device=get_current_device(),
placement_policy=args.placement,
pin_memory=True,
hidden_dim=model.model.config.hidden_size,
search_range_mb=128)
optim_config = dict(gpu_margin_mem_ratio=0.)
gemini_config = dict(
strict_ddp_mode=args.tp_degree == 1,
device=get_current_device(),
placement_policy=args.placement,
pin_memory=True,
hidden_dim=model.model.config.hidden_size,
search_range_mb=128,
)
optim_config = dict(gpu_margin_mem_ratio=0.0)
else:
raise RuntimeError
@ -287,7 +312,7 @@ def main():
model = zero_model_wrapper(model, zero_stage, gemini_config)
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
logger.info(get_mem_info(prefix="After init optim, "), ranks=[0])
elif args.distplan.startswith("Pytorch"):
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
model = model_builder(VOCAB_SIZE, checkpoint=True).cuda()
@ -296,14 +321,17 @@ def main():
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
elif args.distplan.endswith("ZeRO"):
from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
optimizer = ZeroRedundancyOptimizer(
model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3
)
else:
raise RuntimeError
# model is shared after TP
numel = get_model_size(model)
logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
logger.info(get_mem_info(prefix="After init model, "), ranks=[0])
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
# = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
@ -325,7 +353,7 @@ def main():
torch.cuda.synchronize()
fwd_end = time()
fwd_time = fwd_end - start
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Forward "), ranks=[0])
if args.distplan.startswith("CAI"):
optimizer.backward(loss)
@ -337,13 +365,15 @@ def main():
torch.cuda.synchronize()
bwd_end = time()
bwd_time = bwd_end - fwd_end
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Backward "), ranks=[0])
optimizer.step()
torch.cuda.synchronize()
optim_time = time() - bwd_end
step_time = time() - start
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
logger.info(
get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Optimizer step "), ranks=[0]
)
step_tflops = get_tflops_func(step_time)
logger.info(
@ -353,10 +383,12 @@ def main():
if n >= WARMUP_STEPS:
tflops_list.append(step_tflops)
demo_profiler = get_profile_context(PROF_FLAG,
WARMUP_STEPS,
NUM_STEPS - WARMUP_STEPS,
save_dir=f"profile/{get_time_stamp()}-demo")
demo_profiler = get_profile_context(
PROF_FLAG,
WARMUP_STEPS,
NUM_STEPS - WARMUP_STEPS,
save_dir=f"profile/{get_time_stamp()}-demo",
)
with demo_profiler as prof:
start_time = time()
@ -364,7 +396,7 @@ def main():
train_step()
prof.step()
end_time = time()
print('total time: {}'.format(end_time - start_time))
print("total time: {}".format(end_time - start_time))
tflops_list.sort()
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
@ -372,5 +404,5 @@ def main():
torch.cuda.synchronize()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -2,11 +2,15 @@ import time
from contextlib import nullcontext
import torch
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
from torch.profiler import (
ProfilerActivity,
profile,
schedule,
tensorboard_trace_handler,
)
class DummyProfiler:
def __init__(self):
self.step_number = 0
@ -16,7 +20,9 @@ class DummyProfiler:
# Randomly Generated Data
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
input_ids = torch.randint(
0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()
)
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
@ -27,15 +33,17 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
if enable_flag:
return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
on_trace_ready=tensorboard_trace_handler(save_dir),
record_shapes=True,
profile_memory=True)
return profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
on_trace_ready=tensorboard_trace_handler(save_dir),
record_shapes=True,
profile_memory=True,
)
else:
return nullcontext(DummyProfiler())
def get_time_stamp():
cur_time = time.strftime("%d-%H:%M", time.localtime())
return cur_time
return cur_time

View File

@ -22,13 +22,15 @@ vocab_size = 32000
total_step = 100
use_activation_ckpt = False
class FakeSet(torch.utils.data.Dataset):
def __getitem__(self, idx):
return torch.randint(0, vocab_size, (seq_length, ))
return torch.randint(0, vocab_size, (seq_length,))
def __len__(self):
return 1000000000
class SpeedTest(pl.LightningModule):
def __init__(self):
super().__init__()
@ -45,7 +47,7 @@ class SpeedTest(pl.LightningModule):
out = self.model(batch, labels=batch)
loss = out.loss
if self.start_time is None:
print('start')
print("start")
self.start_time = time.time()
return loss
@ -53,23 +55,26 @@ class SpeedTest(pl.LightningModule):
optimizer = FusedAdam(self.trainer.model.parameters(), lr=1e-5)
return optimizer
model = SpeedTest()
train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size)
strategy=DeepSpeedStrategy(
strategy = DeepSpeedStrategy(
stage=2,
offload_optimizer=False,
offload_parameters=False,
process_group_backend="nccl"
process_group_backend="nccl",
)
trainer = pl.Trainer(
limit_train_batches=total_step,
limit_train_batches=total_step,
max_epochs=1,
devices=8,
accelerator="gpu",
strategy=strategy,
precision=16,
enable_checkpointing=False)
precision=16,
enable_checkpointing=False,
)
def train(model, train_loader):
start_time = time.time()
@ -77,4 +82,5 @@ def train(model, train_loader):
end_time = time.time()
return end_time - model.start_time
print('total time: {}'.format(train(model, train_loader)))
print("total time: {}".format(train(model, train_loader)))