update format

This commit is contained in:
LiangSong 2023-04-12 22:16:15 +08:00
parent a4aa109dd3
commit 3f62a23ee2
8 changed files with 281 additions and 160 deletions

View File

@ -1,3 +1,13 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-12 19:12:42
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-12 21:01:32
FilePath: /Open-Llama/pretrain.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import yaml import yaml
import torch import torch
import random import random
@ -20,55 +30,72 @@ from solver.trainer import Trainer
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string("config", None, "Training config path") flags.DEFINE_string("config", None, "Training config path")
class FakeSet(torch.utils.data.Dataset):
def __getitem__(self, idx):
return {"input_ids": torch.randint(0, 32000, (2048,))}
def __len__(self):
return 1000000000
def main(argv): def main(argv):
accelerator = Accelerator() accelerator = Accelerator()
with open(FLAGS.config, 'r', encoding="utf-8") as fp: with open(FLAGS.config, "r", encoding="utf-8") as fp:
config = yaml.load(fp, Loader=yaml.FullLoader) config = yaml.load(fp, Loader=yaml.FullLoader)
sp_model = spm.SentencePieceProcessor(model_file=config['data']['tokenizer_model_path']) sp_model = spm.SentencePieceProcessor(
model_file=config["data"]["tokenizer_model_path"]
)
tokenizer = Tokenizer(sp_model) tokenizer = Tokenizer(sp_model)
paths = create_shard_kwargs(config['data']['patterns']) # paths = create_shard_kwargs(config['data']['patterns'])
random.shuffle(paths) # random.shuffle(paths)
transform_dict = { # transform_dict = {
"wudao": preprocess_wudao_gen(tokenizer, config['model']['max_length']), # "wudao": preprocess_wudao_gen(tokenizer, config['model']['max_length']),
"pile": preprocess_the_pile_gen(tokenizer, config['model']['max_length']), # "pile": preprocess_the_pile_gen(tokenizer, config['model']['max_length']),
} # }
data_set = DataIter( # data_set = DataIter(
paths, # paths,
transform_dict=transform_dict, # transform_dict=transform_dict,
concat_docs=True, # concat_docs=True,
max_length=config['model']['max_length'], # max_length=config['model']['max_length'],
process_index=accelerator.process_index, # process_index=accelerator.process_index,
num_processes=accelerator.num_processes, # num_processes=accelerator.num_processes,
) # )
train_loader = DataLoader( # train_loader = DataLoader(
data_set, # data_set,
batch_size=config['train']['train_batch_size'], # batch_size=config['train']['train_batch_size'],
# If num_workers is greater than 1, duplicate data may occur. # # If num_workers is greater than 1, duplicate data may occur.
num_workers=0, # num_workers=0,
collate_fn=collate_fn_gen(tokenizer, config['model']['max_length']), # collate_fn=collate_fn_gen(tokenizer, config['model']['max_length']),
drop_last=True, # drop_last=True,
# )
train_loader = torch.utils.data.DataLoader(
FakeSet(), batch_size=config["train"]["train_batch_size"]
) )
# smaller initializer_range make training more stable # smaller initializer_range make training more stable
# add stabel embedding to token embedding # add stabel embedding to token embedding
raw_model = LlamaForCausalLM( raw_model = LlamaForCausalLM(
LlamaConfig( LlamaConfig(
vocab_size=tokenizer.vocab_size, vocab_size=tokenizer.vocab_size,
initializer_range=config['model']['initializer_range'], initializer_range=config["model"]["initializer_range"],
pad_token_id=tokenizer.pad_id, pad_token_id=tokenizer.pad_id,
rms_norm_eps=1e-5, rms_norm_eps=1e-5,
hidden_dropout_prob=config['model']['hidden_dropout_prob'], hidden_dropout_prob=config["model"]["hidden_dropout_prob"],
attention_dropout_prob=config['model']['attention_dropout_prob'], attention_dropout_prob=config["model"]["attention_dropout_prob"],
use_stable_embedding=config['model']['use_stable_embedding'], use_stable_embedding=config["model"]["use_stable_embedding"],
shared_input_output_embedding=config['model']['shared_input_output_embedding'], shared_input_output_embedding=config["model"][
"shared_input_output_embedding"
],
) )
) )
if config['train']['ckpt'] is not None: if config["train"]["ckpt"] is not None:
ckpt = torch.load(config['train']['ckpt']) ckpt = torch.load(config["train"]["ckpt"])
raw_model.load_state_dict(ckpt) raw_model.load_state_dict(ckpt)
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator) trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
trainer.train() trainer.train()
if __name__ == '__main__':
if __name__ == "__main__":
app.run(main) app.run(main)

