add high-performance Llama pre-train code

This commit is contained in:
LiangSong 2023-03-26 23:59:53 +08:00
parent 0fa15787b4
commit 73a81a4205
18 changed files with 858 additions and 0 deletions

132
.gitignore vendored Normal file
View File

@ -0,0 +1,132 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.DS_Store
pretrain_data/
wandb/

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,30 @@
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 12
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 1
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_backend: 'no'
# dynamo_config:
# dynamo_backend: INDUCTOR
# dynamo_mode: default
# dynamo_use_dynamic: true
# dynamo_use_fullgraph: false
fsdp_config: {}
machine_rank: 0
main_training_function: main
megatron_lm_config: {}
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Binary file not shown.

16
configs/train_config.py Normal file
View File

@ -0,0 +1,16 @@
max_length = 1024
train_batch_size = 2
num_training_steps = 1000000
num_warmup_steps = 2000
initializer_range = 1e-2
lr = 2e-4
weight_decay = 1e-1
tokenizer_model_path = 'configs/10w_vocab_wudao5_pile10.model'
patterns = [
'data/pretrain_data/part-*.jsonl.zst'
]
# global step
log_interval = 5
eval_interval = 200
save_interval = 800
work_dir = 'data/saved_ckpt/'

26
data/download_the_pile.sh Normal file
View File

@ -0,0 +1,26 @@
#!/bin/bash
###
# @Author: LiangSong(sl12160010@gmail.com)
# @Date: 2023-03-16 21:21:38
# @LastEditors: LiangSong(sl12160010@gmail.com)
# @LastEditTime: 2023-03-26 22:58:02
# @FilePath: /Open-Llama/data/download_the_pile.sh
# @Description:
# download the pile dataset and preprocess
# Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
###
start=0
end=29
mkdir data/the_pile
for (( i=$start; i<=$end; i++ ))
do
url="https://the-eye.eu/public/AI/pile/train/$(printf "%02d" $i).jsonl.zst"
echo "Downloading file: $url"
curl -C - $url -o data/the_pile/"$(printf "%02d" $i).jsonl.zst"
done
wait
echo "All files downloaded successfully."
mkdir data/pretrain_data
python3 data/preprocess_the_pile.py

19
data/download_wudao.sh Normal file
View File

@ -0,0 +1,19 @@
#!/bin/bash
###
# @Author: LiangSong(sl12160010@gmail.com)
# @Date: 2023-03-16 21:21:56
# @LastEditors: LiangSong(sl12160010@gmail.com)
# @LastEditTime: 2023-03-26 22:58:11
# @FilePath: /Open-Llama/data/download_wudao.sh
# @Description:
# download wudao dataset and preprocess
# Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
###
apt install unrar
for i in {1..100}
do
curl -C - --retry 100 'https://dorc.baai.ac.cn/resources/data/WuDaoCorpora2.0/WuDaoCorpus2.0_base_200G.rar?AccessKeyId=AKLTNasiLRBBTcOgPqzlkPzu1w&Expires=1679127659&Signature=7jh%2FpnJyC2hAeumm9EjaeE5HN9E%3D' -o data/WuDaoCorpus2.0_base_200G.rar
done
unrar x data/WuDaoCorpus2.0_base_200G.rar
mkdir data/pretrain_data
python3 data/preprocess_wudao.py

View File

@ -0,0 +1,32 @@
'''
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-16 22:35:38
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-26 22:59:38
FilePath: /Open-Llama/data/preprocess_the_pile.py
Description:
Parse the dataset from the raw files and split them into different jsonl files based on the preset maximum number of lines,
making it easy for parallel training to perform streaming reads.
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
'''
import json
from glob import glob
from tqdm import tqdm
import zstandard as zstd
paths = glob('data/the_pile/*.jsonl.zst')
write_path = 'data/pretrain_data/part-pile-{}.jsonl.zst'
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), 'wb', encoding='utf-8')
for path in tqdm(paths, total=len(paths)):
with zstd.open(path, 'r', encoding='utf-8') as fp:
for line in fp:
if total_num % 16384 == 0 and total_num > 0:
file_num += 1
wfp.close()
wfp = zstd.open(write_path.format(file_num), 'wb', encoding='utf-8')
wfp.write(line.encode('utf-8'))
total_num += 1
wfp.close()
print('total line: {}\ntotal files: {}'.format(total_num, file_num))

