update format
This commit is contained in:
parent
a4aa109dd3
commit
3f62a23ee2
91
pretrain.py
91
pretrain.py
|
@ -1,3 +1,13 @@
|
|||
"""
|
||||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-12 19:12:42
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-12 21:01:32
|
||||
FilePath: /Open-Llama/pretrain.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import yaml
|
||||
import torch
|
||||
import random
|
||||
|
@ -20,55 +30,72 @@ from solver.trainer import Trainer
|
|||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("config", None, "Training config path")
|
||||
|
||||
|
||||
class FakeSet(torch.utils.data.Dataset):
|
||||
def __getitem__(self, idx):
|
||||
return {"input_ids": torch.randint(0, 32000, (2048,))}
|
||||
|
||||
def __len__(self):
|
||||
return 1000000000
|
||||
|
||||
|
||||
def main(argv):
|
||||
accelerator = Accelerator()
|
||||
|
||||
with open(FLAGS.config, 'r', encoding="utf-8") as fp:
|
||||
with open(FLAGS.config, "r", encoding="utf-8") as fp:
|
||||
config = yaml.load(fp, Loader=yaml.FullLoader)
|
||||
sp_model = spm.SentencePieceProcessor(model_file=config['data']['tokenizer_model_path'])
|
||||
sp_model = spm.SentencePieceProcessor(
|
||||
model_file=config["data"]["tokenizer_model_path"]
|
||||
)
|
||||
tokenizer = Tokenizer(sp_model)
|
||||
|
||||
paths = create_shard_kwargs(config['data']['patterns'])
|
||||
random.shuffle(paths)
|
||||
transform_dict = {
|
||||
"wudao": preprocess_wudao_gen(tokenizer, config['model']['max_length']),
|
||||
"pile": preprocess_the_pile_gen(tokenizer, config['model']['max_length']),
|
||||
}
|
||||
data_set = DataIter(
|
||||
paths,
|
||||
transform_dict=transform_dict,
|
||||
concat_docs=True,
|
||||
max_length=config['model']['max_length'],
|
||||
process_index=accelerator.process_index,
|
||||
num_processes=accelerator.num_processes,
|
||||
)
|
||||
train_loader = DataLoader(
|
||||
data_set,
|
||||
batch_size=config['train']['train_batch_size'],
|
||||
# If num_workers is greater than 1, duplicate data may occur.
|
||||
num_workers=0,
|
||||
collate_fn=collate_fn_gen(tokenizer, config['model']['max_length']),
|
||||
drop_last=True,
|
||||
# paths = create_shard_kwargs(config['data']['patterns'])
|
||||
# random.shuffle(paths)
|
||||
# transform_dict = {
|
||||
# "wudao": preprocess_wudao_gen(tokenizer, config['model']['max_length']),
|
||||
# "pile": preprocess_the_pile_gen(tokenizer, config['model']['max_length']),
|
||||
# }
|
||||
# data_set = DataIter(
|
||||
# paths,
|
||||
# transform_dict=transform_dict,
|
||||
# concat_docs=True,
|
||||
# max_length=config['model']['max_length'],
|
||||
# process_index=accelerator.process_index,
|
||||
# num_processes=accelerator.num_processes,
|
||||
# )
|
||||
# train_loader = DataLoader(
|
||||
# data_set,
|
||||
# batch_size=config['train']['train_batch_size'],
|
||||
# # If num_workers is greater than 1, duplicate data may occur.
|
||||
# num_workers=0,
|
||||
# collate_fn=collate_fn_gen(tokenizer, config['model']['max_length']),
|
||||
# drop_last=True,
|
||||
# )
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
FakeSet(), batch_size=config["train"]["train_batch_size"]
|
||||
)
|
||||
# smaller initializer_range make training more stable
|
||||
# add stabel embedding to token embedding
|
||||
raw_model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=tokenizer.vocab_size,
|
||||
initializer_range=config['model']['initializer_range'],
|
||||
initializer_range=config["model"]["initializer_range"],
|
||||
pad_token_id=tokenizer.pad_id,
|
||||
rms_norm_eps=1e-5,
|
||||
hidden_dropout_prob=config['model']['hidden_dropout_prob'],
|
||||
attention_dropout_prob=config['model']['attention_dropout_prob'],
|
||||
use_stable_embedding=config['model']['use_stable_embedding'],
|
||||
shared_input_output_embedding=config['model']['shared_input_output_embedding'],
|
||||
hidden_dropout_prob=config["model"]["hidden_dropout_prob"],
|
||||
attention_dropout_prob=config["model"]["attention_dropout_prob"],
|
||||
use_stable_embedding=config["model"]["use_stable_embedding"],
|
||||
shared_input_output_embedding=config["model"][
|
||||
"shared_input_output_embedding"
|
||||
],
|
||||
)
|
||||
)
|
||||
if config['train']['ckpt'] is not None:
|
||||
ckpt = torch.load(config['train']['ckpt'])
|
||||
if config["train"]["ckpt"] is not None:
|
||||
ckpt = torch.load(config["train"]["ckpt"])
|
||||
raw_model.load_state_dict(ckpt)
|
||||
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
||||
trainer.train()
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
|
|
|
@ -8,6 +8,7 @@ from transformers import get_cosine_schedule_with_warmup
|
|||
|
||||
from dataset.validation import val_set
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, config, raw_model, train_loader, tokenizer, accelerator):
|
||||
self.config = config
|
||||
|
@ -15,18 +16,29 @@ class Trainer:
|
|||
self.train_loader = train_loader
|
||||
self.tokenizer = tokenizer
|
||||
self.accelerator = accelerator
|
||||
self.lr_scheduler_factor = accelerator.num_processes / accelerator.gradient_accumulation_steps
|
||||
self.log_interval = self.config['log_interval'] * accelerator.gradient_accumulation_steps
|
||||
self.eval_interval = self.config['eval_interval'] * accelerator.gradient_accumulation_steps
|
||||
self.save_interval = self.config['save_interval'] * accelerator.gradient_accumulation_steps
|
||||
self.work_dir = self.config['work_dir']
|
||||
self.lr_scheduler_factor = (
|
||||
accelerator.num_processes / accelerator.gradient_accumulation_steps
|
||||
)
|
||||
self.log_interval = (
|
||||
self.config["log_interval"] * accelerator.gradient_accumulation_steps
|
||||
)
|
||||
self.eval_interval = (
|
||||
self.config["eval_interval"] * accelerator.gradient_accumulation_steps
|
||||
)
|
||||
self.save_interval = (
|
||||
self.config["save_interval"] * accelerator.gradient_accumulation_steps
|
||||
)
|
||||
self.work_dir = self.config["work_dir"]
|
||||
self.get_model_info()
|
||||
if accelerator.is_main_process:
|
||||
wandb.init(project=self.config['project_name'])
|
||||
wandb.init(project=self.config["project_name"])
|
||||
|
||||
def get_model_info(self):
|
||||
with torch.no_grad():
|
||||
summary(self.raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda())
|
||||
summary(
|
||||
self.raw_model.cuda(),
|
||||
input_data=torch.ones(1, 64, dtype=torch.int64).cuda(),
|
||||
)
|
||||
|
||||
def get_optimizer(self):
|
||||
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
|
||||
|
@ -37,7 +49,7 @@ class Trainer:
|
|||
for n, p in self.raw_model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": self.config['train']['weight_decay'],
|
||||
"weight_decay": self.config["train"]["weight_decay"],
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
|
@ -48,15 +60,21 @@ class Trainer:
|
|||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
self.optim = FusedAdam(optimizer_grouped_parameters, lr=self.config['train']['lr'], betas=(0.9, 0.95))
|
||||
self.optim = FusedAdam(
|
||||
optimizer_grouped_parameters,
|
||||
lr=self.config["train"]["lr"],
|
||||
betas=(0.9, 0.95),
|
||||
)
|
||||
|
||||
def get_lr_scheduler(self):
|
||||
self.scheduler = get_cosine_schedule_with_warmup(
|
||||
self.optim,
|
||||
num_warmup_steps=self.config['train']['num_warmup_steps'] * self.lr_scheduler_factor,
|
||||
num_training_steps=self.config['train']['num_training_steps'] * self.lr_scheduler_factor,
|
||||
num_warmup_steps=self.config["train"]["num_warmup_steps"]
|
||||
* self.lr_scheduler_factor,
|
||||
num_training_steps=self.config["train"]["num_training_steps"]
|
||||
* self.lr_scheduler_factor,
|
||||
)
|
||||
|
||||
|
||||
def prepare(self):
|
||||
_, self.model, self.optim, self.scheduler = self.accelerator.prepare(
|
||||
self.train_loader, self.raw_model, self.optim, self.scheduler
|
||||
|
@ -74,7 +92,7 @@ class Trainer:
|
|||
self.scheduler.step()
|
||||
self.optim.zero_grad()
|
||||
return losses
|
||||
|
||||
|
||||
def train(self):
|
||||
self.get_optimizer()
|
||||
self.get_lr_scheduler()
|
||||
|
@ -82,28 +100,45 @@ class Trainer:
|
|||
self.global_step = 0
|
||||
self.start_time = time.time()
|
||||
self.optim.zero_grad()
|
||||
for self.data_step in range(self.config['train']['num_training_steps']):
|
||||
for self.data_step in range(self.config["train"]["num_training_steps"]):
|
||||
self.model.train()
|
||||
with self.accelerator.accumulate(self.model):
|
||||
batch = next(self.train_loader_iter)
|
||||
losses = self.train_step(batch)
|
||||
if self.accelerator.sync_gradients:
|
||||
self.global_step += 1
|
||||
if self.data_step % self.log_interval == 0 and self.data_step > 0 and self.accelerator.is_main_process:
|
||||
if (
|
||||
self.data_step % self.log_interval == 0
|
||||
and self.data_step > 0
|
||||
and self.accelerator.is_main_process
|
||||
):
|
||||
self.log(losses)
|
||||
if self.data_step % self.eval_interval == 0 and self.accelerator.is_main_process:
|
||||
if (
|
||||
self.data_step % self.eval_interval == 0
|
||||
and self.accelerator.is_main_process
|
||||
):
|
||||
self.eval()
|
||||
if self.data_step % self.save_interval == 0 and self.data_step > 0 and self.accelerator.is_main_process:
|
||||
if (
|
||||
self.data_step % self.save_interval == 0
|
||||
and self.data_step > 0
|
||||
and self.accelerator.is_main_process
|
||||
):
|
||||
if not os.path.isdir(self.work_dir):
|
||||
os.mkdir(self.work_dir)
|
||||
torch.save(self.raw_model.state_dict(), "{}/{}.pt".format(self.work_dir, self.global_step))
|
||||
torch.save(
|
||||
self.raw_model.state_dict(),
|
||||
"{}/{}.pt".format(self.work_dir, self.global_step),
|
||||
)
|
||||
wandb.finish()
|
||||
|
||||
def log(self, losses):
|
||||
cost_time = time.time() - self.start_time
|
||||
self.start_time = time.time()
|
||||
tokens = self.config['train']['train_batch_size'] * \
|
||||
self.log_interval * self.config['model']['max_length']
|
||||
tokens = (
|
||||
self.config["train"]["train_batch_size"]
|
||||
* self.log_interval
|
||||
* self.config["model"]["max_length"]
|
||||
)
|
||||
wandb.log({"Training/Token per second per gpu": tokens / cost_time})
|
||||
for k, v in losses.items():
|
||||
wandb.log({"Losses/{}".format(k): v})
|
||||
|
@ -115,7 +150,10 @@ class Trainer:
|
|||
wandb.log({"Training/Global Step": self.global_step})
|
||||
self.accelerator.print(
|
||||
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
|
||||
self.global_step, self.data_step, losses["total_loss"], tokens / cost_time
|
||||
self.global_step,
|
||||
self.data_step,
|
||||
losses["total_loss"],
|
||||
tokens / cost_time,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -137,4 +175,4 @@ class Trainer:
|
|||
pred = self.tokenizer.decode(pred.cpu())[0]
|
||||
pred = pred[inputs_len:]
|
||||
text_table.add_data(raw_inputs, pred)
|
||||
wandb.log({"Predictions on {}".format(self.global_step): text_table})
|
||||
wandb.log({"Predictions on {}".format(self.global_step): text_table})
|
||||
|
|
|
@ -6,16 +6,20 @@ sp_model = spm.SentencePieceProcessor(
|
|||
model_file="configs/llama_tokenizer_extended.model"
|
||||
)
|
||||
merged_vocab_size = sp_model.vocab_size()
|
||||
ckpt = torch.load('data/llama_raw_ckpt/7B/consolidated.00.pth')
|
||||
ckpt = torch.load("data/llama_raw_ckpt/7B/consolidated.00.pth")
|
||||
|
||||
raw_vocab_size, hidden_size = ckpt['tok_embeddings.weight'].shape
|
||||
raw_vocab_size, hidden_size = ckpt["tok_embeddings.weight"].shape
|
||||
extended_tok_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden_size)
|
||||
extended_tok_embeddings = extended_tok_embeddings * 0.001
|
||||
ckpt['tok_embeddings.weight'] = torch.cat([ckpt['tok_embeddings.weight'], extended_tok_embeddings], dim=0)
|
||||
ckpt["tok_embeddings.weight"] = torch.cat(
|
||||
[ckpt["tok_embeddings.weight"], extended_tok_embeddings], dim=0
|
||||
)
|
||||
|
||||
extended_out_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden_size)
|
||||
extended_out_embeddings = extended_out_embeddings * 0.001
|
||||
ckpt['output.weight'] = torch.cat([ckpt['output.weight'], extended_out_embeddings], dim=0)
|
||||
ckpt["output.weight"] = torch.cat(
|
||||
[ckpt["output.weight"], extended_out_embeddings], dim=0
|
||||
)
|
||||
|
||||
rename_map = {
|
||||
"tok_embeddings.weight": "model.embed_tokens.weight",
|
||||
|
@ -26,31 +30,31 @@ rename_map = {
|
|||
for f, t in rename_map.items():
|
||||
v = ckpt.pop(f)
|
||||
ckpt[t] = v
|
||||
|
||||
|
||||
from_names = [
|
||||
"layers.{}.attention.wq.weight",
|
||||
"layers.{}.attention.wk.weight",
|
||||
"layers.{}.attention.wv.weight",
|
||||
"layers.{}.attention.wo.weight",
|
||||
"layers.{}.feed_forward.w1.weight",
|
||||
"layers.{}.feed_forward.w2.weight",
|
||||
"layers.{}.feed_forward.w3.weight",
|
||||
"layers.{}.attention_norm.weight",
|
||||
"layers.{}.ffn_norm.weight",
|
||||
"layers.{}.attention.inner_attention.rope.freqs"
|
||||
"layers.{}.attention.wq.weight",
|
||||
"layers.{}.attention.wk.weight",
|
||||
"layers.{}.attention.wv.weight",
|
||||
"layers.{}.attention.wo.weight",
|
||||
"layers.{}.feed_forward.w1.weight",
|
||||
"layers.{}.feed_forward.w2.weight",
|
||||
"layers.{}.feed_forward.w3.weight",
|
||||
"layers.{}.attention_norm.weight",
|
||||
"layers.{}.ffn_norm.weight",
|
||||
"layers.{}.attention.inner_attention.rope.freqs",
|
||||
]
|
||||
|
||||
to_names = [
|
||||
"model.layers.{}.self_attn.q_proj.weight",
|
||||
"model.layers.{}.self_attn.k_proj.weight",
|
||||
"model.layers.{}.self_attn.v_proj.weight",
|
||||
"model.layers.{}.self_attn.o_proj.weight",
|
||||
"model.layers.{}.mlp.gate_proj.weight",
|
||||
"model.layers.{}.mlp.down_proj.weight",
|
||||
"model.layers.{}.mlp.up_proj.weight",
|
||||
"model.layers.{}.input_layernorm.weight",
|
||||
"model.layers.{}.self_attn.q_proj.weight",
|
||||
"model.layers.{}.self_attn.k_proj.weight",
|
||||
"model.layers.{}.self_attn.v_proj.weight",
|
||||
"model.layers.{}.self_attn.o_proj.weight",
|
||||
"model.layers.{}.mlp.gate_proj.weight",
|
||||
"model.layers.{}.mlp.down_proj.weight",
|
||||
"model.layers.{}.mlp.up_proj.weight",
|
||||
"model.layers.{}.input_layernorm.weight",
|
||||
"model.layers.{}.post_attention_layernorm.weight",
|
||||
"model.layers.{}.self_attn.rotary_emb.inv_freq",
|
||||
"model.layers.{}.self_attn.rotary_emb.inv_freq",
|
||||
]
|
||||
|
||||
for layer in range(32):
|
||||
|
@ -59,4 +63,4 @@ for layer in range(32):
|
|||
t = t.format(layer)
|
||||
v = ckpt.pop(f)
|
||||
ckpt[t] = v
|
||||
torch.save(ckpt, 'data/llama_raw_ckpt/7B/extended.pth')
|
||||
torch.save(ckpt, "data/llama_raw_ckpt/7B/extended.pth")
|
||||
|
|
|
@ -3,21 +3,21 @@ import sentencepiece as spm
|
|||
from sentencepiece import sentencepiece_model_pb2 as model
|
||||
|
||||
raw_model = model.ModelProto()
|
||||
raw_model.ParseFromString(open('configs/llama_tokenizer.model', 'rb').read())
|
||||
raw_model.ParseFromString(open("configs/llama_tokenizer.model", "rb").read())
|
||||
|
||||
exist_pieces = set([p.piece for p in raw_model.pieces])
|
||||
cn_model = model.ModelProto()
|
||||
cn_model.ParseFromString(open('configs/4w_cn_vocab_wudao15.model', 'rb').read())
|
||||
cn_model.ParseFromString(open("configs/4w_cn_vocab_wudao15.model", "rb").read())
|
||||
|
||||
for p in tqdm(cn_model.pieces, total=len(cn_model.pieces)):
|
||||
if p.piece not in exist_pieces:
|
||||
raw_model.pieces.append(p)
|
||||
|
||||
with open('configs/llama_tokenizer_extended.model', 'wb') as f:
|
||||
with open("configs/llama_tokenizer_extended.model", "wb") as f:
|
||||
f.write(raw_model.SerializeToString())
|
||||
|
||||
sp_model = spm.SentencePieceProcessor(
|
||||
model_file="configs/llama_tokenizer_extended.model"
|
||||
)
|
||||
|
||||
print('merged vocab size: {}'.format(sp_model.vocab_size()))
|
||||
print("merged vocab size: {}".format(sp_model.vocab_size()))
|
||||
|
|
|
@ -20,13 +20,15 @@ vocab_size = 32000
|
|||
total_step = 2
|
||||
use_activation_ckpt = True
|
||||
|
||||
|
||||
class FakeSet(torch.utils.data.Dataset):
|
||||
def __getitem__(self, idx):
|
||||
return torch.randint(0, vocab_size, (seq_length, ))
|
||||
|
||||
return torch.randint(0, vocab_size, (seq_length,))
|
||||
|
||||
def __len__(self):
|
||||
return 1000000000
|
||||
|
||||
|
||||
accelerator = Accelerator()
|
||||
raw_model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
|
@ -39,15 +41,18 @@ optimizer = FusedAdam(raw_model.parameters(), lr=1e-5)
|
|||
|
||||
train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size)
|
||||
if accelerator.distributed_type == DistributedType.FSDP:
|
||||
accelerator.print('FSDP')
|
||||
accelerator.print("FSDP")
|
||||
model = accelerator.prepare(raw_model)
|
||||
optimizer, train_loader = accelerator.prepare(optimizer, train_loader)
|
||||
else:
|
||||
model, optimizer, train_loader = accelerator.prepare(raw_model, optimizer, train_loader)
|
||||
model, optimizer, train_loader = accelerator.prepare(
|
||||
raw_model, optimizer, train_loader
|
||||
)
|
||||
|
||||
|
||||
def train(model, optimizer, train_loader):
|
||||
start_time = time.time()
|
||||
for i, batch in enumerate(train_loader):
|
||||
for i, batch in enumerate(train_loader):
|
||||
if i == total_step:
|
||||
break
|
||||
optimizer.zero_grad()
|
||||
|
@ -58,4 +63,5 @@ def train(model, optimizer, train_loader):
|
|||
end_time = time.time()
|
||||
return end_time - start_time
|
||||
|
||||
accelerator.print('total time: {}'.format(train(model, optimizer, train_loader)))
|
||||
|
||||
accelerator.print("total time: {}".format(train(model, optimizer, train_loader)))
|
||||
|
|
|
@ -23,7 +23,14 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.tensor import (
|
||||
ColoParameter,
|
||||
ComputePattern,
|
||||
ComputeSpec,
|
||||
ProcessGroup,
|
||||
ReplicaSpec,
|
||||
ShardSpec,
|
||||
)
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
|
||||
|
||||
|
@ -35,7 +42,7 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
"--distplan",
|
||||
type=str,
|
||||
default='CAI_Gemini',
|
||||
default="CAI_Gemini",
|
||||
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -47,14 +54,13 @@ def parse_args():
|
|||
parser.add_argument(
|
||||
"--placement",
|
||||
type=str,
|
||||
default='cpu',
|
||||
default="cpu",
|
||||
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shardinit",
|
||||
action='store_true',
|
||||
help=
|
||||
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||
action="store_true",
|
||||
help="Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
|
@ -105,7 +111,6 @@ def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
|||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
|
@ -114,7 +119,9 @@ class GPTLMLoss(nn.Module):
|
|||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
return self.loss_fn(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
||||
)
|
||||
|
||||
|
||||
def get_cpu_mem():
|
||||
|
@ -125,8 +132,8 @@ def get_gpu_mem():
|
|||
return torch.cuda.memory_allocated() / 1024**2
|
||||
|
||||
|
||||
def get_mem_info(prefix=''):
|
||||
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
|
||||
def get_mem_info(prefix=""):
|
||||
return f"{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB"
|
||||
|
||||
|
||||
def get_model_size(model: nn.Module):
|
||||
|
@ -142,11 +149,11 @@ def model_size_formatter(numel: int) -> str:
|
|||
MB_SIZE = 10**6
|
||||
KB_SIZE = 10**3
|
||||
if numel >= GB_SIZE:
|
||||
return f'{numel / GB_SIZE:.1f}B'
|
||||
return f"{numel / GB_SIZE:.1f}B"
|
||||
elif numel >= MB_SIZE:
|
||||
return f'{numel / MB_SIZE:.1f}M'
|
||||
return f"{numel / MB_SIZE:.1f}M"
|
||||
elif numel >= KB_SIZE:
|
||||
return f'{numel / KB_SIZE:.1f}K'
|
||||
return f"{numel / KB_SIZE:.1f}K"
|
||||
else:
|
||||
return str(numel)
|
||||
|
||||
|
@ -154,7 +161,7 @@ def model_size_formatter(numel: int) -> str:
|
|||
def set_cpu_maximum_parallelism():
|
||||
conf_str = torch.__config__.parallel_info()
|
||||
inter_str = conf_str.split("hardware_concurrency() : ")[1]
|
||||
max_concurrency = inter_str.split('\n')[0]
|
||||
max_concurrency = inter_str.split("\n")[0]
|
||||
os.environ["OMP_NUM_THREADS"] = max_concurrency
|
||||
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
|
||||
|
||||
|
@ -170,7 +177,7 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
|||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
# NOTE() a param maybe shared by two modules
|
||||
if hasattr(param, 'visited'):
|
||||
if hasattr(param, "visited"):
|
||||
continue
|
||||
|
||||
# if shard init, then convert param to replica and use the dp-only ProcessGroup
|
||||
|
@ -179,22 +186,22 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
|||
param.set_process_group(pg)
|
||||
|
||||
# shard it w.r.t tp pattern
|
||||
if 'mlp.c_fc' in mn:
|
||||
if 'weight' in pn or 'bias' in pn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
if "mlp.c_fc" in mn:
|
||||
if "weight" in pn or "bias" in pn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
# keep the shape of the output from c_fc
|
||||
param.compute_spec.set_output_replicate(False)
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'mlp.c_proj' in mn:
|
||||
if 'weight' in pn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
elif "mlp.c_proj" in mn:
|
||||
if "weight" in pn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
elif 'wte' in mn or 'wpe' in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
elif "wte" in mn or "wpe" in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
elif "c_attn" in mn or "c_proj" in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.visited = True
|
||||
|
@ -209,7 +216,13 @@ def main():
|
|||
args = parse_args()
|
||||
|
||||
# if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
|
||||
if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]:
|
||||
if args.distplan not in [
|
||||
"CAI_ZeRO1",
|
||||
"CAI_ZeRO2",
|
||||
"CAI_Gemini",
|
||||
"Pytorch_DDP",
|
||||
"Pytorch_ZeRO",
|
||||
]:
|
||||
raise TypeError(f"{args.distplan} is error")
|
||||
|
||||
# batch size per DP degree
|
||||
|
@ -221,14 +234,18 @@ def main():
|
|||
|
||||
WARMUP_STEPS = 1
|
||||
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
|
||||
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
|
||||
PROF_FLAG = False # The flag of profiling, False by default
|
||||
assert (
|
||||
NUM_STEPS - WARMUP_STEPS
|
||||
) % 2 == 1, "the number of valid steps should be odd to take the median"
|
||||
PROF_FLAG = False # The flag of profiling, False by default
|
||||
|
||||
disable_existing_loggers()
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])
|
||||
logger.info(
|
||||
f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0]
|
||||
)
|
||||
|
||||
# build criterion
|
||||
criterion = GPTLMLoss()
|
||||
|
@ -244,10 +261,12 @@ def main():
|
|||
raise RuntimeError("You can only use shardinit with CAI_Gemini")
|
||||
|
||||
# build GPT model
|
||||
with ColoInitContext(device=get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg):
|
||||
with ColoInitContext(
|
||||
device=get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg,
|
||||
):
|
||||
model = model_builder(VOCAB_SIZE, checkpoint=True)
|
||||
|
||||
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
|
@ -259,15 +278,21 @@ def main():
|
|||
# asign running configurations
|
||||
gemini_config = None
|
||||
if args.distplan.startswith("CAI_ZeRO"):
|
||||
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
|
||||
optim_config = dict(
|
||||
reduce_bucket_size=12 * 1024 * 1024,
|
||||
overlap_communication=True,
|
||||
verbose=True,
|
||||
)
|
||||
elif args.distplan == "CAI_Gemini":
|
||||
gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
|
||||
device=get_current_device(),
|
||||
placement_policy=args.placement,
|
||||
pin_memory=True,
|
||||
hidden_dim=model.model.config.hidden_size,
|
||||
search_range_mb=128)
|
||||
optim_config = dict(gpu_margin_mem_ratio=0.)
|
||||
gemini_config = dict(
|
||||
strict_ddp_mode=args.tp_degree == 1,
|
||||
device=get_current_device(),
|
||||
placement_policy=args.placement,
|
||||
pin_memory=True,
|
||||
hidden_dim=model.model.config.hidden_size,
|
||||
search_range_mb=128,
|
||||
)
|
||||
optim_config = dict(gpu_margin_mem_ratio=0.0)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
|
@ -287,7 +312,7 @@ def main():
|
|||
model = zero_model_wrapper(model, zero_stage, gemini_config)
|
||||
optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
|
||||
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
logger.info(get_mem_info(prefix="After init optim, "), ranks=[0])
|
||||
elif args.distplan.startswith("Pytorch"):
|
||||
assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
|
||||
model = model_builder(VOCAB_SIZE, checkpoint=True).cuda()
|
||||
|
@ -296,14 +321,17 @@ def main():
|
|||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
elif args.distplan.endswith("ZeRO"):
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
|
||||
|
||||
optimizer = ZeroRedundancyOptimizer(
|
||||
model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3
|
||||
)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
# model is shared after TP
|
||||
numel = get_model_size(model)
|
||||
logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
|
||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||
logger.info(get_mem_info(prefix="After init model, "), ranks=[0])
|
||||
|
||||
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
|
||||
# = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
|
||||
|
@ -325,7 +353,7 @@ def main():
|
|||
torch.cuda.synchronize()
|
||||
fwd_end = time()
|
||||
fwd_time = fwd_end - start
|
||||
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
|
||||
logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Forward "), ranks=[0])
|
||||
|
||||
if args.distplan.startswith("CAI"):
|
||||
optimizer.backward(loss)
|
||||
|
@ -337,13 +365,15 @@ def main():
|
|||
torch.cuda.synchronize()
|
||||
bwd_end = time()
|
||||
bwd_time = bwd_end - fwd_end
|
||||
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
|
||||
logger.info(get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Backward "), ranks=[0])
|
||||
|
||||
optimizer.step()
|
||||
torch.cuda.synchronize()
|
||||
optim_time = time() - bwd_end
|
||||
step_time = time() - start
|
||||
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
|
||||
logger.info(
|
||||
get_mem_info(prefix=f"[{n + 1}/{NUM_STEPS}] Optimizer step "), ranks=[0]
|
||||
)
|
||||
|
||||
step_tflops = get_tflops_func(step_time)
|
||||
logger.info(
|
||||
|
@ -353,10 +383,12 @@ def main():
|
|||
if n >= WARMUP_STEPS:
|
||||
tflops_list.append(step_tflops)
|
||||
|
||||
demo_profiler = get_profile_context(PROF_FLAG,
|
||||
WARMUP_STEPS,
|
||||
NUM_STEPS - WARMUP_STEPS,
|
||||
save_dir=f"profile/{get_time_stamp()}-demo")
|
||||
demo_profiler = get_profile_context(
|
||||
PROF_FLAG,
|
||||
WARMUP_STEPS,
|
||||
NUM_STEPS - WARMUP_STEPS,
|
||||
save_dir=f"profile/{get_time_stamp()}-demo",
|
||||
)
|
||||
|
||||
with demo_profiler as prof:
|
||||
start_time = time()
|
||||
|
@ -364,7 +396,7 @@ def main():
|
|||
train_step()
|
||||
prof.step()
|
||||
end_time = time()
|
||||
print('total time: {}'.format(end_time - start_time))
|
||||
print("total time: {}".format(end_time - start_time))
|
||||
|
||||
tflops_list.sort()
|
||||
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
|
||||
|
@ -372,5 +404,5 @@ def main():
|
|||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -2,11 +2,15 @@ import time
|
|||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
|
||||
from torch.profiler import (
|
||||
ProfilerActivity,
|
||||
profile,
|
||||
schedule,
|
||||
tensorboard_trace_handler,
|
||||
)
|
||||
|
||||
|
||||
class DummyProfiler:
|
||||
|
||||
def __init__(self):
|
||||
self.step_number = 0
|
||||
|
||||
|
@ -16,7 +20,9 @@ class DummyProfiler:
|
|||
|
||||
# Randomly Generated Data
|
||||
def get_data(batch_size, seq_len, vocab_size):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||
input_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()
|
||||
)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
@ -27,15 +33,17 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
|
|||
|
||||
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
|
||||
if enable_flag:
|
||||
return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
|
||||
on_trace_ready=tensorboard_trace_handler(save_dir),
|
||||
record_shapes=True,
|
||||
profile_memory=True)
|
||||
return profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
|
||||
on_trace_ready=tensorboard_trace_handler(save_dir),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
)
|
||||
else:
|
||||
return nullcontext(DummyProfiler())
|
||||
|
||||
|
||||
def get_time_stamp():
|
||||
cur_time = time.strftime("%d-%H:%M", time.localtime())
|
||||
return cur_time
|
||||
return cur_time
|
||||
|
|
|
@ -22,13 +22,15 @@ vocab_size = 32000
|
|||
total_step = 100
|
||||
use_activation_ckpt = False
|
||||
|
||||
|
||||
class FakeSet(torch.utils.data.Dataset):
|
||||
def __getitem__(self, idx):
|
||||
return torch.randint(0, vocab_size, (seq_length, ))
|
||||
|
||||
return torch.randint(0, vocab_size, (seq_length,))
|
||||
|
||||
def __len__(self):
|
||||
return 1000000000
|
||||
|
||||
|
||||
class SpeedTest(pl.LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -45,7 +47,7 @@ class SpeedTest(pl.LightningModule):
|
|||
out = self.model(batch, labels=batch)
|
||||
loss = out.loss
|
||||
if self.start_time is None:
|
||||
print('start')
|
||||
print("start")
|
||||
self.start_time = time.time()
|
||||
return loss
|
||||
|
||||
|
@ -53,23 +55,26 @@ class SpeedTest(pl.LightningModule):
|
|||
optimizer = FusedAdam(self.trainer.model.parameters(), lr=1e-5)
|
||||
return optimizer
|
||||
|
||||
|
||||
model = SpeedTest()
|
||||
train_loader = torch.utils.data.DataLoader(FakeSet(), batch_size=batch_size)
|
||||
|
||||
strategy=DeepSpeedStrategy(
|
||||
strategy = DeepSpeedStrategy(
|
||||
stage=2,
|
||||
offload_optimizer=False,
|
||||
offload_parameters=False,
|
||||
process_group_backend="nccl"
|
||||
process_group_backend="nccl",
|
||||
)
|
||||
trainer = pl.Trainer(
|
||||
limit_train_batches=total_step,
|
||||
limit_train_batches=total_step,
|
||||
max_epochs=1,
|
||||
devices=8,
|
||||
accelerator="gpu",
|
||||
strategy=strategy,
|
||||
precision=16,
|
||||
enable_checkpointing=False)
|
||||
precision=16,
|
||||
enable_checkpointing=False,
|
||||
)
|
||||
|
||||
|
||||
def train(model, train_loader):
|
||||
start_time = time.time()
|
||||
|
@ -77,4 +82,5 @@ def train(model, train_loader):
|
|||
end_time = time.time()
|
||||
return end_time - model.start_time
|
||||
|
||||
print('total time: {}'.format(train(model, train_loader)))
|
||||
|
||||
print("total time: {}".format(train(model, train_loader)))
|
||||
|
|
Loading…
Reference in New Issue
Block a user