add high-performance Llama pre-train code
This commit is contained in:
parent
0fa15787b4
commit
73a81a4205
132
.gitignore
vendored
Normal file
132
.gitignore
vendored
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
.DS_Store
|
||||||
|
pretrain_data/
|
||||||
|
wandb/
|
BIN
configs/10w_vocab_wudao5_pile10.model
Normal file
BIN
configs/10w_vocab_wudao5_pile10.model
Normal file
Binary file not shown.
BIN
configs/6w_vocab_wudao5_pile10.model
Normal file
BIN
configs/6w_vocab_wudao5_pile10.model
Normal file
Binary file not shown.
30
configs/default_config.yaml
Normal file
30
configs/default_config.yaml
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
deepspeed_config:
|
||||||
|
deepspeed_multinode_launcher: standard
|
||||||
|
gradient_accumulation_steps: 12
|
||||||
|
gradient_clipping: 1.0
|
||||||
|
offload_optimizer_device: none
|
||||||
|
offload_param_device: none
|
||||||
|
zero3_init_flag: false
|
||||||
|
zero_stage: 1
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
dynamo_backend: 'no'
|
||||||
|
# dynamo_config:
|
||||||
|
# dynamo_backend: INDUCTOR
|
||||||
|
# dynamo_mode: default
|
||||||
|
# dynamo_use_dynamic: true
|
||||||
|
# dynamo_use_fullgraph: false
|
||||||
|
fsdp_config: {}
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
megatron_lm_config: {}
|
||||||
|
mixed_precision: bf16
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 8
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
BIN
configs/llama_tokenizer.model
Normal file
BIN
configs/llama_tokenizer.model
Normal file
Binary file not shown.
16
configs/train_config.py
Normal file
16
configs/train_config.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
max_length = 1024
|
||||||
|
train_batch_size = 2
|
||||||
|
num_training_steps = 1000000
|
||||||
|
num_warmup_steps = 2000
|
||||||
|
initializer_range = 1e-2
|
||||||
|
lr = 2e-4
|
||||||
|
weight_decay = 1e-1
|
||||||
|
tokenizer_model_path = 'configs/10w_vocab_wudao5_pile10.model'
|
||||||
|
patterns = [
|
||||||
|
'data/pretrain_data/part-*.jsonl.zst'
|
||||||
|
]
|
||||||
|
# global step
|
||||||
|
log_interval = 5
|
||||||
|
eval_interval = 200
|
||||||
|
save_interval = 800
|
||||||
|
work_dir = 'data/saved_ckpt/'
|
26
data/download_the_pile.sh
Normal file
26
data/download_the_pile.sh
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
#!/bin/bash
|
||||||
|
###
|
||||||
|
# @Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
# @Date: 2023-03-16 21:21:38
|
||||||
|
# @LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
# @LastEditTime: 2023-03-26 22:58:02
|
||||||
|
# @FilePath: /Open-Llama/data/download_the_pile.sh
|
||||||
|
# @Description:
|
||||||
|
# download the pile dataset and preprocess
|
||||||
|
# Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
###
|
||||||
|
start=0
|
||||||
|
end=29
|
||||||
|
mkdir data/the_pile
|
||||||
|
for (( i=$start; i<=$end; i++ ))
|
||||||
|
do
|
||||||
|
url="https://the-eye.eu/public/AI/pile/train/$(printf "%02d" $i).jsonl.zst"
|
||||||
|
echo "Downloading file: $url"
|
||||||
|
curl -C - $url -o data/the_pile/"$(printf "%02d" $i).jsonl.zst"
|
||||||
|
done
|
||||||
|
|
||||||
|
wait
|
||||||
|
|
||||||
|
echo "All files downloaded successfully."
|
||||||
|
mkdir data/pretrain_data
|
||||||
|
python3 data/preprocess_the_pile.py
|
19
data/download_wudao.sh
Normal file
19
data/download_wudao.sh
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
#!/bin/bash
|
||||||
|
###
|
||||||
|
# @Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
# @Date: 2023-03-16 21:21:56
|
||||||
|
# @LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
# @LastEditTime: 2023-03-26 22:58:11
|
||||||
|
# @FilePath: /Open-Llama/data/download_wudao.sh
|
||||||
|
# @Description:
|
||||||
|
# download wudao dataset and preprocess
|
||||||
|
# Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
###
|
||||||
|
apt install unrar
|
||||||
|
for i in {1..100}
|
||||||
|
do
|
||||||
|
curl -C - --retry 100 'https://dorc.baai.ac.cn/resources/data/WuDaoCorpora2.0/WuDaoCorpus2.0_base_200G.rar?AccessKeyId=AKLTNasiLRBBTcOgPqzlkPzu1w&Expires=1679127659&Signature=7jh%2FpnJyC2hAeumm9EjaeE5HN9E%3D' -o data/WuDaoCorpus2.0_base_200G.rar
|
||||||
|
done
|
||||||
|
unrar x data/WuDaoCorpus2.0_base_200G.rar
|
||||||
|
mkdir data/pretrain_data
|
||||||
|
python3 data/preprocess_wudao.py
|
32
data/preprocess_the_pile.py
Normal file
32
data/preprocess_the_pile.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
'''
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-03-16 22:35:38
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-03-26 22:59:38
|
||||||
|
FilePath: /Open-Llama/data/preprocess_the_pile.py
|
||||||
|
Description:
|
||||||
|
Parse the dataset from the raw files and split them into different jsonl files based on the preset maximum number of lines,
|
||||||
|
making it easy for parallel training to perform streaming reads.
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
'''
|
||||||
|
import json
|
||||||
|
from glob import glob
|
||||||
|
from tqdm import tqdm
|
||||||
|
import zstandard as zstd
|
||||||
|
|
||||||
|
paths = glob('data/the_pile/*.jsonl.zst')
|
||||||
|
write_path = 'data/pretrain_data/part-pile-{}.jsonl.zst'
|
||||||
|
total_num = 0
|
||||||
|
file_num = 0
|
||||||
|
wfp = zstd.open(write_path.format(file_num), 'wb', encoding='utf-8')
|
||||||
|
for path in tqdm(paths, total=len(paths)):
|
||||||
|
with zstd.open(path, 'r', encoding='utf-8') as fp:
|
||||||
|
for line in fp:
|
||||||
|
if total_num % 16384 == 0 and total_num > 0:
|
||||||
|
file_num += 1
|
||||||
|
wfp.close()
|
||||||
|
wfp = zstd.open(write_path.format(file_num), 'wb', encoding='utf-8')
|
||||||
|
wfp.write(line.encode('utf-8'))
|
||||||
|
total_num += 1
|
||||||
|
wfp.close()
|
||||||
|
print('total line: {}\ntotal files: {}'.format(total_num, file_num))
|
34
data/preprocess_wudao.py
Normal file
34
data/preprocess_wudao.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
'''
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-03-16 22:10:44
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-03-26 22:59:55
|
||||||
|
FilePath: /Open-Llama/data/preprocess_wudao.py
|
||||||
|
Description:
|
||||||
|
Parse the dataset from the raw files and split them into different jsonl files based on the preset maximum number of lines,
|
||||||
|
making it easy for parallel training to perform streaming reads.
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
'''
|
||||||
|
import json
|
||||||
|
from glob import glob
|
||||||
|
from tqdm import tqdm
|
||||||
|
import zstandard as zstd
|
||||||
|
|
||||||
|
paths = glob('data/WuDaoCorpus2.0_base_200G/part*')
|
||||||
|
write_path = 'data/pretrain_data/part-wudao-{}.jsonl.zst'
|
||||||
|
total_num = 0
|
||||||
|
file_num = 0
|
||||||
|
wfp = zstd.open(write_path.format(file_num), 'wb', encoding='utf-8')
|
||||||
|
for path in tqdm(paths, total=len(paths)):
|
||||||
|
with open(path, 'r') as fp:
|
||||||
|
data = json.load(fp)
|
||||||
|
for line in data:
|
||||||
|
if total_num % 16384 == 0 and total_num > 0:
|
||||||
|
file_num += 1
|
||||||
|
wfp.close()
|
||||||
|
wfp = zstd.open(write_path.format(file_num), 'wb', encoding='utf-8')
|
||||||
|
wfp.write(json.dumps(line).encode('utf-8'))
|
||||||
|
wfp.write('\n'.encode('utf-8'))
|
||||||
|
total_num += 1
|
||||||
|
wfp.close()
|
||||||
|
print('total line: {}\ntotal files: {}'.format(total_num, file_num))
|
92
dataset/data_iter.py
Normal file
92
dataset/data_iter.py
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
'''
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-03-17 19:32:20
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-03-26 23:03:32
|
||||||
|
FilePath: /Open-Llama/dataset/data_iter.py
|
||||||
|
Description:
|
||||||
|
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
'''
|
||||||
|
import json
|
||||||
|
from glob import glob
|
||||||
|
import zstandard as zstd
|
||||||
|
|
||||||
|
|
||||||
|
def create_data_iter(paths, transform_dict=None, process_index=0, num_processes=1):
|
||||||
|
'''
|
||||||
|
Currently, the allowed storage formats are jsonl and jsonl.zst.
|
||||||
|
Each line of the data is a dictionary, which can be parsed as JSON for subsequent processing after reading.
|
||||||
|
'''
|
||||||
|
past = None
|
||||||
|
for i, path in paths:
|
||||||
|
dataset_name = path.split('-')[-2]
|
||||||
|
if past != dataset_name:
|
||||||
|
print('Loading data from {}'.format(path))
|
||||||
|
past = path
|
||||||
|
if num_processes > 1 and i % num_processes != process_index:
|
||||||
|
continue
|
||||||
|
if path.endswith('jsonl.zst'):
|
||||||
|
with zstd.open(path, 'r', encoding='utf-8') as fp:
|
||||||
|
for line in fp:
|
||||||
|
if isinstance(line, bytes):
|
||||||
|
line = line.decode('utf-8')
|
||||||
|
line = json.loads(line)
|
||||||
|
line['dataset'] = dataset_name
|
||||||
|
if transform_dict:
|
||||||
|
line = transform_dict[dataset_name](line)
|
||||||
|
if isinstance(line, str):
|
||||||
|
yield line
|
||||||
|
elif isinstance(line, list):
|
||||||
|
for i in line:
|
||||||
|
yield i
|
||||||
|
else:
|
||||||
|
raise Exception('Unsupported type in Transformation: {}'.format(transform_dict[dataset_name]))
|
||||||
|
else:
|
||||||
|
yield line
|
||||||
|
elif path.endswith('jsonl'):
|
||||||
|
with open(path, 'r') as fp:
|
||||||
|
for line in fp:
|
||||||
|
if isinstance(line, bytes):
|
||||||
|
line = line.decode('utf-8')
|
||||||
|
line = json.loads(line)
|
||||||
|
line['dataset'] = dataset_name
|
||||||
|
if transform_dict:
|
||||||
|
line = transform_dict[dataset_name](line)
|
||||||
|
if isinstance(line, str):
|
||||||
|
yield line
|
||||||
|
elif isinstance(line, list):
|
||||||
|
for i in line:
|
||||||
|
yield i
|
||||||
|
else:
|
||||||
|
raise Exception('Unsupported type in Transformation: {}'.format(transform_dict[dataset_name]))
|
||||||
|
else:
|
||||||
|
yield line
|
||||||
|
else:
|
||||||
|
raise Exception('File format of {} is not supported yet.'.format(path))
|
||||||
|
|
||||||
|
def create_shard_kwargs(patterns, repeat=1):
|
||||||
|
'''
|
||||||
|
Assign numbers to different shards of data to ensure that data is not duplicated
|
||||||
|
when allocated to different nodes during distributed training.
|
||||||
|
'''
|
||||||
|
all_path = []
|
||||||
|
for p in patterns:
|
||||||
|
all_path.extend(glob(p))
|
||||||
|
all_path *= repeat
|
||||||
|
return [(i, p) for i, p in enumerate(all_path)]
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
patterns = [
|
||||||
|
'data/pretrain_data/part-wudao*.jsonl.zst'
|
||||||
|
]
|
||||||
|
paths = create_shard_kwargs(patterns)
|
||||||
|
transform_dict = {
|
||||||
|
'wudao': lambda x: x['title'],
|
||||||
|
'pile': lambda x: [x['text']]
|
||||||
|
}
|
||||||
|
data_iter = create_data_iter(paths, transform_dict=transform_dict)
|
||||||
|
for i, data in enumerate(data_iter):
|
||||||
|
print(i, data)
|
||||||
|
if i == 20:
|
||||||
|
break
|
82
dataset/pretrain_dataset.py
Normal file
82
dataset/pretrain_dataset.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
'''
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-03-17 20:41:25
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-03-26 23:07:56
|
||||||
|
FilePath: /Open-Llama/dataset/pretrain_dataset.py
|
||||||
|
Description:
|
||||||
|
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
'''
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def preprocess_wudao_gen(tokenizer, segment_max_length=1024):
|
||||||
|
def preprocess_wudao(line):
|
||||||
|
'''
|
||||||
|
The format of the data is roughly as follows.
|
||||||
|
{'id': 1, 'dataType': '百科', 'title': 'some title', 'content': 'some content'}
|
||||||
|
Split the data based on the tokenized length according to the maximum length.
|
||||||
|
'''
|
||||||
|
total = line['title'] + '\n' + line['content']
|
||||||
|
out = tokenizer(total)
|
||||||
|
input_ids = out['input_ids']
|
||||||
|
return [input_ids[i*segment_max_length: (i+1)*segment_max_length]
|
||||||
|
for i in range(math.ceil(len(input_ids)/segment_max_length))]
|
||||||
|
return preprocess_wudao
|
||||||
|
|
||||||
|
def preprocess_the_pile_gen(tokenizer, segment_max_length=1024):
|
||||||
|
def preprocess_the_pile(line):
|
||||||
|
'''
|
||||||
|
The format of the data is roughly as follows.
|
||||||
|
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||||
|
Split the data based on the tokenized length according to the maximum length.
|
||||||
|
'''
|
||||||
|
total = line['text']
|
||||||
|
out = tokenizer(total)
|
||||||
|
input_ids = out['input_ids']
|
||||||
|
return [input_ids[i*segment_max_length: (i+1)*segment_max_length]
|
||||||
|
for i in range(math.ceil(len(input_ids)/segment_max_length))]
|
||||||
|
return preprocess_the_pile
|
||||||
|
|
||||||
|
def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024):
|
||||||
|
'''
|
||||||
|
Organize data into tensors by padding based on the preset maximum length.
|
||||||
|
'''
|
||||||
|
pad_id = tokenizer.pad_id
|
||||||
|
def pretrain_collate_fn(batch):
|
||||||
|
input_ids = []
|
||||||
|
for i in batch:
|
||||||
|
input_len = len(i)
|
||||||
|
input_ids.append(i+[pad_id]*(segment_max_length-input_len))
|
||||||
|
inputs = {
|
||||||
|
'input_ids': torch.tensor(input_ids, dtype=torch.int64),
|
||||||
|
}
|
||||||
|
return inputs
|
||||||
|
return pretrain_collate_fn
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sentencepiece as spm
|
||||||
|
from datasets import IterableDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from dataset.tokenizer import Tokenizer
|
||||||
|
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
||||||
|
|
||||||
|
sp_model = spm.SentencePieceProcessor(model_file='configs/10w_vocab_wudao5_pile10.model')
|
||||||
|
tokenizer = Tokenizer(sp_model)
|
||||||
|
patterns = [
|
||||||
|
'data/pretrain_data/part-*.jsonl.zst'
|
||||||
|
]
|
||||||
|
paths = create_shard_kwargs(patterns)
|
||||||
|
transform_dict = {
|
||||||
|
'wudao': preprocess_wudao_gen(tokenizer),
|
||||||
|
'pile': preprocess_the_pile_gen(tokenizer)
|
||||||
|
}
|
||||||
|
data_set = IterableDataset.from_generator(create_data_iter, gen_kwargs={'paths': paths, 'transform_dict': transform_dict})
|
||||||
|
train_loader = DataLoader(data_set, batch_size=8, num_workers=4,
|
||||||
|
collate_fn=pretrain_collate_fn_gen(tokenizer), drop_last=True)
|
||||||
|
for batch in train_loader:
|
||||||
|
for k, v in batch.items():
|
||||||
|
print(k, v.shape)
|
||||||
|
break
|
146
dataset/tokenizer.py
Normal file
146
dataset/tokenizer.py
Normal file
|
@ -0,0 +1,146 @@
|
||||||
|
'''
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-03-20 21:39:47
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-03-26 23:09:39
|
||||||
|
FilePath: /Open-Llama/dataset/tokenizer.py
|
||||||
|
Description:
|
||||||
|
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
'''
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
def __init__(self, sp_model):
|
||||||
|
self.sp_model = sp_model
|
||||||
|
self.bos_id = self.sp_model.bos_id()
|
||||||
|
self.eos_id = self.sp_model.eos_id()
|
||||||
|
self.pad_id = self.sp_model.pad_id()
|
||||||
|
self.vocab_size = self.sp_model.vocab_size()
|
||||||
|
|
||||||
|
def __call__(self, inputs, padding=None, max_length=256, return_tensors=False, truncation=False,
|
||||||
|
add_special_tokens=True, return_mask=False):
|
||||||
|
if isinstance(inputs, str):
|
||||||
|
return self.encode(inputs, padding=padding, max_length=max_length,
|
||||||
|
return_tensors=return_tensors, truncation=truncation, add_special_tokens=add_special_tokens, return_mask=return_mask)
|
||||||
|
else:
|
||||||
|
return self.encode_batch(inputs, padding=padding, max_length=max_length,
|
||||||
|
return_tensors=return_tensors, truncation=truncation, add_special_tokens=add_special_tokens, return_mask=return_mask)
|
||||||
|
|
||||||
|
def encode(self, inputs, padding=None, max_length=8192, return_tensors=False, truncation=False,
|
||||||
|
add_special_tokens=True, return_mask=False):
|
||||||
|
assert(isinstance(inputs, str))
|
||||||
|
input_ids = self.sp_model.Encode(inputs)
|
||||||
|
if return_mask:
|
||||||
|
attention_mask = [1] * len(input_ids)
|
||||||
|
if truncation:
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L780
|
||||||
|
# 参考Transformer中的实现 默认最后一位一定是pad或者eos
|
||||||
|
input_ids = input_ids[: max_length-1]
|
||||||
|
if return_mask:
|
||||||
|
attention_mask = attention_mask[: max_length-1]
|
||||||
|
if add_special_tokens:
|
||||||
|
input_ids = input_ids + [self.eos_id]
|
||||||
|
if return_mask:
|
||||||
|
attention_mask = attention_mask + [0]
|
||||||
|
if padding == 'max_length':
|
||||||
|
input_ids = input_ids + [self.pad_id] * (max_length-len(input_ids))
|
||||||
|
if return_mask:
|
||||||
|
attention_mask = attention_mask + [0] * (max_length-len(attention_mask))
|
||||||
|
if return_tensors:
|
||||||
|
input_ids = torch.tensor([input_ids])
|
||||||
|
out = {
|
||||||
|
'input_ids': input_ids,
|
||||||
|
}
|
||||||
|
if return_mask:
|
||||||
|
attention_mask = torch.tensor([attention_mask])
|
||||||
|
out['attention_mask'] = attention_mask
|
||||||
|
else:
|
||||||
|
out = {
|
||||||
|
'input_ids': input_ids,
|
||||||
|
}
|
||||||
|
if return_mask:
|
||||||
|
out['attention_mask'] = attention_mask
|
||||||
|
return out
|
||||||
|
|
||||||
|
def encode_batch(self, inputs, padding=None, max_length=8192, return_tensors=False, truncation=False,
|
||||||
|
add_special_tokens=True, return_mask=False):
|
||||||
|
input_ids = self.sp_model.Encode(inputs)
|
||||||
|
if return_mask:
|
||||||
|
attention_mask = [[1] * len(i) for i in input_ids]
|
||||||
|
if truncation:
|
||||||
|
input_ids = [i[: max_length-1] for i in input_ids]
|
||||||
|
if return_mask:
|
||||||
|
attention_mask = [i[: max_length-1] for i in attention_mask]
|
||||||
|
if add_special_tokens:
|
||||||
|
input_ids = [i+[self.eos_id] for i in input_ids]
|
||||||
|
if return_mask:
|
||||||
|
attention_mask = [i+[0] for i in attention_mask]
|
||||||
|
if padding == 'max_length':
|
||||||
|
input_ids_pad = []
|
||||||
|
if return_mask:
|
||||||
|
attention_mask_pad = []
|
||||||
|
for idx, i in enumerate(input_ids):
|
||||||
|
input_ids_pad.append(i + [self.pad_id] * (max_length-len(i)))
|
||||||
|
if return_mask:
|
||||||
|
j = attention_mask[idx]
|
||||||
|
attention_mask_pad.append(j + [0] * (max_length-len(j)))
|
||||||
|
input_ids = input_ids_pad
|
||||||
|
if return_mask:
|
||||||
|
attention_mask = attention_mask_pad
|
||||||
|
if return_tensors:
|
||||||
|
input_ids = torch.tensor(input_ids)
|
||||||
|
out = {
|
||||||
|
'input_ids': input_ids,
|
||||||
|
}
|
||||||
|
if return_mask:
|
||||||
|
attention_mask = torch.tensor(attention_mask)
|
||||||
|
out['attention_mask'] = attention_mask
|
||||||
|
else:
|
||||||
|
out = {
|
||||||
|
'input_ids': input_ids,
|
||||||
|
}
|
||||||
|
if return_mask:
|
||||||
|
out['attention_mask'] = attention_mask
|
||||||
|
return out
|
||||||
|
|
||||||
|
def decode(self, inputs):
|
||||||
|
inputs = inputs.tolist()
|
||||||
|
out = []
|
||||||
|
for i in inputs:
|
||||||
|
if self.eos_id in i:
|
||||||
|
eos_idx = i.index(self.eos_id)
|
||||||
|
i = i[: eos_idx]
|
||||||
|
out.append(i)
|
||||||
|
out = self.sp_model.Decode(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sentencepiece as spm
|
||||||
|
from unicodedata import normalize
|
||||||
|
# Using sentencepiece may not be able to process some reserved keywords like '▁'.
|
||||||
|
sp_model = spm.SentencePieceProcessor(model_file='configs/10w_vocab_wudao5_pile10.model')
|
||||||
|
tokenizer = Tokenizer(sp_model)
|
||||||
|
tmp = ['hello world',
|
||||||
|
'这是开源项目的V1版本,this is the first version of a open-source project!',
|
||||||
|
'# this is a python script\nfor i in range(10):\n print(i)\n for j in range(10):\n print(j)']
|
||||||
|
print(tmp)
|
||||||
|
out = tokenizer(tmp, padding='max_length', return_tensors=True, max_length=64, truncation=True)
|
||||||
|
for k, v in out.items():
|
||||||
|
print(k, v.shape)
|
||||||
|
print(out['input_ids'])
|
||||||
|
out = tokenizer.decode(out['input_ids'])
|
||||||
|
print(out)
|
||||||
|
for i, j in zip(tmp, out):
|
||||||
|
assert(normalize('NFKC', i) == j)
|
||||||
|
|
||||||
|
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
||||||
|
patterns = [
|
||||||
|
'data/pretrain_data/part-wudao*.jsonl.zst'
|
||||||
|
]
|
||||||
|
paths = create_shard_kwargs(patterns)
|
||||||
|
data_iter = create_data_iter(paths)
|
||||||
|
for i, data in enumerate(data_iter):
|
||||||
|
assert(normalize('NFKC', data['content']) == sp_model.Decode(sp_model.Encode(data['content'])) or '▁' in data['content'])
|
||||||
|
if i == 1000:
|
||||||
|
break
|
53
dataset/train_tokenizer.py
Normal file
53
dataset/train_tokenizer.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
'''
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-03-24 20:49:03
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-03-26 23:43:59
|
||||||
|
FilePath: /Open-Llama/dataset/train_tokenizer.py
|
||||||
|
Description:
|
||||||
|
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
'''
|
||||||
|
import random
|
||||||
|
from dataset.data_iter import create_data_iter, create_shard_kwargs
|
||||||
|
|
||||||
|
wudao_patterns = [
|
||||||
|
'data/pretrain_data/part-wudao-*.jsonl.zst',
|
||||||
|
]
|
||||||
|
wudao_paths = create_shard_kwargs(wudao_patterns)
|
||||||
|
random.shuffle(wudao_paths)
|
||||||
|
|
||||||
|
pile_patterns = [
|
||||||
|
'data/pretrain_data/part-pile-*.jsonl.zst',
|
||||||
|
]
|
||||||
|
pile_paths = create_shard_kwargs(pile_patterns)
|
||||||
|
random.shuffle(pile_paths)
|
||||||
|
paths = wudao_paths[: 5] + pile_paths[: 10]
|
||||||
|
transform_dict = {
|
||||||
|
'wudao': lambda line: [(line['title'] + '\n' + line['content'])],
|
||||||
|
'pile': lambda line: [line['text']]
|
||||||
|
}
|
||||||
|
data_iter = create_data_iter(paths, transform_dict)
|
||||||
|
|
||||||
|
import io
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
# Loads model from URL as iterator and stores the model to BytesIO.
|
||||||
|
model = io.BytesIO()
|
||||||
|
spm.SentencePieceTrainer.train(
|
||||||
|
sentence_iterator=data_iter, model_writer=model, shuffle_input_sentence=False, train_extremely_large_corpus=True,
|
||||||
|
# hyperparameters of tokenizer
|
||||||
|
max_sentence_length=16384, pad_id=3, model_type='BPE', vocab_size=100000,
|
||||||
|
# split digits and fallback to byte same as Llama.
|
||||||
|
# set split_by_unicode_script to True to avoid grouping punctuation and characters together.
|
||||||
|
split_digits=True, split_by_unicode_script=True, byte_fallback=True,
|
||||||
|
# reserve whitespace and \n and \t etc. for code generation
|
||||||
|
allow_whitespace_only_pieces=True, remove_extra_whitespaces=False, normalization_rule_name='nfkc')
|
||||||
|
|
||||||
|
# Serialize the model as file.
|
||||||
|
with open('configs/10w_vocab_wudao5_pile10.model', 'wb') as f:
|
||||||
|
f.write(model.getvalue())
|
||||||
|
|
||||||
|
# Directly load the model from serialized model.
|
||||||
|
sp = spm.SentencePieceProcessor(model_proto=model.getvalue())
|
||||||
|
print(sp.decode(sp.encode('只因你太美🤗▃ \n 1')))
|
17
dataset/validation.py
Normal file
17
dataset/validation.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
val_set = [
|
||||||
|
'白日依山尽,',
|
||||||
|
'君不见,黄河之水天上来,奔流到海不复回。君不见,',
|
||||||
|
'秦孝公据崤函之固,拥雍州之地,君臣固守以窥周室,有席卷天下,包举宇内,囊括四海之意,并吞八荒之心。',
|
||||||
|
'古之学者必有师。师者,所以传道受业解惑也。人非生而知之者,孰能无惑?',
|
||||||
|
'当我醒来时,我发现自己在一个完全陌生的地方。我看到周围没有人,只有一张纸条。',
|
||||||
|
'这是一个斗气决定一切的大陆。在加玛帝国乌坦城,有个天才少年萧炎打破了所有族人的修炼纪录,一时间万人敬仰,众人艳羡。但不知为何,',
|
||||||
|
'人工智能技术在图像识别领域取得了很大的进展,然而在复杂场景下仍然存在一些问题,例如',
|
||||||
|
'In recent years, there has been increasing interest in the use of machine learning to',
|
||||||
|
'已知三个数分别为1, 2, 3,求它们的平均数是',
|
||||||
|
'小明总共有15个苹果,他分别给了3个人两个苹果,然后自己又吃了一个苹果,那么它还剩几个苹果?',
|
||||||
|
'根据牛顿第二定律,物体的加速度等于',
|
||||||
|
'碳纳米管是一种新型的材料,具有非常独特的电学和光学性质。在过去的几年中,我们对碳纳',
|
||||||
|
'下面是一段用python写的快速排序的代码:',
|
||||||
|
'The quantum many-body problem is a fundamental problem in condensed matter physics. Despite decades of research, there is still no exact solution to this problem for large systems. In this paper, we propose a novel approach based on',
|
||||||
|
'下面是一个使用 PyTorch 和 Transformer 的示例代码,用于训练一个文本分类模型:import torch\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader, Dataset'
|
||||||
|
]
|
12
models/llama.py
Normal file
12
models/llama.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
'''
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-03-17 13:21:33
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-03-26 23:13:57
|
||||||
|
FilePath: /Open-Llama/models/llama.py
|
||||||
|
Description:
|
||||||
|
Building the Llama model proposed by Meta. https://arxiv.org/pdf/2302.13971.pdf
|
||||||
|
Performance and effectiveness optimization based on the implementation in the Transformer library.
|
||||||
|
https://github.com/Bayes-Song/transformers
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
'''
|
148
pretrain_llama.py
Normal file
148
pretrain_llama.py
Normal file
|
@ -0,0 +1,148 @@
|
||||||
|
'''
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-03-17 14:27:28
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-03-26 23:33:41
|
||||||
|
FilePath: /Open-Llama/pretrain_llama.py
|
||||||
|
Description:
|
||||||
|
pretrain GPT
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
'''
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import wandb
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import sentencepiece as spm
|
||||||
|
from torchinfo import summary
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from datasets import IterableDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from deepspeed.ops.adam import FusedAdam
|
||||||
|
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
|
||||||
|
|
||||||
|
from dataset.validation import val_set
|
||||||
|
from dataset.tokenizer import Tokenizer
|
||||||
|
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
||||||
|
from dataset.pretrain_dataset import preprocess_the_pile_gen, preprocess_wudao_gen, pretrain_collate_fn_gen
|
||||||
|
from configs.train_config import *
|
||||||
|
|
||||||
|
accelerator = Accelerator()
|
||||||
|
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
wandb.init(
|
||||||
|
project='LLAMA Pretrain'
|
||||||
|
)
|
||||||
|
|
||||||
|
log_interval *= accelerator.gradient_accumulation_steps
|
||||||
|
eval_interval *= accelerator.gradient_accumulation_steps
|
||||||
|
save_interval *= accelerator.gradient_accumulation_steps
|
||||||
|
|
||||||
|
sp_model = spm.SentencePieceProcessor(model_file=tokenizer_model_path)
|
||||||
|
tokenizer = Tokenizer(sp_model)
|
||||||
|
|
||||||
|
paths = create_shard_kwargs(patterns)
|
||||||
|
random.shuffle(paths)
|
||||||
|
transform_dict = {
|
||||||
|
'wudao': preprocess_wudao_gen(tokenizer, max_length),
|
||||||
|
'pile': preprocess_the_pile_gen(tokenizer, max_length)
|
||||||
|
}
|
||||||
|
data_set = IterableDataset.from_generator(create_data_iter, gen_kwargs={
|
||||||
|
'paths': paths,
|
||||||
|
'transform_dict': transform_dict,
|
||||||
|
'process_index': accelerator.process_index,
|
||||||
|
'num_processes': accelerator.num_processes
|
||||||
|
})
|
||||||
|
train_loader = DataLoader(data_set, batch_size=train_batch_size, num_workers=1,
|
||||||
|
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length), drop_last=True)
|
||||||
|
# smaller initializer_range make training more stable
|
||||||
|
# add stabel embedding to token embedding
|
||||||
|
raw_model = LlamaForCausalLM(LlamaConfig(vocab_size=tokenizer.vocab_size,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
pad_token_id=tokenizer.pad_id,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_dropout_prob=0.1,
|
||||||
|
use_stable_embedding=True,
|
||||||
|
shared_input_output_embedding=True))
|
||||||
|
raw_model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
summary(raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda())
|
||||||
|
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [p for n, p in raw_model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [p for n, p in raw_model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
optim = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=(0.9, 0.95))
|
||||||
|
optim.zero_grad()
|
||||||
|
factor = accelerator.num_processes / accelerator.gradient_accumulation_steps
|
||||||
|
scheduler = get_cosine_schedule_with_warmup(optim, num_warmup_steps=num_warmup_steps * factor,
|
||||||
|
num_training_steps=num_training_steps * factor)
|
||||||
|
|
||||||
|
_, model, optim, scheduler = accelerator.prepare(
|
||||||
|
train_loader, raw_model, optim, scheduler
|
||||||
|
)
|
||||||
|
print('start training...')
|
||||||
|
train_loader_iter = iter(train_loader)
|
||||||
|
global_step = 0
|
||||||
|
start_time = time.time()
|
||||||
|
for data_step in range(num_training_steps):
|
||||||
|
model.train()
|
||||||
|
with accelerator.accumulate(model):
|
||||||
|
batch = next(train_loader_iter)
|
||||||
|
for k, v in batch.items():
|
||||||
|
batch[k] = v.to(accelerator.device)
|
||||||
|
labels = batch['input_ids'].clone()
|
||||||
|
labels[labels==tokenizer.pad_id] = -100
|
||||||
|
out = model(**batch, labels=labels)
|
||||||
|
total_loss = out.loss
|
||||||
|
losses = {
|
||||||
|
'total_loss': total_loss
|
||||||
|
}
|
||||||
|
accelerator.backward(total_loss)
|
||||||
|
optim.step()
|
||||||
|
scheduler.step()
|
||||||
|
optim.zero_grad()
|
||||||
|
if accelerator.sync_gradients:
|
||||||
|
global_step += 1
|
||||||
|
if data_step % log_interval == 0 and accelerator.is_main_process:
|
||||||
|
cost_time = time.time() - start_time
|
||||||
|
start_time = time.time()
|
||||||
|
tokens = train_batch_size * log_interval * max_length
|
||||||
|
wandb.log({'Training/Token per second per gpu': tokens/cost_time})
|
||||||
|
for k, v in losses.items():
|
||||||
|
wandb.log({'Losses/{}'.format(k): v})
|
||||||
|
current_lr = optim.param_groups[0]['lr']
|
||||||
|
wandb.log({'Training/LR': current_lr})
|
||||||
|
if optim.scaler is not None:
|
||||||
|
wandb.log({'Training/Loss Scale': optim.scaler.get_scale()})
|
||||||
|
wandb.log({'Training/Data Step': data_step})
|
||||||
|
wandb.log({'Training/Global Step': global_step})
|
||||||
|
accelerator.print('Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}'.format(
|
||||||
|
global_step, data_step, losses['total_loss'], tokens/cost_time))
|
||||||
|
if data_step % eval_interval == 0 and accelerator.is_main_process:
|
||||||
|
text_table = wandb.Table(columns=['question', 'pred'])
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for data in val_set:
|
||||||
|
raw_inputs = data
|
||||||
|
inputs_len = len(raw_inputs)
|
||||||
|
inputs = tokenizer(raw_inputs, return_tensors=True, add_special_tokens=False)
|
||||||
|
for k, v in inputs.items():
|
||||||
|
inputs[k] = v.to(accelerator.device)
|
||||||
|
pred = model.generate(**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0)
|
||||||
|
pred = tokenizer.decode(pred.cpu())[0]
|
||||||
|
pred = pred[inputs_len:]
|
||||||
|
text_table.add_data(raw_inputs, pred)
|
||||||
|
wandb.log({'Predictions on {}'.format(global_step) : text_table})
|
||||||
|
if data_step % save_interval == 0 and data_step > 0 and accelerator.is_main_process:
|
||||||
|
if not os.path.isdir(work_dir):
|
||||||
|
os.mkdir(work_dir)
|
||||||
|
torch.save(raw_model.state_dict(), '{}/{}.pt'.format(work_dir, global_step))
|
||||||
|
wandb.finish()
|
19
requirements.txt
Normal file
19
requirements.txt
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
torch==1.13.1
|
||||||
|
torchvision
|
||||||
|
torchaudio
|
||||||
|
zstandard
|
||||||
|
accelerate
|
||||||
|
datasets
|
||||||
|
wandb
|
||||||
|
deepspeed
|
||||||
|
absl-py
|
||||||
|
torchinfo
|
||||||
|
scikit-learn
|
||||||
|
datasets==2.10.1
|
||||||
|
matplotlib
|
||||||
|
seaborn
|
||||||
|
sentencepiece
|
||||||
|
triton
|
||||||
|
functorch==1.13.1
|
||||||
|
xformers
|
||||||
|
git+https://github.com/Bayes-Song/transformers.git
|
Loading…
Reference in New Issue
Block a user