34
data/preprocess_wudao.py Normal file
View File

@ -0,0 +1,34 @@
'''
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-16 22:10:44
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-26 22:59:55
FilePath: /Open-Llama/data/preprocess_wudao.py
Description:
Parse the dataset from the raw files and split them into different jsonl files based on the preset maximum number of lines,
making it easy for parallel training to perform streaming reads.
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
'''
import json
from glob import glob
from tqdm import tqdm
import zstandard as zstd
paths = glob('data/WuDaoCorpus2.0_base_200G/part*')
write_path = 'data/pretrain_data/part-wudao-{}.jsonl.zst'
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), 'wb', encoding='utf-8')
for path in tqdm(paths, total=len(paths)):
with open(path, 'r') as fp:
data = json.load(fp)
for line in data:
if total_num % 16384 == 0 and total_num > 0:
file_num += 1
wfp.close()
wfp = zstd.open(write_path.format(file_num), 'wb', encoding='utf-8')
wfp.write(json.dumps(line).encode('utf-8'))
wfp.write('\n'.encode('utf-8'))
total_num += 1
wfp.close()
print('total line: {}\ntotal files: {}'.format(total_num, file_num))

92
dataset/data_iter.py Normal file
View File

@ -0,0 +1,92 @@
'''
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-17 19:32:20
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-26 23:03:32
FilePath: /Open-Llama/dataset/data_iter.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
'''
import json
from glob import glob
import zstandard as zstd
def create_data_iter(paths, transform_dict=None, process_index=0, num_processes=1):
'''
Currently, the allowed storage formats are jsonl and jsonl.zst.
Each line of the data is a dictionary, which can be parsed as JSON for subsequent processing after reading.
'''
past = None
for i, path in paths:
dataset_name = path.split('-')[-2]
if past != dataset_name:
print('Loading data from {}'.format(path))
past = path
if num_processes > 1 and i % num_processes != process_index:
continue
if path.endswith('jsonl.zst'):
with zstd.open(path, 'r', encoding='utf-8') as fp:
for line in fp:
if isinstance(line, bytes):
line = line.decode('utf-8')
line = json.loads(line)
line['dataset'] = dataset_name
if transform_dict:
line = transform_dict[dataset_name](line)
if isinstance(line, str):
yield line
elif isinstance(line, list):
for i in line:
yield i
else:
raise Exception('Unsupported type in Transformation: {}'.format(transform_dict[dataset_name]))
else:
yield line
elif path.endswith('jsonl'):
with open(path, 'r') as fp:
for line in fp:
if isinstance(line, bytes):
line = line.decode('utf-8')
line = json.loads(line)
line['dataset'] = dataset_name
if transform_dict:
line = transform_dict[dataset_name](line)
if isinstance(line, str):
yield line
elif isinstance(line, list):
for i in line:
yield i
else:
raise Exception('Unsupported type in Transformation: {}'.format(transform_dict[dataset_name]))
else:
yield line
else:
raise Exception('File format of {} is not supported yet.'.format(path))
def create_shard_kwargs(patterns, repeat=1):
'''
Assign numbers to different shards of data to ensure that data is not duplicated
when allocated to different nodes during distributed training.
'''
all_path = []
for p in patterns:
all_path.extend(glob(p))
all_path *= repeat
return [(i, p) for i, p in enumerate(all_path)]
if __name__ == '__main__':
patterns = [
'data/pretrain_data/part-wudao*.jsonl.zst'
]
paths = create_shard_kwargs(patterns)
transform_dict = {
'wudao': lambda x: x['title'],
'pile': lambda x: [x['text']]
}
data_iter = create_data_iter(paths, transform_dict=transform_dict)
for i, data in enumerate(data_iter):
print(i, data)
if i == 20:
break

View File