View File

@ -8,6 +8,7 @@ from transformers import get_cosine_schedule_with_warmup
from dataset.validation import val_set from dataset.validation import val_set
class Trainer: class Trainer:
def __init__(self, config, raw_model, train_loader, tokenizer, accelerator): def __init__(self, config, raw_model, train_loader, tokenizer, accelerator):
self.config = config self.config = config
@ -15,18 +16,29 @@ class Trainer:
self.train_loader = train_loader self.train_loader = train_loader
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.accelerator = accelerator self.accelerator = accelerator
self.lr_scheduler_factor = accelerator.num_processes / accelerator.gradient_accumulation_steps self.lr_scheduler_factor = (
self.log_interval = self.config['log_interval'] * accelerator.gradient_accumulation_steps accelerator.num_processes / accelerator.gradient_accumulation_steps
self.eval_interval = self.config['eval_interval'] * accelerator.gradient_accumulation_steps )
self.save_interval = self.config['save_interval'] * accelerator.gradient_accumulation_steps self.log_interval = (
self.work_dir = self.config['work_dir'] self.config["log_interval"] * accelerator.gradient_accumulation_steps
)
self.eval_interval = (
self.config["eval_interval"] * accelerator.gradient_accumulation_steps
)
self.save_interval = (
self.config["save_interval"] * accelerator.gradient_accumulation_steps
)
self.work_dir = self.config["work_dir"]
self.get_model_info() self.get_model_info()
if accelerator.is_main_process: if accelerator.is_main_process:
wandb.init(project=self.config['project_name']) wandb.init(project=self.config["project_name"])
def get_model_info(self): def get_model_info(self):
with torch.no_grad(): with torch.no_grad():
summary(self.raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda()) summary(
self.raw_model.cuda(),
input_data=torch.ones(1, 64, dtype=torch.int64).cuda(),
)
def get_optimizer(self): def get_optimizer(self):
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"] no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
@ -37,7 +49,7 @@ class Trainer:
for n, p in self.raw_model.named_parameters() for n, p in self.raw_model.named_parameters()
if not any(nd in n for nd in no_decay) if not any(nd in n for nd in no_decay)
], ],
"weight_decay": self.config['train']['weight_decay'], "weight_decay": self.config["train"]["weight_decay"],
}, },
{ {
"params": [ "params": [
@ -48,13 +60,19 @@ class Trainer:
"weight_decay": 0.0, "weight_decay": 0.0,
}, },
] ]
self.optim = FusedAdam(optimizer_grouped_parameters, lr=self.config['train']['lr'], betas=(0.9, 0.95)) self.optim = FusedAdam(
optimizer_grouped_parameters,
lr=self.config["train"]["lr"],
betas=(0.9, 0.95),
)
def get_lr_scheduler(self): def get_lr_scheduler(self):
self.scheduler = get_cosine_schedule_with_warmup( self.scheduler = get_cosine_schedule_with_warmup(
self.optim, self.optim,
num_warmup_steps=self.config['train']['num_warmup_steps'] * self.lr_scheduler_factor, num_warmup_steps=self.config["train"]["num_warmup_steps"]
num_training_steps=self.config['train']['num_training_steps'] * self.lr_scheduler_factor, * self.lr_scheduler_factor,
num_training_steps=self.config["train"]["num_training_steps"]
* self.lr_scheduler_factor,
) )
def prepare(self): def prepare(self):
@ -82,28 +100,45 @@ class Trainer:
self.global_step = 0 self.global_step = 0
self.start_time = time.time() self.start_time = time.time()
self.optim.zero_grad() self.optim.zero_grad()
for self.data_step in range(self.config['train']['num_training_steps']): for self.data_step in range(self.config["train"]["num_training_steps"]):
self.model.train() self.model.train()
with self.accelerator.accumulate(self.model): with self.accelerator.accumulate(self.model):
batch = next(self.train_loader_iter) batch = next(self.train_loader_iter)
losses = self.train_step(batch) losses = self.train_step(batch)
if self.accelerator.sync_gradients: if self.accelerator.sync_gradients:
self.global_step += 1 self.global_step += 1
if self.data_step % self.log_interval == 0 and self.data_step > 0 and self.accelerator.is_main_process: if (
self.data_step % self.log_interval == 0
and self.data_step > 0
and self.accelerator.is_main_process
):
self.log(losses) self.log(losses)
if self.data_step % self.eval_interval == 0 and self.accelerator.is_main_process: if (
self.data_step % self.eval_interval == 0
and self.accelerator.is_main_process
):
self.eval() self.eval()
if self.data_step % self.save_interval == 0 and self.data_step > 0 and self.accelerator.is_main_process: if (
self.data_step % self.save_interval == 0
and self.data_step > 0
and self.accelerator.is_main_process
):
if not os.path.isdir(self.work_dir): if not os.path.isdir(self.work_dir):
os.mkdir(self.work_dir) os.mkdir(self.work_dir)
torch.save(self.raw_model.state_dict(), "{}/{}.pt".format(self.work_dir, self.global_step)) torch.save(
self.raw_model.state_dict(),
"{}/{}.pt".format(self.work_dir, self.global_step),
)
wandb.finish() wandb.finish()
def log(self, losses): def log(self, losses):
cost_time = time.time() - self.start_time cost_time = time.time() - self.start_time
self.start_time = time.time() self.start_time = time.time()
tokens = self.config['train']['train_batch_size'] * \ tokens = (
self.log_interval * self.config['model']['max_length'] self.config["train"]["train_batch_size"]
* self.log_interval
* self.config["model"]["max_length"]
)
wandb.log({"Training/Token per second per gpu": tokens / cost_time}) wandb.log({"Training/Token per second per gpu": tokens / cost_time})
for k, v in losses.items(): for k, v in losses.items():
wandb.log({"Losses/{}".format(k): v}) wandb.log({"Losses/{}".format(k): v})
@ -115,7 +150,10 @@ class Trainer:
wandb.log({"Training/Global Step": self.global_step}) wandb.log({"Training/Global Step": self.global_step})
self.accelerator.print( self.accelerator.print(
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format( "Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
self.global_step, self.data_step, losses["total_loss"], tokens / cost_time self.global_step,
self.data_step,
losses["total_loss"],
tokens / cost_time,
) )
) )

