unified pre-training and instrcution-tuning both use train_lm and dataset
This commit is contained in:
parent
97aff0e051
commit
db6cdb51d0
|
@ -102,8 +102,7 @@ with gr.Blocks() as demo:
|
|||
context = context.cuda()
|
||||
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
|
||||
pred = pred[:, inputs_len:]
|
||||
pred = tokenizer.decode(pred.cpu()[0])
|
||||
pred = pred.strip()
|
||||
pred = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
|
||||
print(pred)
|
||||
bot_message = parse_codeblock(pred)
|
||||
history[-1][1] = bot_message
|
||||
|
|
Binary file not shown.
|
@ -1,7 +1,6 @@
|
|||
compute_environment: LOCAL_MACHINE
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
gradient_accumulation_steps: 12
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
|
|
32
configs/instruct_config.yaml
Normal file
32
configs/instruct_config.yaml
Normal file
|
@ -0,0 +1,32 @@
|
|||
data:
|
||||
mode: "instruct"
|
||||
data:
|
||||
mixed: "data/instruction_data/part-*.jsonl.zst"
|
||||
pad_to_max: False
|
||||
sequence_sample_mode: "sample"
|
||||
concat_multiple_sequence: True
|
||||
num_sequences: 50
|
||||
seq_length: 2048
|
||||
tokenizer_model_path: "configs/llama_tokenizer_extended.model"
|
||||
model:
|
||||
initializer_range: 1.0e-2
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_dropout_prob: 0.1
|
||||
use_stable_embedding: False
|
||||
shared_input_output_embedding: False
|
||||
train:
|
||||
train_batch_size: 2
|
||||
num_training_steps: 1000000
|
||||
num_warmup_steps: 2000
|
||||
initializer_range: 1.0e-2
|
||||
lr: 2.0e-4
|
||||
weight_decay: 1.0e-1
|
||||
ckpt: "data/llama_raw_ckpt/7B/extended.pth"
|
||||
train_num_workers: 16
|
||||
gradient_accumulation_steps: 1
|
||||
# global step
|
||||
log_interval: 50
|
||||
eval_interval: 500
|
||||
save_interval: 1000
|
||||
work_dir: "data/saved_ckpt/7B"
|
||||
project_name: "Llama Instruction"
|
|
@ -1,25 +0,0 @@
|
|||
"""
|
||||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-30 21:38:07
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-06 03:37:23
|
||||
FilePath: /Open-Llama/configs/instruction_tuning_config.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
max_length = 1024
|
||||
train_batch_size = 2
|
||||
num_training_steps = 40000
|
||||
num_warmup_steps = 100
|
||||
initializer_range = 1e-2
|
||||
lr = 2e-4
|
||||
weight_decay = 1e-1
|
||||
tokenizer_model_path = "configs/10w_vocab_wudao5_pile10.model"
|
||||
patterns = ["data/instruction_data/part-*.jsonl.zst"]
|
||||
# global step
|
||||
log_interval = 50
|
||||
eval_interval = 500
|
||||
save_interval = 1000
|
||||
work_dir = "data/saved_ckpt/"
|
||||
ckpt_path = "data/saved_ckpt/83200.pt"
|
|
@ -2,16 +2,18 @@ data:
|
|||
mode: "pretrain"
|
||||
data:
|
||||
mixed: "data/pretrain_data/part-*.jsonl.zst"
|
||||
pad_to_max: False
|
||||
sequence_sample_mode: "sample"
|
||||
concat_multiple_sequence: True
|
||||
num_sequences: 10
|
||||
num_sequences: 20
|
||||
seq_length: 2048
|
||||
tokenizer_model_path: "configs/llama_tokenizer_extended.model"
|
||||
model:
|
||||
initializer_range: 1.0e-2
|
||||
hidden_dropout_prob: 0.1
|
||||
attention_dropout_prob: 0.1
|
||||
use_stable_embedding: True
|
||||
shared_input_output_embedding: True
|
||||
use_stable_embedding: False
|
||||
shared_input_output_embedding: False
|
||||
train:
|
||||
train_batch_size: 2
|
||||
num_training_steps: 1000000
|
||||
|
@ -19,11 +21,12 @@ train:
|
|||
initializer_range: 1.0e-2
|
||||
lr: 2.0e-4
|
||||
weight_decay: 1.0e-1
|
||||
ckpt: null
|
||||
ckpt: "data/llama_raw_ckpt/7B/extended.pth"
|
||||
train_num_workers: 16
|
||||
gradient_accumulation_steps: 12
|
||||
# global step
|
||||
log_interval: 5
|
||||
eval_interval: 200
|
||||
save_interval: 800
|
||||
work_dir: "data/saved_ckpt/"
|
||||
work_dir: "data/saved_ckpt/7B"
|
||||
project_name: "Llama Pretrain"
|
||||
|
|
|
@ -14,15 +14,100 @@ from glob import glob
|
|||
from datasets import load_dataset, interleave_datasets
|
||||
|
||||
|
||||
def pretrain_transform(line):
|
||||
if "title" in line and "text" not in line:
|
||||
line["text"] = line["title"] + "\n" + line["content"]
|
||||
return line
|
||||
def pretrain_transform(batch):
|
||||
# wudao preprocess
|
||||
if "title" in batch and "content" in batch:
|
||||
assert len(batch["title"]) == 1
|
||||
batch["text"] = [batch["title"][0] + "\n" + batch["content"][0]]
|
||||
elif "text" in batch:
|
||||
pass
|
||||
else:
|
||||
raise Exception("Unrecognized pretrain dataset format.")
|
||||
return batch
|
||||
|
||||
|
||||
def instruct_transform(batch):
|
||||
# self instruct preprocess
|
||||
if "prompt" in batch and "completion" in batch:
|
||||
prompt = batch["prompt"][0]
|
||||
completion = batch["completion"][0]
|
||||
if prompt.endswith("Output:"):
|
||||
prompt = prompt[:-7]
|
||||
text = "user:{}\nsystem:{}".format(prompt.strip(), completion.strip())
|
||||
texts = [text]
|
||||
# belle preprocess
|
||||
elif "instruction" in batch and "output" in batch:
|
||||
prompt = batch["instruction"][0].replace("\\n", "")
|
||||
prompt = prompt.strip("")
|
||||
|
||||
completion = batch["output"][0].replace("\\n", "")
|
||||
completion = completion.strip("")
|
||||
# multi turn chat
|
||||
if "Human:" in prompt:
|
||||
texts = []
|
||||
chats = prompt + completion
|
||||
chats = chats.split("Human:")
|
||||
for chat in chats:
|
||||
if chat.strip() == "":
|
||||
continue
|
||||
res = chat.split("Assistant:")
|
||||
if len(res) != 2:
|
||||
continue
|
||||
prompt, completion = res
|
||||
prompt = prompt.strip()
|
||||
completion = completion.strip()
|
||||
chat = "user:{}\nsystem:{}".format(prompt, completion)
|
||||
texts.append(chat)
|
||||
texts = ["[multiturn_sep]".join(texts)]
|
||||
else:
|
||||
text = "user:{}\nsystem:{}".format(prompt, completion)
|
||||
texts = [text]
|
||||
# instruct code preprocess
|
||||
elif "instruction" in batch and "answer" in batch:
|
||||
prompt = batch["instruction"][0].replace("\\n", "")
|
||||
prompt = prompt.strip("")
|
||||
|
||||
completion = batch["answer"][0].replace("\\n", "")
|
||||
completion = completion.strip("")
|
||||
text = "user:{}\nsystem:{}".format(prompt, completion)
|
||||
texts = [text]
|
||||
# share gpt preprocess
|
||||
elif "conversations" in batch:
|
||||
chats = batch["conversations"][0]
|
||||
if chats[0]["from"] != "human":
|
||||
chats = chats[1:]
|
||||
texts = []
|
||||
for i in range(len(chats) // 2):
|
||||
prompt = chats[2 * i]
|
||||
completion = chats[2 * i + 1]
|
||||
if not (prompt["from"] == "human" and completion["from"] == "gpt"):
|
||||
continue
|
||||
prompt = prompt["value"]
|
||||
prompt = prompt.strip()
|
||||
completion = completion["value"]
|
||||
completion = completion.strip()
|
||||
chat = "user:{}\nsystem:{}".format(prompt, completion)
|
||||
texts.append(chat)
|
||||
texts = ["[multiturn_sep]".join(texts)]
|
||||
else:
|
||||
raise Exception("Unrecognized instruct dataset format.")
|
||||
return {"text": texts}
|
||||
|
||||
|
||||
def split_multiturn(batch):
|
||||
return {"text": batch["text"][0].split("[multiturn_sep]")}
|
||||
|
||||
|
||||
def truncation_gen(seq_length):
|
||||
def truncation(line):
|
||||
return {"input_ids": line["input_ids"][:seq_length]}
|
||||
|
||||
return truncation
|
||||
|
||||
|
||||
def sample_sequence_gen(seq_length, eos_token_id):
|
||||
def sample_sequence(line):
|
||||
doc_length = line["input_ids"].shape[1]
|
||||
doc_length = line["input_ids"].shape[0]
|
||||
if doc_length <= seq_length:
|
||||
start = 0
|
||||
else:
|
||||
|
@ -30,7 +115,7 @@ def sample_sequence_gen(seq_length, eos_token_id):
|
|||
start = 0
|
||||
else:
|
||||
start = random.randint(0, doc_length - seq_length)
|
||||
input_ids = line["input_ids"][0, start : start + seq_length]
|
||||
input_ids = line["input_ids"][start : start + seq_length]
|
||||
if input_ids[-1] != eos_token_id:
|
||||
input_ids[-1] = eos_token_id
|
||||
return {"input_ids": input_ids}
|
||||
|
@ -38,18 +123,28 @@ def sample_sequence_gen(seq_length, eos_token_id):
|
|||
return sample_sequence
|
||||
|
||||
|
||||
def split_sequence_gen(seq_length):
|
||||
def split_sequence(batch):
|
||||
input_ids = batch["input_ids"][0]
|
||||
out = []
|
||||
while len(input_ids) >= (1 + len(out)) * seq_length:
|
||||
out.append(input_ids[len(out) * seq_length : (1 + len(out)) * seq_length])
|
||||
return {"input_ids": out}
|
||||
|
||||
return split_sequence
|
||||
|
||||
|
||||
def concat_multiple_sequence_gen(seq_length):
|
||||
def concat_multiple_sequence(batch):
|
||||
concat_input_ids = torch.cat(batch["input_ids"], dim=0)
|
||||
input_ids = []
|
||||
while len(concat_input_ids) > (1 + len(input_ids)) * seq_length:
|
||||
while len(concat_input_ids) >= (1 + len(input_ids)) * seq_length:
|
||||
input_ids.append(
|
||||
concat_input_ids[
|
||||
len(input_ids) * seq_length : (1 + len(input_ids)) * seq_length
|
||||
]
|
||||
)
|
||||
out = {"input_ids": input_ids}
|
||||
return out
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
return concat_multiple_sequence
|
||||
|
||||
|
@ -75,8 +170,13 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
|||
dataset = load_dataset(
|
||||
"json", data_files=data_files, split="train", streaming=True
|
||||
)
|
||||
# 文本预处理转换为统一格式
|
||||
if dataset_config["mode"] == "pretrain":
|
||||
dataset = dataset.map(pretrain_transform)
|
||||
dataset = dataset.map(pretrain_transform, batched=True, batch_size=1)
|
||||
elif dataset_config["mode"] == "instruct":
|
||||
dataset = dataset.map(instruct_transform, batched=True, batch_size=1)
|
||||
dataset = dataset.select_columns("text")
|
||||
dataset = dataset.map(split_multiturn, batched=True, batch_size=1)
|
||||
else:
|
||||
raise Exception(
|
||||
"Dataset mode: {} not found.".format(dataset_config["mode"])
|
||||
|
@ -84,6 +184,7 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
|||
datasets.append(dataset)
|
||||
probabilities.append(dataset.n_shards)
|
||||
probabilities_sum = sum(probabilities)
|
||||
# 多个数据部分按概率采样
|
||||
probabilities = [p / probabilities_sum for p in probabilities]
|
||||
if len(datasets) > 1:
|
||||
full_dataset = interleave_datasets(
|
||||
|
@ -91,26 +192,16 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
|||
)
|
||||
else:
|
||||
full_dataset = datasets[0]
|
||||
|
||||
# to visualize
|
||||
if return_raw_text:
|
||||
return full_dataset
|
||||
|
||||
seq_length = dataset_config["seq_length"]
|
||||
if dataset_config.get("concat_multiple_sequence", False):
|
||||
num_sequences = dataset_config["num_sequences"]
|
||||
full_dataset = full_dataset.map(
|
||||
lambda x: tokenizer(
|
||||
x["text"], return_tensors="pt", return_attention_mask=False
|
||||
)
|
||||
)
|
||||
full_dataset = full_dataset.map(
|
||||
sample_sequence_gen(seq_length, tokenizer.eos_token_id)
|
||||
)
|
||||
full_dataset = full_dataset.select_columns("input_ids")
|
||||
full_dataset = full_dataset.map(
|
||||
concat_multiple_sequence_gen(seq_length),
|
||||
batched=True,
|
||||
batch_size=num_sequences,
|
||||
)
|
||||
else:
|
||||
sequence_sample_mode = dataset_config.get("sequence_sample_mode", "truncation")
|
||||
truncation = sequence_sample_mode == "truncation"
|
||||
# tokenize
|
||||
if dataset_config.get("pad_to_max", True):
|
||||
full_dataset = full_dataset.map(
|
||||
lambda x: tokenizer(
|
||||
x["text"],
|
||||
|
@ -118,12 +209,57 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
|||
return_attention_mask=False,
|
||||
padding="max_length",
|
||||
max_length=seq_length,
|
||||
truncation=True,
|
||||
truncation=truncation,
|
||||
)
|
||||
)
|
||||
full_dataset = full_dataset.map(lambda x: {"input_ids": x["input_ids"][0]})
|
||||
full_dataset = full_dataset.select_columns("input_ids")
|
||||
else:
|
||||
full_dataset = full_dataset.map(
|
||||
lambda x: tokenizer(
|
||||
x["text"],
|
||||
return_tensors="pt",
|
||||
return_attention_mask=False,
|
||||
truncation=truncation,
|
||||
)
|
||||
)
|
||||
|
||||
# format
|
||||
full_dataset = full_dataset.map(lambda x: {"input_ids": x["input_ids"][0]})
|
||||
full_dataset = full_dataset.select_columns("input_ids")
|
||||
|
||||
# sequence_sample
|
||||
if sequence_sample_mode == "truncation":
|
||||
pass
|
||||
elif sequence_sample_mode == "none":
|
||||
pass
|
||||
elif sequence_sample_mode == "sample":
|
||||
full_dataset = full_dataset.map(
|
||||
sample_sequence_gen(seq_length, tokenizer.eos_token_id)
|
||||
)
|
||||
elif sequence_sample_mode == "split":
|
||||
assert not dataset_config.get("concat_multiple_sequence", False)
|
||||
full_dataset = full_dataset.map(
|
||||
split_sequence_gen(seq_length), batched=True, batch_size=1
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"Unknown sequence_sample mode: {}.".format(sequence_sample_mode)
|
||||
)
|
||||
|
||||
# concat multiple sequence
|
||||
if dataset_config.get("concat_multiple_sequence", False):
|
||||
num_sequences = dataset_config["num_sequences"]
|
||||
full_dataset = full_dataset.map(
|
||||
concat_multiple_sequence_gen(seq_length),
|
||||
batched=True,
|
||||
batch_size=num_sequences,
|
||||
drop_last_batch=True,
|
||||
)
|
||||
|
||||
# add label
|
||||
full_dataset = full_dataset.map(get_labels_gen(tokenizer.pad_token_id))
|
||||
|
||||
# shuffle
|
||||
full_dataset = full_dataset.shuffle()
|
||||
return full_dataset
|
||||
|
||||
|
||||
|
@ -135,7 +271,9 @@ if __name__ == "__main__":
|
|||
|
||||
data_config = {
|
||||
"mode": "pretrain",
|
||||
"data": {"wudao": "data/pretrain_data/part-wudao*.jsonl.zst"},
|
||||
"data": {"mixed": "data/pretrain_data/part-*.jsonl.zst"},
|
||||
"pad_to_max": False,
|
||||
"sequence_sample_mode": "sample",
|
||||
"concat_multiple_sequence": True,
|
||||
"num_sequences": 10,
|
||||
"seq_length": 2048,
|
||||
|
@ -153,8 +291,8 @@ if __name__ == "__main__":
|
|||
# raw_text = normalize("NFKC", raw_text)
|
||||
input_ids = tokenizer(
|
||||
line["text"], return_tensors="pt", return_attention_mask=False
|
||||
)["input_ids"][0, :-1]
|
||||
decode_text = tokenizer.decode(input_ids)
|
||||
)["input_ids"][0]
|
||||
decode_text = tokenizer.decode(input_ids, skip_special_tokens=True)
|
||||
if raw_text != decode_text and "▁" not in raw_text:
|
||||
print(raw_text, "\n", decode_text)
|
||||
if i == 3000:
|
||||
|
|
|
@ -1,178 +0,0 @@
|
|||
"""
|
||||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-30 21:02:00
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-06 03:33:27
|
||||
FilePath: /Open-Llama/dataset/instruction_dataset.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import math
|
||||
|
||||
|
||||
def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
|
||||
def preprocess_self_instruction(line):
|
||||
"""
|
||||
The format of the data is roughly as follows.
|
||||
{'prompt': 'Explain the origin of life on earth. Output:', 'completion': 'Life on Earth is believed to have'}
|
||||
Split the data based on the tokenized length according to the maximum length.
|
||||
"""
|
||||
prompt = line["prompt"]
|
||||
if prompt.endswith("Output:"):
|
||||
prompt = prompt[:-7]
|
||||
total = "user:{}\nsystem:{}".format(prompt.strip(), line["completion"].strip())
|
||||
out = tokenizer(total)
|
||||
input_ids = out["input_ids"]
|
||||
return [
|
||||
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||
]
|
||||
|
||||
return preprocess_self_instruction
|
||||
|
||||
|
||||
def preprocess_belle_gen(tokenizer, segment_max_length=1024):
|
||||
def preprocess_belle(line):
|
||||
"""
|
||||
The format of the data is roughly as follows.
|
||||
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||
Split the data based on the tokenized length according to the maximum length.
|
||||
"""
|
||||
prompt = line["instruction"].replace("\\n", "")
|
||||
prompt = prompt.strip("")
|
||||
|
||||
completion = line["output"].replace("\\n", "")
|
||||
completion = completion.strip("")
|
||||
total = "user:{}\nsystem:{}".format(prompt, completion)
|
||||
out = tokenizer(total)
|
||||
input_ids = out["input_ids"]
|
||||
return [
|
||||
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||
]
|
||||
|
||||
return preprocess_belle
|
||||
|
||||
|
||||
def preprocess_belle_multiturn_chat_gen(tokenizer, segment_max_length=1024):
|
||||
def preprocess_belle_multiturn_chat(line):
|
||||
"""
|
||||
The format of the data is roughly as follows.
|
||||
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||
Split the data based on the tokenized length according to the maximum length.
|
||||
"""
|
||||
prompt = line["instruction"].replace("\\n", "")
|
||||
prompt = prompt.strip("")
|
||||
|
||||
completion = line["output"].replace("\\n", "")
|
||||
completion = completion.strip("")
|
||||
chats = prompt + completion
|
||||
chats = chats.split("Human:")
|
||||
input_ids = []
|
||||
for chat in chats:
|
||||
if chat.strip() == "":
|
||||
continue
|
||||
res = chat.split("Assistant:")
|
||||
if len(res) != 2:
|
||||
continue
|
||||
prompt, completion = res
|
||||
prompt = prompt.strip()
|
||||
completion = completion.strip()
|
||||
chat = "user:{}\nsystem:{}".format(prompt, completion)
|
||||
out = tokenizer(chat)
|
||||
input_ids.extend(out["input_ids"])
|
||||
if len(input_ids) == 0:
|
||||
return None
|
||||
return [
|
||||
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||
]
|
||||
|
||||
return preprocess_belle_multiturn_chat
|
||||
|
||||
|
||||
def preprocess_sharegpt_gen(tokenizer, segment_max_length=1024):
|
||||
def preprocess_sharegpt(line):
|
||||
"""
|
||||
The format of the data is roughly as follows.
|
||||
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||
Split the data based on the tokenized length according to the maximum length.
|
||||
"""
|
||||
chats = line["conversations"]
|
||||
if chats[0]["from"] != "human":
|
||||
chats = chats[1:]
|
||||
input_ids = []
|
||||
for i in range(len(chats) // 2):
|
||||
prompt = chats[2 * i]
|
||||
completion = chats[2 * i + 1]
|
||||
if not (prompt["from"] == "human" and completion["from"] == "gpt"):
|
||||
continue
|
||||
prompt = prompt["value"]
|
||||
prompt = prompt.strip()
|
||||
completion = completion["value"]
|
||||
completion = completion.strip()
|
||||
chat = "user:{}\nsystem:{}".format(prompt, completion)
|
||||
out = tokenizer(chat)
|
||||
input_ids.extend(out["input_ids"])
|
||||
if input_ids == []:
|
||||
return None
|
||||
return [
|
||||
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||
]
|
||||
|
||||
return preprocess_sharegpt
|
||||
|
||||
|
||||
def preprocess_instruct_code_gen(tokenizer, segment_max_length=1024):
|
||||
def preprocess_instruct_code(line):
|
||||
"""
|
||||
The format of the data is roughly as follows.
|
||||
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||
Split the data based on the tokenized length according to the maximum length.
|
||||
"""
|
||||
prompt = line["instruction"].replace("\\n", "")
|
||||
prompt = prompt.strip("")
|
||||
|
||||
completion = line["answer"].replace("\\n", "")
|
||||
completion = completion.strip("")
|
||||
total = "user:{}\nsystem:{}".format(prompt, completion)
|
||||
out = tokenizer(total)
|
||||
input_ids = out["input_ids"]
|
||||
return [
|
||||
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||
]
|
||||
|
||||
return preprocess_instruct_code
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sentencepiece as spm
|
||||
|
||||
from dataset.tokenizer import Tokenizer
|
||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||
|
||||
sp_model = spm.SentencePieceProcessor(
|
||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||
)
|
||||
tokenizer = Tokenizer(sp_model)
|
||||
patterns = ["data/instruction_data/part-belle_multiturn_chat_0.8M-*.jsonl.zst"]
|
||||
paths = create_shard_kwargs(patterns)
|
||||
transform_dict = {
|
||||
"self_instruct": preprocess_self_instruction_gen(tokenizer),
|
||||
"belle_1M": preprocess_belle_gen(tokenizer),
|
||||
"belle_0.5M": preprocess_belle_gen(tokenizer),
|
||||
"belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
|
||||
"belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
|
||||
"instruct_to_code": preprocess_instruct_code_gen(tokenizer),
|
||||
"sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
|
||||
}
|
||||
data_set = DataIter(
|
||||
paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
|
||||
)
|
||||
for i, sample in enumerate(data_set):
|
||||
print(sp_model.decode(sample))
|
||||
if i == 1:
|
||||
break
|
|
@ -1,189 +0,0 @@
|
|||
"""
|
||||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-30 21:35:01
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-06 03:35:31
|
||||
FilePath: /Open-Llama/inctruction_tuning.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import wandb
|
||||
import torch
|
||||
import random
|
||||
import sentencepiece as spm
|
||||
from torchinfo import summary
|
||||
from accelerate import Accelerator
|
||||
from torch.utils.data import DataLoader
|
||||
from deepspeed.ops.adam import FusedAdam
|
||||
from transformers import (
|
||||
OpenLlamaForCausalLM,
|
||||
OpenLlamaConfig,
|
||||
get_cosine_schedule_with_warmup,
|
||||
)
|
||||
|
||||
from dataset.validation import val_set
|
||||
from dataset.tokenizer import Tokenizer
|
||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||
from dataset.collate_fn import collate_fn_gen
|
||||
from dataset.instruction_dataset import (
|
||||
preprocess_belle_gen,
|
||||
preprocess_self_instruction_gen,
|
||||
preprocess_belle_multiturn_chat_gen,
|
||||
preprocess_instruct_code_gen,
|
||||
preprocess_sharegpt_gen,
|
||||
)
|
||||
from configs.instruction_tuning_config import *
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
wandb.init(project="LLAMA Instruction")
|
||||
|
||||
log_interval *= accelerator.gradient_accumulation_steps
|
||||
eval_interval *= accelerator.gradient_accumulation_steps
|
||||
save_interval *= accelerator.gradient_accumulation_steps
|
||||
|
||||
sp_model = spm.SentencePieceProcessor(model_file=tokenizer_model_path)
|
||||
tokenizer = Tokenizer(sp_model)
|
||||
|
||||
paths = create_shard_kwargs(patterns, repeat=3)
|
||||
random.shuffle(paths)
|
||||
transform_dict = {
|
||||
"self_instruct": preprocess_self_instruction_gen(tokenizer),
|
||||
"belle_1M": preprocess_belle_gen(tokenizer),
|
||||
"belle_0.5M": preprocess_belle_gen(tokenizer),
|
||||
"belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
|
||||
"belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
|
||||
"instruct_to_code": preprocess_instruct_code_gen(tokenizer),
|
||||
"sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
|
||||
}
|
||||
data_set = DataIter(
|
||||
paths,
|
||||
transform_dict=transform_dict,
|
||||
concat_docs=True,
|
||||
max_length=max_length,
|
||||
process_index=accelerator.process_index,
|
||||
num_processes=accelerator.num_processes,
|
||||
)
|
||||
train_loader = DataLoader(
|
||||
data_set,
|
||||
batch_size=train_batch_size,
|
||||
# If num_workers is greater than 1, duplicate data may occur.
|
||||
num_workers=0,
|
||||
collate_fn=collate_fn_gen(tokenizer, max_length),
|
||||
drop_last=True,
|
||||
)
|
||||
# smaller initializer_range make training more stable
|
||||
# add stabel embedding to token embedding
|
||||
raw_model = OpenLlamaForCausalLM(
|
||||
OpenLlamaConfig(
|
||||
vocab_size=tokenizer.vocab_size,
|
||||
initializer_range=initializer_range,
|
||||
pad_token_id=tokenizer.pad_id,
|
||||
rms_norm_eps=1e-5,
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_dropout_prob=0.1,
|
||||
use_stable_embedding=True,
|
||||
shared_input_output_embedding=True,
|
||||
)
|
||||
)
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
raw_model.load_state_dict(ckpt)
|
||||
raw_model.eval()
|
||||
with torch.no_grad():
|
||||
summary(raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda())
|
||||
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in raw_model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in raw_model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optim = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=(0.9, 0.95))
|
||||
optim.zero_grad()
|
||||
factor = accelerator.num_processes / accelerator.gradient_accumulation_steps
|
||||
scheduler = get_cosine_schedule_with_warmup(
|
||||
optim,
|
||||
num_warmup_steps=num_warmup_steps * factor,
|
||||
num_training_steps=num_training_steps * factor,
|
||||
)
|
||||
|
||||
_, model, optim, scheduler = accelerator.prepare(
|
||||
train_loader, raw_model, optim, scheduler
|
||||
)
|
||||
print("start training...")
|
||||
train_loader_iter = iter(train_loader)
|
||||
global_step = 0
|
||||
start_time = time.time()
|
||||
for data_step in range(num_training_steps):
|
||||
model.train()
|
||||
with accelerator.accumulate(model):
|
||||
batch = next(train_loader_iter)
|
||||
for k, v in batch.items():
|
||||
batch[k] = v.to(accelerator.device, non_blocking=True)
|
||||
out = model(**batch, labels=batch["input_ids"])
|
||||
total_loss = out.loss
|
||||
losses = {"total_loss": total_loss}
|
||||
accelerator.backward(total_loss)
|
||||
optim.step()
|
||||
scheduler.step()
|
||||
optim.zero_grad()
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
if data_step % log_interval == 0 and data_step > 0 and accelerator.is_main_process:
|
||||
cost_time = time.time() - start_time
|
||||
start_time = time.time()
|
||||
tokens = train_batch_size * log_interval * max_length
|
||||
wandb.log({"Training/Token per second per gpu": tokens / cost_time})
|
||||
for k, v in losses.items():
|
||||
wandb.log({"Losses/{}".format(k): v})
|
||||
current_lr = optim.param_groups[0]["lr"]
|
||||
wandb.log({"Training/LR": current_lr})
|
||||
if optim.scaler is not None:
|
||||
wandb.log({"Training/Loss Scale": optim.scaler.get_scale()})
|
||||
wandb.log({"Training/Data Step": data_step})
|
||||
wandb.log({"Training/Global Step": global_step})
|
||||
accelerator.print(
|
||||
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
|
||||
global_step, data_step, losses["total_loss"], tokens / cost_time
|
||||
)
|
||||
)
|
||||
if data_step % eval_interval == 0 and accelerator.is_main_process:
|
||||
text_table = wandb.Table(columns=["question", "pred"])
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for data in val_set:
|
||||
raw_inputs = data
|
||||
inputs_len = len(raw_inputs)
|
||||
inputs = tokenizer(
|
||||
raw_inputs, return_tensors=True, add_special_tokens=False
|
||||
)
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.to(accelerator.device)
|
||||
pred = model.generate(
|
||||
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
|
||||
)
|
||||
pred = tokenizer.decode(pred.cpu())[0]
|
||||
pred = pred[inputs_len:]
|
||||
text_table.add_data(raw_inputs, pred)
|
||||
wandb.log({"Predictions on {}".format(global_step): text_table})
|
||||
if data_step % save_interval == 0 and data_step > 0 and accelerator.is_main_process:
|
||||
if not os.path.isdir(work_dir):
|
||||
os.mkdir(work_dir)
|
||||
torch.save(raw_model.state_dict(), "{}/{}.pt".format(work_dir, global_step))
|
||||
wandb.finish()
|
|
@ -191,6 +191,6 @@ class Trainer:
|
|||
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
|
||||
)
|
||||
pred = pred[0, input_length:]
|
||||
pred = self.tokenizer.decode(pred.cpu())
|
||||
pred = self.tokenizer.decode(pred.cpu(), skip_special_tokens=True)
|
||||
text_table.add_data(raw_inputs, pred)
|
||||
wandb.log({"Predictions on {}".format(self.global_step): text_table})
|
||||
|
|
|
@ -8,7 +8,6 @@ Description:
|
|||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import os
|
||||
import yaml
|
||||
import torch
|
||||
from absl import app
|
||||
|
@ -26,10 +25,14 @@ flags.DEFINE_string("config", None, "Training config path")
|
|||
|
||||
|
||||
def main(argv):
|
||||
accelerator = Accelerator()
|
||||
|
||||
with open(FLAGS.config, "r", encoding="utf-8") as fp:
|
||||
config = yaml.load(fp, Loader=yaml.FullLoader)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=config["train"].get(
|
||||
"gradient_accumulation_steps", 1
|
||||
)
|
||||
)
|
||||
tokenizer = LlamaTokenizer(
|
||||
config["data"]["tokenizer_model_path"],
|
||||
pad_token="<pad>",
|
||||
|
@ -37,14 +40,14 @@ def main(argv):
|
|||
add_eos_token=True,
|
||||
)
|
||||
data_config = config["data"]
|
||||
pretrain_dataset = construct_dataset(data_config, tokenizer)
|
||||
pretrain_dataset = split_dataset_by_node(
|
||||
pretrain_dataset,
|
||||
rank=int(os.environ["RANK"]),
|
||||
world_size=int(os.environ["WORLD_SIZE"]),
|
||||
train_dataset = construct_dataset(data_config, tokenizer)
|
||||
train_dataset = split_dataset_by_node(
|
||||
train_dataset,
|
||||
rank=accelerator.process_index,
|
||||
world_size=accelerator.num_processes,
|
||||
)
|
||||
train_loader = DataLoader(
|
||||
pretrain_dataset,
|
||||
train_dataset,
|
||||
batch_size=config["train"]["train_batch_size"],
|
||||
num_workers=config["train"]["train_num_workers"],
|
||||
)
|
Loading…
Reference in New Issue
Block a user