@ -0,0 +1,82 @@
'''
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-17 20:41:25
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-26 23:07:56
FilePath: /Open-Llama/dataset/pretrain_dataset.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
'''
import math
import torch
def preprocess_wudao_gen(tokenizer, segment_max_length=1024):
def preprocess_wudao(line):
'''
The format of the data is roughly as follows.
{'id': 1, 'dataType': '百科', 'title': 'some title', 'content': 'some content'}
Split the data based on the tokenized length according to the maximum length.
'''
total = line['title'] + '\n' + line['content']
out = tokenizer(total)
input_ids = out['input_ids']
return [input_ids[i*segment_max_length: (i+1)*segment_max_length]
for i in range(math.ceil(len(input_ids)/segment_max_length))]
return preprocess_wudao
def preprocess_the_pile_gen(tokenizer, segment_max_length=1024):
def preprocess_the_pile(line):
'''
The format of the data is roughly as follows.
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
'''
total = line['text']
out = tokenizer(total)
input_ids = out['input_ids']
return [input_ids[i*segment_max_length: (i+1)*segment_max_length]
for i in range(math.ceil(len(input_ids)/segment_max_length))]
return preprocess_the_pile
def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024):
'''
Organize data into tensors by padding based on the preset maximum length.
'''
pad_id = tokenizer.pad_id
def pretrain_collate_fn(batch):
input_ids = []
for i in batch:
input_len = len(i)
input_ids.append(i+[pad_id]*(segment_max_length-input_len))
inputs = {
'input_ids': torch.tensor(input_ids, dtype=torch.int64),
}
return inputs
return pretrain_collate_fn
if __name__ == '__main__':
import sentencepiece as spm
from datasets import IterableDataset
from torch.utils.data import DataLoader
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, create_data_iter
sp_model = spm.SentencePieceProcessor(model_file='configs/10w_vocab_wudao5_pile10.model')
tokenizer = Tokenizer(sp_model)
patterns = [
'data/pretrain_data/part-*.jsonl.zst'
]
paths = create_shard_kwargs(patterns)
transform_dict = {
'wudao': preprocess_wudao_gen(tokenizer),
'pile': preprocess_the_pile_gen(tokenizer)
}
data_set = IterableDataset.from_generator(create_data_iter, gen_kwargs={'paths': paths, 'transform_dict': transform_dict})
train_loader = DataLoader(data_set, batch_size=8, num_workers=4,
collate_fn=pretrain_collate_fn_gen(tokenizer), drop_last=True)
for batch in train_loader:
for k, v in batch.items():
print(k, v.shape)
break

146
dataset/tokenizer.py Normal file
View File