View File

@ -6,16 +6,20 @@ sp_model = spm.SentencePieceProcessor(
model_file="configs/llama_tokenizer_extended.model" model_file="configs/llama_tokenizer_extended.model"
) )
merged_vocab_size = sp_model.vocab_size() merged_vocab_size = sp_model.vocab_size()
ckpt = torch.load('data/llama_raw_ckpt/7B/consolidated.00.pth') ckpt = torch.load("data/llama_raw_ckpt/7B/consolidated.00.pth")
raw_vocab_size, hidden_size = ckpt['tok_embeddings.weight'].shape raw_vocab_size, hidden_size = ckpt["tok_embeddings.weight"].shape
extended_tok_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden_size) extended_tok_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden_size)
extended_tok_embeddings = extended_tok_embeddings * 0.001 extended_tok_embeddings = extended_tok_embeddings * 0.001
ckpt['tok_embeddings.weight'] = torch.cat([ckpt['tok_embeddings.weight'], extended_tok_embeddings], dim=0) ckpt["tok_embeddings.weight"] = torch.cat(
[ckpt["tok_embeddings.weight"], extended_tok_embeddings], dim=0
)
extended_out_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden_size) extended_out_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden_size)
extended_out_embeddings = extended_out_embeddings * 0.001 extended_out_embeddings = extended_out_embeddings * 0.001
ckpt['output.weight'] = torch.cat([ckpt['output.weight'], extended_out_embeddings], dim=0) ckpt["output.weight"] = torch.cat(
[ckpt["output.weight"], extended_out_embeddings], dim=0
)
rename_map = { rename_map = {
"tok_embeddings.weight": "model.embed_tokens.weight", "tok_embeddings.weight": "model.embed_tokens.weight",
@ -37,7 +41,7 @@ from_names = [
"layers.{}.feed_forward.w3.weight", "layers.{}.feed_forward.w3.weight",
"layers.{}.attention_norm.weight", "layers.{}.attention_norm.weight",
"layers.{}.ffn_norm.weight", "layers.{}.ffn_norm.weight",
"layers.{}.attention.inner_attention.rope.freqs" "layers.{}.attention.inner_attention.rope.freqs",
] ]
to_names = [ to_names = [
@ -59,4 +63,4 @@ for layer in range(32):
t = t.format(layer) t = t.format(layer)
v = ckpt.pop(f) v = ckpt.pop(f)
ckpt[t] = v ckpt[t] = v
torch.save(ckpt, 'data/llama_raw_ckpt/7B/extended.pth') torch.save(ckpt, "data/llama_raw_ckpt/7B/extended.pth")

View File

@ -3,21 +3,21 @@ import sentencepiece as spm
from sentencepiece import sentencepiece_model_pb2 as model from sentencepiece import sentencepiece_model_pb2 as model
raw_model = model.ModelProto() raw_model = model.ModelProto()
raw_model.ParseFromString(open('configs/llama_tokenizer.model', 'rb').read()) raw_model.ParseFromString(open("configs/llama_tokenizer.model", "rb").read())
exist_pieces = set([p.piece for p in raw_model.pieces]) exist_pieces = set([p.piece for p in raw_model.pieces])
cn_model = model.ModelProto() cn_model = model.ModelProto()
cn_model.ParseFromString(open('configs/4w_cn_vocab_wudao15.model', 'rb').read()) cn_model.ParseFromString(open("configs/4w_cn_vocab_wudao15.model", "rb").read())
for p in tqdm(cn_model.pieces, total=len(cn_model.pieces)): for p in tqdm(cn_model.pieces, total=len(cn_model.pieces)):
if p.piece not in exist_pieces: if p.piece not in exist_pieces:
raw_model.pieces.append(p) raw_model.pieces.append(p)
with open('configs/llama_tokenizer_extended.model', 'wb') as f: with open("configs/llama_tokenizer_extended.model", "wb") as f:
f.write(raw_model.SerializeToString()) f.write(raw_model.SerializeToString())
sp_model = spm.SentencePieceProcessor( sp_model = spm.SentencePieceProcessor(
model_file="configs/llama_tokenizer_extended.model" model_file="configs/llama_tokenizer_extended.model"
) )
print('merged vocab size: {}'.format(sp_model.vocab_size())) print("merged vocab size: {}".format(sp_model.vocab_size()))

View File

@ -20,13 +20,15 @@ vocab_size = 32000
total_step = 2 total_step = 2
use_activation_ckpt = True use_activation_ckpt = True
class FakeSet(torch.utils.data.Dataset): class FakeSet(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
return torch.randint(0, vocab_size, (seq_length, )) return torch.randint(0, vocab_size, (seq_length,))
def __len__(self): def __len__(self):
return 1000000000 return 1000000000
accelerator = Accelerator() accelerator = Accelerator()
raw_model = LlamaForCausalLM( raw_model = LlamaForCausalLM(
LlamaConfig( LlamaConfig(
@ -39,15 +41,18 @@ optimizer = FusedAdam(raw_model.parameters(), lr=1e-5)
train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size) train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size)
if accelerator.distributed_type == DistributedType.FSDP: if accelerator.distributed_type == DistributedType.FSDP:
accelerator.print('FSDP') accelerator.print("FSDP")
model = accelerator.prepare(raw_model) model = accelerator.prepare(raw_model)
optimizer, train_loader = accelerator.prepare(optimizer, train_loader) optimizer, train_loader = accelerator.prepare(optimizer, train_loader)
else: else:
model, optimizer, train_loader = accelerator.prepare(raw_model, optimizer, train_loader) model, optimizer, train_loader = accelerator.prepare(
raw_model, optimizer, train_loader
)
def train(model, optimizer, train_loader): def train(model, optimizer, train_loader):
start_time = time.time() start_time = time.time()
for i, batch in enumerate(train_loader): for i, batch in enumerate(train_loader):
if i == total_step: if i == total_step:
break break
optimizer.zero_grad() optimizer.zero_grad()
@ -58,4 +63,5 @@ def train(model, optimizer, train_loader):
end_time = time.time() end_time = time.time()
return end_time - start_time return end_time - start_time
accelerator.print('total time: {}'.format(train(model, optimizer, train_loader)))
accelerator.print("total time: {}".format(train(model, optimizer, train_loader)))

View File

@ -23,7 +23,14 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import (
ColoParameter,
ComputePattern,
ComputeSpec,
ProcessGroup,
ReplicaSpec,
ShardSpec,
)
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
@ -35,7 +42,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--distplan", "--distplan",
type=str, type=str,
default='CAI_Gemini', default="CAI_Gemini",
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
) )
parser.add_argument( parser.add_argument(
@ -47,14 +54,13 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--placement", "--placement",
type=str, type=str,
default='cpu', default="cpu",
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
) )
parser.add_argument( parser.add_argument(
"--shardinit", "--shardinit",
action='store_true', action="store_true",
help= help="Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
) )
parser.add_argument( parser.add_argument(
"--batch_size", "--batch_size",
@ -105,7 +111,6 @@ def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
class GPTLMLoss(nn.Module): class GPTLMLoss(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.loss_fn = nn.CrossEntropyLoss() self.loss_fn = nn.CrossEntropyLoss()
@ -114,7 +119,9 @@ class GPTLMLoss(nn.Module):
shift_logits = logits[..., :-1, :].contiguous() shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens # Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return self.loss_fn(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
def get_cpu_mem(): def get_cpu_mem():
@ -125,8 +132,8 @@ def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2 return torch.cuda.memory_allocated() / 1024**2
def get_mem_info(prefix=''): def get_mem_info(prefix=""):
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB"
def get_model_size(model: nn.Module): def get_model_size(model: nn.Module):
@ -142,11 +149,11 @@ def model_size_formatter(numel: int) -> str:
MB_SIZE = 10**6 MB_SIZE = 10**6
KB_SIZE = 10**3 KB_SIZE = 10**3
if numel >= GB_SIZE: if numel >= GB_SIZE:
return f'{numel / GB_SIZE:.1f}B' return f"{numel / GB_SIZE:.1f}B"
elif numel >= MB_SIZE: elif numel >= MB_SIZE:
return f'{numel / MB_SIZE:.1f}M' return f"{numel / MB_SIZE:.1f}M"
elif numel >= KB_SIZE: elif numel >= KB_SIZE:
return f'{numel / KB_SIZE:.1f}K' return f"{numel / KB_SIZE:.1f}K"
else: else:
return str(numel) return str(numel)
@ -154,7 +161,7 @@ def model_size_formatter(numel: int) -> str:
def set_cpu_maximum_parallelism(): def set_cpu_maximum_parallelism():
conf_str = torch.__config__.parallel_info() conf_str = torch.__config__.parallel_info()
inter_str = conf_str.split("hardware_concurrency() : ")[1] inter_str = conf_str.split("hardware_concurrency() : ")[1]
max_concurrency = inter_str.split('\n')[0] max_concurrency = inter_str.split("\n")[0]
os.environ["OMP_NUM_THREADS"] = max_concurrency os.environ["OMP_NUM_THREADS"] = max_concurrency
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.") print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
@ -170,7 +177,7 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
for mn, module in model.named_modules(): for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False): for pn, param in module.named_parameters(recurse=False):
# NOTE() a param maybe shared by two modules # NOTE() a param maybe shared by two modules
if hasattr(param, 'visited'): if hasattr(param, "visited"):
continue continue
# if shard init, then convert param to replica and use the dp-only ProcessGroup # if shard init, then convert param to replica and use the dp-only ProcessGroup
@ -179,22 +186,22 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
param.set_process_group(pg) param.set_process_group(pg)
# shard it w.r.t tp pattern # shard it w.r.t tp pattern
if 'mlp.c_fc' in mn: if "mlp.c_fc" in mn:
if 'weight' in pn or 'bias' in pn: if "weight" in pn or "bias" in pn:
split_param_col_tp1d(param, pg) # colmn slice split_param_col_tp1d(param, pg) # colmn slice
# keep the shape of the output from c_fc # keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False) param.compute_spec.set_output_replicate(False)
else: else:
param.set_dist_spec(ReplicaSpec()) param.set_dist_spec(ReplicaSpec())
elif 'mlp.c_proj' in mn: elif "mlp.c_proj" in mn:
if 'weight' in pn: if "weight" in pn:
split_param_row_tp1d(param, pg) # row slice split_param_row_tp1d(param, pg) # row slice
else: else:
param.set_dist_spec(ReplicaSpec()) param.set_dist_spec(ReplicaSpec())
elif 'wte' in mn or 'wpe' in mn: elif "wte" in mn or "wpe" in mn:
split_param_col_tp1d(param, pg) # colmn slice split_param_col_tp1d(param, pg) # colmn slice
elif 'c_attn' in mn or 'c_proj' in mn: elif "c_attn" in mn or "c_proj" in mn:
split_param_col_tp1d(param, pg) # colmn slice split_param_col_tp1d(param, pg) # colmn slice
else: else:
param.set_dist_spec(ReplicaSpec()) param.set_dist_spec(ReplicaSpec())
param.visited = True param.visited = True
@ -209,7 +216,13 @@ def main():
args = parse_args() args = parse_args()
# if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]: # if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]: if args.distplan not in [
"CAI_ZeRO1",
"CAI_ZeRO2",
"CAI_Gemini",
"Pytorch_DDP",
"Pytorch_ZeRO",
]:
raise TypeError(f"{args.distplan} is error") raise TypeError(f"{args.distplan} is error")
# batch size per DP degree # batch size per DP degree
@ -221,14 +234,18 @@ def main():
WARMUP_STEPS = 1 WARMUP_STEPS = 1
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" assert (
PROF_FLAG = False # The flag of profiling, False by default NUM_STEPS - WARMUP_STEPS
) % 2 == 1, "the number of valid steps should be odd to take the median"
PROF_FLAG = False # The flag of profiling, False by default
disable_existing_loggers() disable_existing_loggers()
colossalai.launch_from_torch(config={}) colossalai.launch_from_torch(config={})
logger = get_dist_logger() logger = get_dist_logger()
logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0]) logger.info(
f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0]
)
# build criterion # build criterion
criterion = GPTLMLoss() criterion = GPTLMLoss()
@ -244,10 +261,12 @@ def main():
raise RuntimeError("You can only use shardinit with CAI_Gemini") raise RuntimeError("You can only use shardinit with CAI_Gemini")
# build GPT model # build GPT model
with ColoInitContext(device=get_current_device(), with ColoInitContext(
dtype=torch.half, device=get_current_device(),
default_dist_spec=default_dist_spec, dtype=torch.half,
default_pg=shard_pg): default_dist_spec=default_dist_spec,
default_pg=shard_pg,
):
model = model_builder(VOCAB_SIZE, checkpoint=True) model = model_builder(VOCAB_SIZE, checkpoint=True)
tp_pg = ProcessGroup(tp_degree=args.tp_degree) tp_pg = ProcessGroup(tp_degree=args.tp_degree)
@ -259,15 +278,21 @@ def main():
# asign running configurations # asign running configurations
gemini_config = None gemini_config = None
if args.distplan.startswith("CAI_ZeRO"): if args.distplan.startswith("CAI_ZeRO"):
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True) optim_config = dict(
reduce_bucket_size=12 * 1024 * 1024,
overlap_communication=True,
verbose=True,
)
elif args.distplan == "CAI_Gemini": elif args.distplan == "CAI_Gemini":
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1, gemini_config = dict(
device=get_current_device(), strict_ddp_mode=args.tp_degree == 1,
placement_policy=args.placement, device=get_current_device(),
pin_memory=True, placement_policy=args.placement,
hidden_dim=model.model.config.hidden_size, pin_memory=True,
search_range_mb=128) hidden_dim=model.model.config.hidden_size,
optim_config = dict(gpu_margin_mem_ratio=0.) search_range_mb=128,
)
optim_config = dict(gpu_margin_mem_ratio=0.0)
else: else:
raise RuntimeError raise RuntimeError
@ -287,7 +312,7 @@ def main():
model = zero_model_wrapper(model, zero_stage, gemini_config) model = zero_model_wrapper(model, zero_stage, gemini_config)
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config) optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) logger.info(get_mem_info(prefix="After init optim, "), ranks=[0])
elif args.distplan.startswith("Pytorch"): elif args.distplan.startswith("Pytorch"):
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples." assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
model = model_builder(VOCAB_SIZE, checkpoint=True).cuda() model = model_builder(VOCAB_SIZE, checkpoint=True).cuda()
@ -296,14 +321,17 @@ def main():
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
elif args.distplan.endswith("ZeRO"): elif args.distplan.endswith("ZeRO"):
from torch.distributed.optim import ZeroRedundancyOptimizer from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
optimizer = ZeroRedundancyOptimizer(
model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3
)
else: else:
raise RuntimeError raise RuntimeError
# model is shared after TP # model is shared after TP
numel = get_model_size(model) numel = get_model_size(model)
logger.info(f"the size of testing model size is {model_size_formatter(numel)}.") logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) logger.info(get_mem_info(prefix="After init model, "), ranks=[0])
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
# = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree) # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
@ -325,7 +353,7 @@ def main():
torch.cuda.synchronize() torch.cuda.synchronize()
fwd_end = time() fwd_end = time()
fwd_time = fwd_end - start fwd_time = fwd_end - start
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0]) logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Forward "), ranks=[0])
if args.distplan.startswith("CAI"): if args.distplan.startswith("CAI"):
optimizer.backward(loss) optimizer.backward(loss)
@ -337,13 +365,15 @@ def main():
torch.cuda.synchronize() torch.cuda.synchronize()
bwd_end = time() bwd_end = time()
bwd_time = bwd_end - fwd_end bwd_time = bwd_end - fwd_end
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0]) logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Backward "), ranks=[0])
optimizer.step() optimizer.step()
torch.cuda.synchronize() torch.cuda.synchronize()
optim_time = time() - bwd_end optim_time = time() - bwd_end
step_time = time() - start step_time = time() - start
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) logger.info(
get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Optimizer step "), ranks=[0]
)
step_tflops = get_tflops_func(step_time) step_tflops = get_tflops_func(step_time)
logger.info( logger.info(
@ -353,10 +383,12 @@ def main():
if n >= WARMUP_STEPS: if n >= WARMUP_STEPS:
tflops_list.append(step_tflops) tflops_list.append(step_tflops)
demo_profiler = get_profile_context(PROF_FLAG, demo_profiler = get_profile_context(
WARMUP_STEPS, PROF_FLAG,
NUM_STEPS - WARMUP_STEPS, WARMUP_STEPS,
save_dir=f"profile/{get_time_stamp()}-demo") NUM_STEPS - WARMUP_STEPS,
save_dir=f"profile/{get_time_stamp()}-demo",
)
with demo_profiler as prof: with demo_profiler as prof:
start_time = time() start_time = time()
@ -364,7 +396,7 @@ def main():
train_step() train_step()
prof.step() prof.step()
end_time = time() end_time = time()
print('total time: {}'.format(end_time - start_time)) print("total time: {}".format(end_time - start_time))
tflops_list.sort() tflops_list.sort()
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
@ -372,5 +404,5 @@ def main():
torch.cuda.synchronize() torch.cuda.synchronize()
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -2,11 +2,15 @@ import time
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler from torch.profiler import (
ProfilerActivity,
profile,
schedule,
tensorboard_trace_handler,
)
class DummyProfiler: class DummyProfiler:
def __init__(self): def __init__(self):
self.step_number = 0 self.step_number = 0
@ -16,7 +20,9 @@ class DummyProfiler:
# Randomly Generated Data # Randomly Generated Data
def get_data(batch_size, seq_len, vocab_size): def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) input_ids = torch.randint(
0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()
)
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask return input_ids, attention_mask
@ -27,11 +33,13 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir): def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
if enable_flag: if enable_flag:
return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], return profile(
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
on_trace_ready=tensorboard_trace_handler(save_dir), schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
record_shapes=True, on_trace_ready=tensorboard_trace_handler(save_dir),
profile_memory=True) record_shapes=True,
profile_memory=True,
)
else: else:
return nullcontext(DummyProfiler()) return nullcontext(DummyProfiler())

