using huggingface datasets to accelerate training, using open-llama to pretrain
This commit is contained in:
parent
3f62a23ee2
commit
f8f4cde228
|
@ -12,15 +12,15 @@ import torch
|
|||
import gradio as gr
|
||||
import sentencepiece as spm
|
||||
from dataset.tokenizer import Tokenizer
|
||||
from transformers import LlamaForCausalLM, LlamaConfig
|
||||
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig
|
||||
|
||||
|
||||
sp_model = spm.SentencePieceProcessor(
|
||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||
)
|
||||
tokenizer = Tokenizer(sp_model)
|
||||
raw_model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
raw_model = OpenLlamaForCausalLM(
|
||||
OpenLlamaConfig(
|
||||
vocab_size=tokenizer.vocab_size,
|
||||
initializer_range=0.01,
|
||||
pad_token_id=tokenizer.pad_id,
|
||||
|
|
BIN
configs/llama_tokenizer_extended.model
Normal file
BIN
configs/llama_tokenizer_extended.model
Normal file
Binary file not shown.
|
@ -1,9 +1,13 @@
|
|||
data:
|
||||
patterns: ["data/pretrain_data/part-*.jsonl.zst"]
|
||||
tokenizer_model_path: "configs/10w_vocab_wudao5_pile10.model"
|
||||
mode: "pretrain"
|
||||
data:
|
||||
mixed: "data/pretrain_data/part-*.jsonl.zst"
|
||||
concat_multiple_sequence: True
|
||||
num_sequences: 10
|
||||
seq_length: 2048
|
||||
tokenizer_model_path: "configs/llama_tokenizer_extended.model"
|
||||
model:
|
||||
initializer_range: 1.0e-2
|
||||
max_length: 1024
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_dropout_prob: 0.1
|
||||
use_stable_embedding: True
|
||||
|
@ -16,6 +20,7 @@ train:
|
|||
lr: 2.0e-4
|
||||
weight_decay: 1.0e-1
|
||||
ckpt: null
|
||||
train_num_workers: 16
|
||||
# global step
|
||||
log_interval: 5
|
||||
eval_interval: 200
|
||||
|
|
|
@ -18,7 +18,7 @@ root_dir = "data"
|
|||
dataset = load_dataset("yizhongw/self_instruct")
|
||||
write_path = root_dir + "/instruction_data/part-self_instruct-{}.jsonl.zst"
|
||||
total_num = 0
|
||||
file_num = 0
|
||||
file_num = 1
|
||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||
for line in dataset["train"]:
|
||||
line = json.dumps(line)
|
||||
|
@ -39,7 +39,7 @@ print(
|
|||
dataset = load_dataset("BelleGroup/train_0.5M_CN")
|
||||
write_path = root_dir + "/instruction_data/part-belle_0.5M-{}.jsonl.zst"
|
||||
total_num = 0
|
||||
file_num = 0
|
||||
file_num = 1
|
||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||
for line in dataset["train"]:
|
||||
line = json.dumps(line)
|
||||
|
@ -60,7 +60,7 @@ print(
|
|||
dataset = load_dataset("BelleGroup/train_1M_CN")
|
||||
write_path = root_dir + "/instruction_data/part-belle_1M-{}.jsonl.zst"
|
||||
total_num = 0
|
||||
file_num = 0
|
||||
file_num = 1
|
||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||
for line in dataset["train"]:
|
||||
line = json.dumps(line)
|
||||
|
@ -81,7 +81,7 @@ print(
|
|||
dataset = load_dataset("BelleGroup/school_math_0.25M")
|
||||
write_path = root_dir + "/instruction_data/part-belle_school_math_0.25M-{}.jsonl.zst"
|
||||
total_num = 0
|
||||
file_num = 0
|
||||
file_num = 1
|
||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||
for line in dataset["train"]:
|
||||
line = json.dumps(line)
|
||||
|
@ -102,7 +102,7 @@ print(
|
|||
dataset = load_dataset("BelleGroup/multiturn_chat_0.8M")
|
||||
write_path = root_dir + "/instruction_data/part-belle_multiturn_chat_0.8M-{}.jsonl.zst"
|
||||
total_num = 0
|
||||
file_num = 0
|
||||
file_num = 1
|
||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||
for line in dataset["train"]:
|
||||
line = json.dumps(line)
|
||||
|
@ -123,7 +123,7 @@ print(
|
|||
dataset = load_dataset("Graverman/Instruct-to-Code")
|
||||
write_path = root_dir + "/instruction_data/part-instruct_to_code-{}.jsonl.zst"
|
||||
total_num = 0
|
||||
file_num = 0
|
||||
file_num = 1
|
||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||
for line in dataset["train"]:
|
||||
line = json.dumps(line)
|
||||
|
@ -143,7 +143,7 @@ print(
|
|||
|
||||
write_path = root_dir + "/instruction_data/part-sharegpt_90K-{}.jsonl.zst"
|
||||
total_num = 0
|
||||
file_num = 0
|
||||
file_num = 1
|
||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||
with open("data/sg_90k_part1.json", "r") as fp:
|
||||
data1 = json.load(fp)
|
||||
|
|
|
@ -17,7 +17,7 @@ import zstandard as zstd
|
|||
paths = glob("data/the_pile/*.jsonl.zst")
|
||||
write_path = "data/pretrain_data/part-pile-{}.jsonl.zst"
|
||||
total_num = 0
|
||||
file_num = 0
|
||||
file_num = 1
|
||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||
for path in tqdm(paths, total=len(paths)):
|
||||
with zstd.open(path, "r", encoding="utf-8") as fp:
|
||||
|
|
|
@ -17,7 +17,7 @@ import zstandard as zstd
|
|||
paths = glob("data/WuDaoCorpus2.0_base_200G/part*")
|
||||
write_path = "data/pretrain_data/part-wudao-{}.jsonl.zst"
|
||||
total_num = 0
|
||||
file_num = 0
|
||||
file_num = 1
|
||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||
for path in tqdm(paths, total=len(paths)):
|
||||
with open(path, "r") as fp:
|
||||
|
|
|
@ -1,69 +0,0 @@
|
|||
"""
|
||||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-30 20:58:16
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-05 22:11:03
|
||||
FilePath: /Open-Llama/dataset/collate_fn.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def collate_fn_gen(tokenizer, segment_max_length=1024, padding="longest"):
|
||||
"""
|
||||
Organize data into tensors by padding based on the preset maximum length.
|
||||
"""
|
||||
pad_id = tokenizer.pad_id
|
||||
|
||||
def collate_fn(batch):
|
||||
if padding == "longest":
|
||||
max_length = max([len(i) for i in batch])
|
||||
elif padding == "max_length":
|
||||
max_length = segment_max_length
|
||||
else:
|
||||
raise Exception("Invalid argumet for padding: {}".format(padding))
|
||||
input_ids = []
|
||||
for i in batch:
|
||||
input_len = len(i)
|
||||
input_ids.append(i + [pad_id] * (max_length - input_len))
|
||||
inputs = {
|
||||
"input_ids": torch.tensor(input_ids, dtype=torch.int64),
|
||||
}
|
||||
return inputs
|
||||
|
||||
return collate_fn
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sentencepiece as spm
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset.pretrain_dataset import preprocess_wudao_gen, preprocess_the_pile_gen
|
||||
|
||||
from dataset.tokenizer import Tokenizer
|
||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||
|
||||
sp_model = spm.SentencePieceProcessor(
|
||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||
)
|
||||
tokenizer = Tokenizer(sp_model)
|
||||
patterns = ["data/pretrain_data/part-*.jsonl.zst"]
|
||||
paths = create_shard_kwargs(patterns)
|
||||
transform_dict = {
|
||||
"wudao": preprocess_wudao_gen(tokenizer),
|
||||
"pile": preprocess_the_pile_gen(tokenizer),
|
||||
}
|
||||
data_set = DataIter(paths, transform_dict=transform_dict)
|
||||
train_loader = DataLoader(
|
||||
data_set,
|
||||
batch_size=8,
|
||||
num_workers=4,
|
||||
collate_fn=collate_fn_gen(tokenizer),
|
||||
drop_last=True,
|
||||
)
|
||||
for batch in train_loader:
|
||||
for k, v in batch.items():
|
||||
print(k, v.shape)
|
||||
break
|
|
@ -1,116 +0,0 @@
|
|||
"""
|
||||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-17 19:32:20
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-06 03:37:55
|
||||
FilePath: /Open-Llama/dataset/data_iter.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import json
|
||||
from glob import glob
|
||||
import zstandard as zstd
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
|
||||
class DataIter(IterableDataset):
|
||||
"""
|
||||
Currently, the allowed storage formats are jsonl.zst.
|
||||
Each line of the data is a dictionary, which can be parsed as JSON for subsequent processing after reading.
|
||||
Currently, only single worker is supported.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
paths_with_index,
|
||||
transform_dict=None,
|
||||
max_length=None,
|
||||
concat_docs=False,
|
||||
process_index=0,
|
||||
num_processes=1,
|
||||
):
|
||||
super().__init__()
|
||||
self.paths_with_index = paths_with_index
|
||||
self.max_length = max_length
|
||||
self.transform_dict = transform_dict
|
||||
self.concat_docs = concat_docs
|
||||
self.process_index = process_index
|
||||
self.num_processes = num_processes
|
||||
if self.concat_docs:
|
||||
self.cache = []
|
||||
|
||||
def __iter__(self):
|
||||
past = None
|
||||
for i, path in self.paths_with_index:
|
||||
# part-dataset_name-01.jsonl.zst
|
||||
dataset_name = path.split("-")[-2]
|
||||
# shard to multiple device
|
||||
if self.num_processes > 1 and i % self.num_processes != self.process_index:
|
||||
continue
|
||||
# Log the file name when encountering a new file.
|
||||
if past != dataset_name:
|
||||
print("Loading data from {}".format(path))
|
||||
past = path
|
||||
# Currently, the allowed storage formats are jsonl.zst.
|
||||
assert path.endswith("jsonl.zst")
|
||||
with zstd.open(path, "r", encoding="utf-8") as fp:
|
||||
for line in fp:
|
||||
# If the length of the cache is greater than max_length.
|
||||
if self.concat_docs and len(self.cache) >= self.max_length:
|
||||
seq = self.cache[: self.max_length]
|
||||
self.cache = self.cache[self.max_length :]
|
||||
yield seq
|
||||
if isinstance(line, bytes):
|
||||
line = line.decode("utf-8")
|
||||
line = json.loads(line)
|
||||
line["dataset"] = dataset_name
|
||||
# Transformation, including sample, tokenize, etc.
|
||||
if self.transform_dict:
|
||||
line = self.transform_dict[dataset_name](line)
|
||||
# skip bad doc
|
||||
if line is None:
|
||||
continue
|
||||
elif isinstance(line, str):
|
||||
yield line
|
||||
# must be list of list
|
||||
elif isinstance(line, list) and isinstance(line[0], list):
|
||||
for seq in line:
|
||||
if self.concat_docs:
|
||||
# concat seq from multiple docs
|
||||
self.cache += seq
|
||||
else:
|
||||
yield seq
|
||||
else:
|
||||
raise Exception(
|
||||
"Unsupported type in Transformation: {}".format(
|
||||
self.transform_dict[dataset_name]
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield line
|
||||
|
||||
|
||||
def create_shard_kwargs(patterns, repeat=1):
|
||||
"""
|
||||
Assign numbers to different shards of data to ensure that data is not duplicated
|
||||
when allocated to different nodes during distributed training.
|
||||
"""
|
||||
all_path = []
|
||||
for p in patterns:
|
||||
all_path.extend(glob(p))
|
||||
all_path *= repeat
|
||||
return [(i, p) for i, p in enumerate(all_path)]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
|
||||
paths = create_shard_kwargs(patterns)
|
||||
transform_dict = {"wudao": lambda x: x["title"], "pile": lambda x: [x["text"]]}
|
||||
data_iter = DataIter(
|
||||
paths, transform_dict=transform_dict, max_length=16, concat_docs=True
|
||||
)
|
||||
for i, data in enumerate(data_iter):
|
||||
print(i, data)
|
||||
if i == 20:
|
||||
break
|
159
dataset/dataset.py
Normal file
159
dataset/dataset.py
Normal file
|
@ -0,0 +1,159 @@
|
|||
import torch
|
||||
import random
|
||||
from glob import glob
|
||||
from datasets import load_dataset, interleave_datasets
|
||||
|
||||
|
||||
def pretrain_transform(line):
|
||||
if "title" in line and "text" not in line:
|
||||
line["text"] = line["title"] + "\n" + line["content"]
|
||||
return line
|
||||
|
||||
|
||||
def sample_sequence_gen(seq_length, eos_token_id):
|
||||
def sample_sequence(line):
|
||||
doc_length = line["input_ids"].shape[1]
|
||||
if doc_length <= seq_length:
|
||||
start = 0
|
||||
else:
|
||||
if random.random() < 1 / 4:
|
||||
start = 0
|
||||
else:
|
||||
start = random.randint(0, doc_length - seq_length)
|
||||
input_ids = line["input_ids"][0, start : start + seq_length]
|
||||
if input_ids[-1] != eos_token_id:
|
||||
input_ids[-1] = eos_token_id
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
return sample_sequence
|
||||
|
||||
|
||||
def concat_multiple_sequence_gen(seq_length):
|
||||
def concat_multiple_sequence(batch):
|
||||
concat_input_ids = torch.cat(batch["input_ids"], dim=0)
|
||||
input_ids = []
|
||||
while len(concat_input_ids) > (1 + len(input_ids)) * seq_length:
|
||||
input_ids.append(
|
||||
concat_input_ids[
|
||||
len(input_ids) * seq_length : (1 + len(input_ids)) * seq_length
|
||||
]
|
||||
)
|
||||
out = {"input_ids": input_ids}
|
||||
return out
|
||||
|
||||
return concat_multiple_sequence
|
||||
|
||||
|
||||
def get_labels_gen(pad_token_id):
|
||||
def get_labels(line):
|
||||
input_ids = line["input_ids"]
|
||||
labels = input_ids.clone()
|
||||
labels[labels == pad_token_id] = -100
|
||||
return {"labels": labels}
|
||||
|
||||
return get_labels
|
||||
|
||||
|
||||
def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
||||
datasets = []
|
||||
probabilities = []
|
||||
# 暂时只使用一个,使用多个时无法使用多进程读取导致耗时较长
|
||||
assert len(dataset_config["data"]) == 1
|
||||
for name, pattern in dataset_config["data"].items():
|
||||
data_files = glob(pattern)
|
||||
assert len(data_files) > 0
|
||||
dataset = load_dataset(
|
||||
"json", data_files=data_files, split="train", streaming=True
|
||||
)
|
||||
if dataset_config["mode"] == "pretrain":
|
||||
dataset = dataset.map(pretrain_transform)
|
||||
else:
|
||||
raise Exception(
|
||||
"Dataset mode: {} not found.".format(dataset_config["mode"])
|
||||
)
|
||||
datasets.append(dataset)
|
||||
probabilities.append(dataset.n_shards)
|
||||
probabilities_sum = sum(probabilities)
|
||||
probabilities = [p / probabilities_sum for p in probabilities]
|
||||
if len(datasets) > 1:
|
||||
full_dataset = interleave_datasets(
|
||||
datasets, probabilities=probabilities, seed=42
|
||||
)
|
||||
else:
|
||||
full_dataset = datasets[0]
|
||||
if return_raw_text:
|
||||
return full_dataset
|
||||
seq_length = dataset_config["seq_length"]
|
||||
if dataset_config.get("concat_multiple_sequence", False):
|
||||
num_sequences = dataset_config["num_sequences"]
|
||||
full_dataset = full_dataset.map(
|
||||
lambda x: tokenizer(
|
||||
x["text"], return_tensors="pt", return_attention_mask=False
|
||||
)
|
||||
)
|
||||
full_dataset = full_dataset.map(
|
||||
sample_sequence_gen(seq_length, tokenizer.eos_token_id)
|
||||
)
|
||||
full_dataset = full_dataset.select_columns("input_ids")
|
||||
full_dataset = full_dataset.map(
|
||||
concat_multiple_sequence_gen(seq_length),
|
||||
batched=True,
|
||||
batch_size=num_sequences,
|
||||
)
|
||||
else:
|
||||
full_dataset = full_dataset.map(
|
||||
lambda x: tokenizer(
|
||||
x["text"],
|
||||
return_tensors="pt",
|
||||
return_attention_mask=False,
|
||||
padding="max_length",
|
||||
max_length=seq_length,
|
||||
truncation=True,
|
||||
)
|
||||
)
|
||||
full_dataset = full_dataset.map(lambda x: {"input_ids": x["input_ids"][0]})
|
||||
full_dataset = full_dataset.select_columns("input_ids")
|
||||
full_dataset = full_dataset.map(get_labels_gen(tokenizer.pad_token_id))
|
||||
return full_dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
from unicodedata import normalize
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import OpenLlamaTokenizer
|
||||
|
||||
data_config = {
|
||||
"mode": "pretrain",
|
||||
"data": {"wudao": "data/pretrain_data/part-wudao*.jsonl.zst"},
|
||||
"concat_multiple_sequence": True,
|
||||
"num_sequences": 10,
|
||||
"seq_length": 2048,
|
||||
}
|
||||
tokenizer = OpenLlamaTokenizer(
|
||||
"configs/llama_tokenizer_extended.model",
|
||||
pad_token="<pad>",
|
||||
add_bos_token=False,
|
||||
add_eos_token=True,
|
||||
)
|
||||
pretrain_dataset = construct_dataset(data_config, tokenizer, True)
|
||||
start = time.time()
|
||||
for i, line in enumerate(pretrain_dataset):
|
||||
raw_text = line["text"]
|
||||
# raw_text = normalize("NFKC", raw_text)
|
||||
input_ids = tokenizer(
|
||||
line["text"], return_tensors="pt", return_attention_mask=False
|
||||
)["input_ids"][0, :-1]
|
||||
decode_text = tokenizer.decode(input_ids)
|
||||
if raw_text != decode_text and "▁" not in raw_text:
|
||||
print(raw_text, "\n", decode_text)
|
||||
if i == 3000:
|
||||
break
|
||||
print("all checked in {} seconds.".format(time.time() - start))
|
||||
pretrain_dataset = construct_dataset(data_config, tokenizer)
|
||||
print(pretrain_dataset.n_shards)
|
||||
pretrain_loader = DataLoader(pretrain_dataset, batch_size=2, num_workers=16)
|
||||
for batch in pretrain_loader:
|
||||
for k, v in batch.items():
|
||||
print(k, v.shape, "\n", v)
|
||||
break
|
|
@ -1,71 +0,0 @@
|
|||
"""
|
||||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-17 20:41:25
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-05 22:32:39
|
||||
FilePath: /Open-Llama/dataset/pretrain_dataset.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import math
|
||||
|
||||
|
||||
def preprocess_wudao_gen(tokenizer, segment_max_length=1024):
|
||||
def preprocess_wudao(line):
|
||||
"""
|
||||
The format of the data is roughly as follows.
|
||||
{'id': 1, 'dataType': '百科', 'title': 'some title', 'content': 'some content'}
|
||||
Split the data based on the tokenized length according to the maximum length.
|
||||
"""
|
||||
total = line["title"] + "\n" + line["content"]
|
||||
out = tokenizer(total)
|
||||
input_ids = out["input_ids"]
|
||||
return [
|
||||
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||
]
|
||||
|
||||
return preprocess_wudao
|
||||
|
||||
|
||||
def preprocess_the_pile_gen(tokenizer, segment_max_length=1024):
|
||||
def preprocess_the_pile(line):
|
||||
"""
|
||||
The format of the data is roughly as follows.
|
||||
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||
Split the data based on the tokenized length according to the maximum length.
|
||||
"""
|
||||
total = line["text"]
|
||||
out = tokenizer(total)
|
||||
input_ids = out["input_ids"]
|
||||
return [
|
||||
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||
]
|
||||
|
||||
return preprocess_the_pile
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sentencepiece as spm
|
||||
|
||||
from dataset.tokenizer import Tokenizer
|
||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||
|
||||
sp_model = spm.SentencePieceProcessor(
|
||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||
)
|
||||
tokenizer = Tokenizer(sp_model)
|
||||
patterns = ["data/pretrain_data/part-*.jsonl.zst"]
|
||||
paths = create_shard_kwargs(patterns)
|
||||
transform_dict = {
|
||||
"wudao": preprocess_wudao_gen(tokenizer),
|
||||
"pile": preprocess_the_pile_gen(tokenizer),
|
||||
}
|
||||
data_set = DataIter(
|
||||
paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
|
||||
)
|
||||
for sample in data_set:
|
||||
print(sample)
|
||||
break
|
|
@ -1,218 +0,0 @@
|
|||
"""
|
||||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-20 21:39:47
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-06 23:01:50
|
||||
FilePath: /Open-Llama/dataset/tokenizer.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, sp_model):
|
||||
self.sp_model = sp_model
|
||||
self.bos_id = self.sp_model.bos_id()
|
||||
self.eos_id = self.sp_model.eos_id()
|
||||
self.pad_id = self.sp_model.pad_id()
|
||||
self.vocab_size = self.sp_model.vocab_size()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs,
|
||||
padding=None,
|
||||
max_length=256,
|
||||
return_tensors=False,
|
||||
truncation=False,
|
||||
add_special_tokens=True,
|
||||
return_mask=False,
|
||||
):
|
||||
if isinstance(inputs, str):
|
||||
return self.encode(
|
||||
inputs,
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
return_tensors=return_tensors,
|
||||
truncation=truncation,
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_mask=return_mask,
|
||||
)
|
||||
else:
|
||||
return self.encode_batch(
|
||||
inputs,
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
return_tensors=return_tensors,
|
||||
truncation=truncation,
|
||||
add_special_tokens=add_special_tokens,
|
||||
return_mask=return_mask,
|
||||
)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
inputs,
|
||||
padding=None,
|
||||
max_length=8192,
|
||||
return_tensors=False,
|
||||
truncation=False,
|
||||
add_special_tokens=True,
|
||||
return_mask=False,
|
||||
):
|
||||
assert isinstance(inputs, str)
|
||||
input_ids = self.sp_model.Encode(inputs)
|
||||
if return_mask:
|
||||
attention_mask = [1] * len(input_ids)
|
||||
if truncation:
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L780
|
||||
# 参考Transformer中的实现 默认最后一位一定是pad或者eos
|
||||
input_ids = input_ids[: max_length - 1]
|
||||
if return_mask:
|
||||
attention_mask = attention_mask[: max_length - 1]
|
||||
if add_special_tokens:
|
||||
input_ids = input_ids + [self.eos_id]
|
||||
if return_mask:
|
||||
attention_mask = attention_mask + [0]
|
||||
if padding == "max_length":
|
||||
input_ids = input_ids + [self.pad_id] * (max_length - len(input_ids))
|
||||
if return_mask:
|
||||
attention_mask = attention_mask + [0] * (
|
||||
max_length - len(attention_mask)
|
||||
)
|
||||
if return_tensors:
|
||||
input_ids = torch.tensor([input_ids])
|
||||
out = {
|
||||
"input_ids": input_ids,
|
||||
}
|
||||
if return_mask:
|
||||
attention_mask = torch.tensor([attention_mask])
|
||||
out["attention_mask"] = attention_mask
|
||||
else:
|
||||
out = {
|
||||
"input_ids": input_ids,
|
||||
}
|
||||
if return_mask:
|
||||
out["attention_mask"] = attention_mask
|
||||
return out
|
||||
|
||||
def encode_batch(
|
||||
self,
|
||||
inputs,
|
||||
padding=None,
|
||||
max_length=8192,
|
||||
return_tensors=False,
|
||||
truncation=False,
|
||||
add_special_tokens=True,
|
||||
return_mask=False,
|
||||
):
|
||||
input_ids = self.sp_model.Encode(inputs)
|
||||
if return_mask:
|
||||
attention_mask = [[1] * len(i) for i in input_ids]
|
||||
if truncation:
|
||||
input_ids = [i[: max_length - 1] for i in input_ids]
|
||||
if return_mask:
|
||||
attention_mask = [i[: max_length - 1] for i in attention_mask]
|
||||
if add_special_tokens:
|
||||
input_ids = [i + [self.eos_id] for i in input_ids]
|
||||
if return_mask:
|
||||
attention_mask = [i + [0] for i in attention_mask]
|
||||
if padding == "max_length":
|
||||
input_ids_pad = []
|
||||
if return_mask:
|
||||
attention_mask_pad = []
|
||||
for idx, i in enumerate(input_ids):
|
||||
input_ids_pad.append(i + [self.pad_id] * (max_length - len(i)))
|
||||
if return_mask:
|
||||
j = attention_mask[idx]
|
||||
attention_mask_pad.append(j + [0] * (max_length - len(j)))
|
||||
input_ids = input_ids_pad
|
||||
if return_mask:
|
||||
attention_mask = attention_mask_pad
|
||||
if return_tensors:
|
||||
input_ids = torch.tensor(input_ids)
|
||||
out = {
|
||||
"input_ids": input_ids,
|
||||
}
|
||||
if return_mask:
|
||||
attention_mask = torch.tensor(attention_mask)
|
||||
out["attention_mask"] = attention_mask
|
||||
else:
|
||||
out = {
|
||||
"input_ids": input_ids,
|
||||
}
|
||||
if return_mask:
|
||||
out["attention_mask"] = attention_mask
|
||||
return out
|
||||
|
||||
def decode(self, inputs, max_rounds=None):
|
||||
inputs = inputs.tolist()
|
||||
out = []
|
||||
for i, ids in enumerate(inputs):
|
||||
count = 0
|
||||
flag = False
|
||||
for j, token in enumerate(ids):
|
||||
if token == self.eos_id:
|
||||
if max_rounds is None:
|
||||
flag = True
|
||||
break
|
||||
elif isinstance(max_rounds, int):
|
||||
if count < max_rounds:
|
||||
count += 1
|
||||
else:
|
||||
flag = True
|
||||
break
|
||||
elif isinstance(max_rounds, list):
|
||||
if count < max_rounds[i]:
|
||||
count += 1
|
||||
else:
|
||||
flag = True
|
||||
break
|
||||
if flag:
|
||||
ids = ids[:j]
|
||||
else:
|
||||
ids = ids
|
||||
out.append(ids)
|
||||
out = self.sp_model.Decode(out)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sentencepiece as spm
|
||||
from unicodedata import normalize
|
||||
|
||||
# Using sentencepiece may not be able to process some reserved keywords like '▁'.
|
||||
sp_model = spm.SentencePieceProcessor(
|
||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||
)
|
||||
tokenizer = Tokenizer(sp_model)
|
||||
tmp = [
|
||||
"hello world",
|
||||
"这是开源项目的V1版本,this is the first version of a open-source project!",
|
||||
"# this is a python script\nfor i in range(10):\n print(i)\n for j in range(10):\n print(j)",
|
||||
]
|
||||
print(tmp)
|
||||
out = tokenizer(
|
||||
tmp, padding="max_length", return_tensors=True, max_length=64, truncation=True
|
||||
)
|
||||
for k, v in out.items():
|
||||
print(k, v.shape)
|
||||
print(out["input_ids"])
|
||||
out = tokenizer.decode(out["input_ids"])
|
||||
print(out)
|
||||
for i, j in zip(tmp, out):
|
||||
assert normalize("NFKC", i) == j
|
||||
|
||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||
|
||||
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
|
||||
paths = create_shard_kwargs(patterns)
|
||||
data_iter = DataIter(paths)
|
||||
for i, data in enumerate(data_iter):
|
||||
assert (
|
||||
normalize("NFKC", data["content"])
|
||||
== sp_model.Decode(sp_model.Encode(data["content"]))
|
||||
or "▁" in data["content"]
|
||||
)
|
||||
if i == 1000:
|
||||
break
|
|
@ -18,7 +18,11 @@ from torchinfo import summary
|
|||
from accelerate import Accelerator
|
||||
from torch.utils.data import DataLoader
|
||||
from deepspeed.ops.adam import FusedAdam
|
||||
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
|
||||
from transformers import (
|
||||
OpenLlamaForCausalLM,
|
||||
OpenLlamaConfig,
|
||||
get_cosine_schedule_with_warmup,
|
||||
)
|
||||
|
||||
from dataset.validation import val_set
|
||||
from dataset.tokenizer import Tokenizer
|
||||
|
@ -74,8 +78,8 @@ train_loader = DataLoader(
|
|||
)
|
||||
# smaller initializer_range make training more stable
|
||||
# add stabel embedding to token embedding
|
||||
raw_model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
raw_model = OpenLlamaForCausalLM(
|
||||
OpenLlamaConfig(
|
||||
vocab_size=tokenizer.vocab_size,
|
||||
initializer_range=initializer_range,
|
||||
pad_token_id=tokenizer.pad_id,
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
"""
|
||||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-17 13:21:33
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-03-26 23:13:57
|
||||
FilePath: /Open-Llama/models/llama.py
|
||||
Description:
|
||||
Building the Llama model proposed by Meta. https://arxiv.org/pdf/2302.13971.pdf
|
||||
Performance and effectiveness optimization based on the implementation in the Transformer library.
|
||||
https://github.com/Bayes-Song/transformers
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
65
pretrain.py
65
pretrain.py
|
@ -10,77 +10,44 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
|||
"""
|
||||
import yaml
|
||||
import torch
|
||||
import random
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import sentencepiece as spm
|
||||
from accelerate import Accelerator
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import LlamaForCausalLM, LlamaConfig
|
||||
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, OpenLlamaTokenizer
|
||||
|
||||
from dataset.tokenizer import Tokenizer
|
||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||
from dataset.collate_fn import collate_fn_gen
|
||||
from dataset.pretrain_dataset import (
|
||||
preprocess_the_pile_gen,
|
||||
preprocess_wudao_gen,
|
||||
)
|
||||
from dataset.dataset import construct_dataset
|
||||
from solver.trainer import Trainer
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("config", None, "Training config path")
|
||||
|
||||
|
||||
class FakeSet(torch.utils.data.Dataset):
|
||||
def __getitem__(self, idx):
|
||||
return {"input_ids": torch.randint(0, 32000, (2048,))}
|
||||
|
||||
def __len__(self):
|
||||
return 1000000000
|
||||
|
||||
|
||||
def main(argv):
|
||||
accelerator = Accelerator()
|
||||
|
||||
with open(FLAGS.config, "r", encoding="utf-8") as fp:
|
||||
config = yaml.load(fp, Loader=yaml.FullLoader)
|
||||
sp_model = spm.SentencePieceProcessor(
|
||||
model_file=config["data"]["tokenizer_model_path"]
|
||||
tokenizer = OpenLlamaTokenizer(
|
||||
config["data"]["tokenizer_model_path"],
|
||||
pad_token="<pad>",
|
||||
add_bos_token=False,
|
||||
add_eos_token=True,
|
||||
)
|
||||
tokenizer = Tokenizer(sp_model)
|
||||
|
||||
# paths = create_shard_kwargs(config['data']['patterns'])
|
||||
# random.shuffle(paths)
|
||||
# transform_dict = {
|
||||
# "wudao": preprocess_wudao_gen(tokenizer, config['model']['max_length']),
|
||||
# "pile": preprocess_the_pile_gen(tokenizer, config['model']['max_length']),
|
||||
# }
|
||||
# data_set = DataIter(
|
||||
# paths,
|
||||
# transform_dict=transform_dict,
|
||||
# concat_docs=True,
|
||||
# max_length=config['model']['max_length'],
|
||||
# process_index=accelerator.process_index,
|
||||
# num_processes=accelerator.num_processes,
|
||||
# )
|
||||
# train_loader = DataLoader(
|
||||
# data_set,
|
||||
# batch_size=config['train']['train_batch_size'],
|
||||
# # If num_workers is greater than 1, duplicate data may occur.
|
||||
# num_workers=0,
|
||||
# collate_fn=collate_fn_gen(tokenizer, config['model']['max_length']),
|
||||
# drop_last=True,
|
||||
# )
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
FakeSet(), batch_size=config["train"]["train_batch_size"]
|
||||
data_config = config["data"]
|
||||
pretrain_dataset = construct_dataset(data_config, tokenizer)
|
||||
train_loader = DataLoader(
|
||||
pretrain_dataset,
|
||||
batch_size=config["train"]["train_batch_size"],
|
||||
num_workers=config["train"]["train_num_workers"],
|
||||
)
|
||||
# smaller initializer_range make training more stable
|
||||
# add stabel embedding to token embedding
|
||||
raw_model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
raw_model = OpenLlamaForCausalLM(
|
||||
OpenLlamaConfig(
|
||||
vocab_size=tokenizer.vocab_size,
|
||||
initializer_range=config["model"]["initializer_range"],
|
||||
pad_token_id=tokenizer.pad_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
rms_norm_eps=1e-5,
|
||||
hidden_dropout_prob=config["model"]["hidden_dropout_prob"],
|
||||
attention_dropout_prob=config["model"]["attention_dropout_prob"],
|
||||
|
|
|
@ -17,4 +17,4 @@ triton
|
|||
functorch==1.13.1
|
||||
xformers
|
||||
gradio
|
||||
git+https://github.com/Bayes-Song/transformers.git
|
||||
git+https://github.com/s-JoL/transformers.git@dev
|
75
server.py
75
server.py
|
@ -1,75 +0,0 @@
|
|||
"""
|
||||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-31 13:26:15
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-06 03:45:44
|
||||
FilePath: /Open-Llama/server.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import torch
|
||||
import gradio as gr
|
||||
import sentencepiece as spm
|
||||
from dataset.tokenizer import Tokenizer
|
||||
from transformers import LlamaForCausalLM, LlamaConfig
|
||||
|
||||
|
||||
sp_model = spm.SentencePieceProcessor(
|
||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||
)
|
||||
tokenizer = Tokenizer(sp_model)
|
||||
|
||||
raw_model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=tokenizer.vocab_size,
|
||||
initializer_range=0.01,
|
||||
pad_token_id=tokenizer.pad_id,
|
||||
rms_norm_eps=1e-5,
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_dropout_prob=0.1,
|
||||
use_stable_embedding=True,
|
||||
shared_input_output_embedding=True,
|
||||
)
|
||||
)
|
||||
ckpt = torch.load(
|
||||
"data/saved_ckpt/instruction_tuning_3_epochs/37001.pt", map_location="cpu"
|
||||
)
|
||||
raw_model.load_state_dict(ckpt)
|
||||
raw_model.eval()
|
||||
model = raw_model.cuda()
|
||||
print("ready")
|
||||
|
||||
|
||||
def question_answer(prompt):
|
||||
print(prompt)
|
||||
raw_inputs = "user:{}\nsystem:".format(prompt)
|
||||
inputs_len = len(raw_inputs)
|
||||
inputs = tokenizer(raw_inputs, return_tensors=True, add_special_tokens=False)
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.cuda()
|
||||
pred = model.generate(**inputs, max_new_tokens=512, do_sample=True)
|
||||
pred = tokenizer.decode(pred.cpu())[0]
|
||||
pred = pred[inputs_len:]
|
||||
print(pred)
|
||||
return pred
|
||||
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=question_answer,
|
||||
inputs="text",
|
||||
outputs="text",
|
||||
examples=[
|
||||
"帮我写一封邮件,内容是咨询教授本学期量子力学课程的时间表?并且希望教授推荐一些相关书籍",
|
||||
"情人节送女朋友什么礼物,预算500",
|
||||
"我今天肚子有点不舒服,晚饭有什么建议么",
|
||||
"可以总结一下小说三体的核心内容么?",
|
||||
"Can you explain to me what quantum mechanics is and how it relates to quantum computing?",
|
||||
"请帮我写一个AI驱动的幼儿教育APP的商业计划书",
|
||||
"用python实现一个快速排序",
|
||||
],
|
||||
title="Open-Llama",
|
||||
description="不基于其他预训练模型,完全使用[Open-Llama](https://github.com/Bayes-Song/Open-Llama)项目从0开始训练的Instruct-GPT模型,总训练成本不超过2w美元。由于请求需要经Gradio进行转发,可能出现请求丢失的现象,当长时间无响应(如20s以上)可刷新重试。当前体验服务生成的所有内容都是由人工智能模型生成,我们对其生成内容的准确性、完整性和功能性不做任何保证,并且其生成的内容不代表我们的态度或观点。",
|
||||
article="联系方式: sl12160010@gmail.com 对于该项目有任何意见和建议都欢迎联系我",
|
||||
).queue(concurrency_count=1)
|
||||
demo.launch(share=True)
|
|
@ -76,15 +76,17 @@ class Trainer:
|
|||
)
|
||||
|
||||
def prepare(self):
|
||||
_, self.model, self.optim, self.scheduler = self.accelerator.prepare(
|
||||
(
|
||||
self.train_loader,
|
||||
self.model,
|
||||
self.optim,
|
||||
self.scheduler,
|
||||
) = self.accelerator.prepare(
|
||||
self.train_loader, self.raw_model, self.optim, self.scheduler
|
||||
)
|
||||
self.train_loader_iter = iter(self.train_loader)
|
||||
|
||||
def train_step(self, batch):
|
||||
for k, v in batch.items():
|
||||
batch[k] = v.to(self.accelerator.device, non_blocking=True)
|
||||
out = self.model(**batch, labels=batch["input_ids"])
|
||||
out = self.model(**batch)
|
||||
total_loss = out.loss
|
||||
losses = {"total_loss": total_loss}
|
||||
self.accelerator.backward(total_loss)
|
||||
|
@ -100,10 +102,11 @@ class Trainer:
|
|||
self.global_step = 0
|
||||
self.start_time = time.time()
|
||||
self.optim.zero_grad()
|
||||
for self.data_step in range(self.config["train"]["num_training_steps"]):
|
||||
for self.data_step, batch in enumerate(self.train_loader):
|
||||
if self.data_step >= self.config["train"]["num_training_steps"]:
|
||||
break
|
||||
self.model.train()
|
||||
with self.accelerator.accumulate(self.model):
|
||||
batch = next(self.train_loader_iter)
|
||||
losses = self.train_step(batch)
|
||||
if self.accelerator.sync_gradients:
|
||||
self.global_step += 1
|
||||
|
@ -137,7 +140,7 @@ class Trainer:
|
|||
tokens = (
|
||||
self.config["train"]["train_batch_size"]
|
||||
* self.log_interval
|
||||
* self.config["model"]["max_length"]
|
||||
* self.config["data"]["seq_length"]
|
||||
)
|
||||
wandb.log({"Training/Token per second per gpu": tokens / cost_time})
|
||||
for k, v in losses.items():
|
||||
|
@ -163,16 +166,19 @@ class Trainer:
|
|||
with torch.no_grad():
|
||||
for data in val_set:
|
||||
raw_inputs = data
|
||||
inputs_len = len(raw_inputs)
|
||||
inputs = self.tokenizer(
|
||||
raw_inputs, return_tensors=True, add_special_tokens=False
|
||||
raw_inputs,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=False,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
input_length = inputs["input_ids"].shape[1]
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.to(self.accelerator.device)
|
||||
pred = self.model.generate(
|
||||
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
|
||||
)
|
||||
pred = self.tokenizer.decode(pred.cpu())[0]
|
||||
pred = pred[inputs_len:]
|
||||
pred = pred[0, input_length:]
|
||||
pred = self.tokenizer.decode(pred.cpu())
|
||||
text_table.add_data(raw_inputs, pred)
|
||||
wandb.log({"Predictions on {}".format(self.global_step): text_table})
|
||||
|
|
Loading…
Reference in New Issue
Block a user