@ -0,0 +1,146 @@
'''
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-20 21:39:47
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-26 23:09:39
FilePath: /Open-Llama/dataset/tokenizer.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
'''
import torch
class Tokenizer:
def __init__(self, sp_model):
self.sp_model = sp_model
self.bos_id = self.sp_model.bos_id()
self.eos_id = self.sp_model.eos_id()
self.pad_id = self.sp_model.pad_id()
self.vocab_size = self.sp_model.vocab_size()
def __call__(self, inputs, padding=None, max_length=256, return_tensors=False, truncation=False,
add_special_tokens=True, return_mask=False):
if isinstance(inputs, str):
return self.encode(inputs, padding=padding, max_length=max_length,
return_tensors=return_tensors, truncation=truncation, add_special_tokens=add_special_tokens, return_mask=return_mask)
else:
return self.encode_batch(inputs, padding=padding, max_length=max_length,
return_tensors=return_tensors, truncation=truncation, add_special_tokens=add_special_tokens, return_mask=return_mask)
def encode(self, inputs, padding=None, max_length=8192, return_tensors=False, truncation=False,
add_special_tokens=True, return_mask=False):
assert(isinstance(inputs, str))
input_ids = self.sp_model.Encode(inputs)
if return_mask:
attention_mask = [1] * len(input_ids)
if truncation:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L780
# 参考Transformer中的实现 默认最后一位一定是pad或者eos
input_ids = input_ids[: max_length-1]
if return_mask:
attention_mask = attention_mask[: max_length-1]
if add_special_tokens:
input_ids = input_ids + [self.eos_id]
if return_mask:
attention_mask = attention_mask + [0]
if padding == 'max_length':
input_ids = input_ids + [self.pad_id] * (max_length-len(input_ids))
if return_mask:
attention_mask = attention_mask + [0] * (max_length-len(attention_mask))
if return_tensors:
input_ids = torch.tensor([input_ids])
out = {
'input_ids': input_ids,
}
if return_mask:
attention_mask = torch.tensor([attention_mask])
out['attention_mask'] = attention_mask
else:
out = {
'input_ids': input_ids,
}
if return_mask:
out['attention_mask'] = attention_mask
return out
def encode_batch(self, inputs, padding=None, max_length=8192, return_tensors=False, truncation=False,
add_special_tokens=True, return_mask=False):
input_ids = self.sp_model.Encode(inputs)
if return_mask:
attention_mask = [[1] * len(i) for i in input_ids]
if truncation:
input_ids = [i[: max_length-1] for i in input_ids]
if return_mask:
attention_mask = [i[: max_length-1] for i in attention_mask]
if add_special_tokens:
input_ids = [i+[self.eos_id] for i in input_ids]
if return_mask:
attention_mask = [i+[0] for i in attention_mask]
if padding == 'max_length':
input_ids_pad = []
if return_mask:
attention_mask_pad = []
for idx, i in enumerate(input_ids):
input_ids_pad.append(i + [self.pad_id] * (max_length-len(i)))
if return_mask:
j = attention_mask[idx]
attention_mask_pad.append(j + [0] * (max_length-len(j)))
input_ids = input_ids_pad
if return_mask:
attention_mask = attention_mask_pad
if return_tensors:
input_ids = torch.tensor(input_ids)
out = {
'input_ids': input_ids,
}
if return_mask:
attention_mask = torch.tensor(attention_mask)
out['attention_mask'] = attention_mask
else:
out = {
'input_ids': input_ids,
}
if return_mask:
out['attention_mask'] = attention_mask
return out
def decode(self, inputs):
inputs = inputs.tolist()
out = []
for i in inputs:
if self.eos_id in i:
eos_idx = i.index(self.eos_id)
i = i[: eos_idx]
out.append(i)
out = self.sp_model.Decode(out)
return out
if __name__ == '__main__':
import sentencepiece as spm
from unicodedata import normalize
# Using sentencepiece may not be able to process some reserved keywords like '▁'.
sp_model = spm.SentencePieceProcessor(model_file='configs/10w_vocab_wudao5_pile10.model')
tokenizer = Tokenizer(sp_model)
tmp = ['hello world',
'这是开源项目的V1版本this is the first version of a open-source project!',
'# this is a python script\nfor i in range(10):\n print(i)\n for j in range(10):\n print(j)']
print(tmp)
out = tokenizer(tmp, padding='max_length', return_tensors=True, max_length=64, truncation=True)
for k, v in out.items():
print(k, v.shape)
print(out['input_ids'])
out = tokenizer.decode(out['input_ids'])
print(out)
for i, j in zip(tmp, out):
assert(normalize('NFKC', i) == j)
from dataset.data_iter import create_shard_kwargs, create_data_iter
patterns = [
'data/pretrain_data/part-wudao*.jsonl.zst'
]
paths = create_shard_kwargs(patterns)
data_iter = create_data_iter(paths)
for i, data in enumerate(data_iter):
assert(normalize('NFKC', data['content']) == sp_model.Decode(sp_model.Encode(data['content'])) or '' in data['content'])
if i == 1000:
break

View File

@ -0,0 +1,53 @@
'''
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-24 20:49:03
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-26 23:43:59
FilePath: /Open-Llama/dataset/train_tokenizer.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
'''
import random
from dataset.data_iter import create_data_iter, create_shard_kwargs
wudao_patterns = [
'data/pretrain_data/part-wudao-*.jsonl.zst',
]
wudao_paths = create_shard_kwargs(wudao_patterns)
random.shuffle(wudao_paths)
pile_patterns = [
'data/pretrain_data/part-pile-*.jsonl.zst',
]
pile_paths = create_shard_kwargs(pile_patterns)
random.shuffle(pile_paths)
paths = wudao_paths[: 5] + pile_paths[: 10]
transform_dict = {
'wudao': lambda line: [(line['title'] + '\n' + line['content'])],
'pile': lambda line: [line['text']]
}
data_iter = create_data_iter(paths, transform_dict)
import io
import sentencepiece as spm
# Loads model from URL as iterator and stores the model to BytesIO.
model = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=data_iter, model_writer=model, shuffle_input_sentence=False, train_extremely_large_corpus=True,
# hyperparameters of tokenizer
max_sentence_length=16384, pad_id=3, model_type='BPE', vocab_size=100000,
# split digits and fallback to byte same as Llama.
# set split_by_unicode_script to True to avoid grouping punctuation and characters together.
split_digits=True, split_by_unicode_script=True, byte_fallback=True,
# reserve whitespace and \n and \t etc. for code generation
allow_whitespace_only_pieces=True, remove_extra_whitespaces=False, normalization_rule_name='nfkc')
# Serialize the model as file.
with open('configs/10w_vocab_wudao5_pile10.model', 'wb') as f:
f.write(model.getvalue())
# Directly load the model from serialized model.
sp = spm.SentencePieceProcessor(model_proto=model.getvalue())
print(sp.decode(sp.encode('只因你太美🤗▃ \n 1')))