View File

@ -22,13 +22,15 @@ vocab_size = 32000
total_step = 100 total_step = 100
use_activation_ckpt = False use_activation_ckpt = False
class FakeSet(torch.utils.data.Dataset): class FakeSet(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
return torch.randint(0, vocab_size, (seq_length, )) return torch.randint(0, vocab_size, (seq_length,))
def __len__(self): def __len__(self):
return 1000000000 return 1000000000
class SpeedTest(pl.LightningModule): class SpeedTest(pl.LightningModule):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -45,7 +47,7 @@ class SpeedTest(pl.LightningModule):
out = self.model(batch, labels=batch) out = self.model(batch, labels=batch)
loss = out.loss loss = out.loss
if self.start_time is None: if self.start_time is None:
print('start') print("start")
self.start_time = time.time() self.start_time = time.time()
return loss return loss
@ -53,14 +55,15 @@ class SpeedTest(pl.LightningModule):
optimizer = FusedAdam(self.trainer.model.parameters(), lr=1e-5) optimizer = FusedAdam(self.trainer.model.parameters(), lr=1e-5)
return optimizer return optimizer
model = SpeedTest() model = SpeedTest()
train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size) train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size)
strategy=DeepSpeedStrategy( strategy = DeepSpeedStrategy(
stage=2, stage=2,
offload_optimizer=False, offload_optimizer=False,
offload_parameters=False, offload_parameters=False,
process_group_backend="nccl" process_group_backend="nccl",
) )
trainer = pl.Trainer( trainer = pl.Trainer(
limit_train_batches=total_step, limit_train_batches=total_step,
@ -69,7 +72,9 @@ trainer = pl.Trainer(
accelerator="gpu", accelerator="gpu",
strategy=strategy, strategy=strategy,
precision=16, precision=16,
enable_checkpointing=False) enable_checkpointing=False,
)
def train(model, train_loader): def train(model, train_loader):
start_time = time.time() start_time = time.time()
@ -77,4 +82,5 @@ def train(model, train_loader):
end_time = time.time() end_time = time.time()
return end_time - model.start_time return end_time - model.start_time
print('total time: {}'.format(train(model, train_loader)))
print("total time: {}".format(train(model, train_loader)))