17
dataset/validation.py Normal file
View File

@ -0,0 +1,17 @@
val_set = [
'白日依山尽,',
'君不见,黄河之水天上来,奔流到海不复回。君不见,',
'秦孝公据崤函之固,拥雍州之地,君臣固守以窥周室,有席卷天下,包举宇内,囊括四海之意,并吞八荒之心。',
'古之学者必有师。师者,所以传道受业解惑也。人非生而知之者,孰能无惑?',
'当我醒来时,我发现自己在一个完全陌生的地方。我看到周围没有人,只有一张纸条。',
'这是一个斗气决定一切的大陆。在加玛帝国乌坦城,有个天才少年萧炎打破了所有族人的修炼纪录,一时间万人敬仰,众人艳羡。但不知为何,',
'人工智能技术在图像识别领域取得了很大的进展,然而在复杂场景下仍然存在一些问题,例如',
'In recent years, there has been increasing interest in the use of machine learning to',
'已知三个数分别为1, 2, 3求它们的平均数是',
'小明总共有15个苹果他分别给了3个人两个苹果然后自己又吃了一个苹果那么它还剩几个苹果',
'根据牛顿第二定律,物体的加速度等于',
'碳纳米管是一种新型的材料,具有非常独特的电学和光学性质。在过去的几年中,我们对碳纳',
'下面是一段用python写的快速排序的代码:',
'The quantum many-body problem is a fundamental problem in condensed matter physics. Despite decades of research, there is still no exact solution to this problem for large systems. In this paper, we propose a novel approach based on',
'下面是一个使用 PyTorch 和 Transformer 的示例代码用于训练一个文本分类模型import torch\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader, Dataset'
]

12
models/llama.py Normal file
View File

@ -0,0 +1,12 @@
'''
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-17 13:21:33
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-26 23:13:57
FilePath: /Open-Llama/models/llama.py
Description:
Building the Llama model proposed by Meta. https://arxiv.org/pdf/2302.13971.pdf
Performance and effectiveness optimization based on the implementation in the Transformer library.
https://github.com/Bayes-Song/transformers
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
'''

148
pretrain_llama.py Normal file
View File

@ -0,0 +1,148 @@
'''
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-17 14:27:28
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-26 23:33:41
FilePath: /Open-Llama/pretrain_llama.py
Description:
pretrain GPT
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
'''
import os
import time
import wandb
import torch
import random
import sentencepiece as spm
from torchinfo import summary
from accelerate import Accelerator
from datasets import IterableDataset
from torch.utils.data import DataLoader
from deepspeed.ops.adam import FusedAdam
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
from dataset.validation import val_set
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, create_data_iter
from dataset.pretrain_dataset import preprocess_the_pile_gen, preprocess_wudao_gen, pretrain_collate_fn_gen
from configs.train_config import *
accelerator = Accelerator()
if accelerator.is_main_process:
wandb.init(
project='LLAMA Pretrain'
)
log_interval *= accelerator.gradient_accumulation_steps
eval_interval *= accelerator.gradient_accumulation_steps
save_interval *= accelerator.gradient_accumulation_steps
sp_model = spm.SentencePieceProcessor(model_file=tokenizer_model_path)
tokenizer = Tokenizer(sp_model)
paths = create_shard_kwargs(patterns)
random.shuffle(paths)
transform_dict = {
'wudao': preprocess_wudao_gen(tokenizer, max_length),
'pile': preprocess_the_pile_gen(tokenizer, max_length)
}
data_set = IterableDataset.from_generator(create_data_iter, gen_kwargs={
'paths': paths,
'transform_dict': transform_dict,
'process_index': accelerator.process_index,
'num_processes': accelerator.num_processes
})
train_loader = DataLoader(data_set, batch_size=train_batch_size, num_workers=1,
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length), drop_last=True)
# smaller initializer_range make training more stable
# add stabel embedding to token embedding
raw_model = LlamaForCausalLM(LlamaConfig(vocab_size=tokenizer.vocab_size,
initializer_range=initializer_range,
pad_token_id=tokenizer.pad_id,
rms_norm_eps=1e-5,
hidden_dropout_prob=0.1,
attention_dropout_prob=0.1,
use_stable_embedding=True,
shared_input_output_embedding=True))
raw_model.eval()
with torch.no_grad():
summary(raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda())
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in raw_model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": weight_decay,
},
{
"params": [p for n, p in raw_model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optim = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=(0.9, 0.95))
optim.zero_grad()
factor = accelerator.num_processes / accelerator.gradient_accumulation_steps
scheduler = get_cosine_schedule_with_warmup(optim, num_warmup_steps=num_warmup_steps * factor,
num_training_steps=num_training_steps * factor)
_, model, optim, scheduler = accelerator.prepare(
train_loader, raw_model, optim, scheduler
)
print('start training...')
train_loader_iter = iter(train_loader)
global_step = 0
start_time = time.time()
for data_step in range(num_training_steps):
model.train()
with accelerator.accumulate(model):
batch = next(train_loader_iter)
for k, v in batch.items():
batch[k] = v.to(accelerator.device)
labels = batch['input_ids'].clone()
labels[labels==tokenizer.pad_id] = -100
out = model(**batch, labels=labels)
total_loss = out.loss
losses = {
'total_loss': total_loss
}
accelerator.backward(total_loss)
optim.step()
scheduler.step()
optim.zero_grad()
if accelerator.sync_gradients:
global_step += 1
if data_step % log_interval == 0 and accelerator.is_main_process:
cost_time = time.time() - start_time
start_time = time.time()
tokens = train_batch_size * log_interval * max_length
wandb.log({'Training/Token per second per gpu': tokens/cost_time})
for k, v in losses.items():
wandb.log({'Losses/{}'.format(k): v})
current_lr = optim.param_groups[0]['lr']
wandb.log({'Training/LR': current_lr})
if optim.scaler is not None:
wandb.log({'Training/Loss Scale': optim.scaler.get_scale()})
wandb.log({'Training/Data Step': data_step})
wandb.log({'Training/Global Step': global_step})
accelerator.print('Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}'.format(
global_step, data_step, losses['total_loss'], tokens/cost_time))
if data_step % eval_interval == 0 and accelerator.is_main_process:
text_table = wandb.Table(columns=['question', 'pred'])
model.eval()
with torch.no_grad():
for data in val_set:
raw_inputs = data
inputs_len = len(raw_inputs)
inputs = tokenizer(raw_inputs, return_tensors=True, add_special_tokens=False)
for k, v in inputs.items():
inputs[k] = v.to(accelerator.device)
pred = model.generate(**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0)
pred = tokenizer.decode(pred.cpu())[0]
pred = pred[inputs_len:]
text_table.add_data(raw_inputs, pred)
wandb.log({'Predictions on {}'.format(global_step) : text_table})
if data_step % save_interval == 0 and data_step > 0 and accelerator.is_main_process:
if not os.path.isdir(work_dir):
os.mkdir(work_dir)
torch.save(raw_model.state_dict(), '{}/{}.pt'.format(work_dir, global_step))
wandb.finish()

19
requirements.txt Normal file
View File

@ -0,0 +1,19 @@
torch==1.13.1
torchvision
torchaudio
zstandard
accelerate
datasets
wandb
deepspeed
absl-py
torchinfo
scikit-learn
datasets==2.10.1
matplotlib
seaborn
sentencepiece
triton
functorch==1.13.1
xformers
git+https://github.com/Bayes-Song/transformers.git