unified pre-training and instrcution-tuning both use train_lm and dataset
This commit is contained in:
parent
97aff0e051
commit
db6cdb51d0
|
@ -102,8 +102,7 @@ with gr.Blocks() as demo:
|
||||||
context = context.cuda()
|
context = context.cuda()
|
||||||
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
|
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
|
||||||
pred = pred[:, inputs_len:]
|
pred = pred[:, inputs_len:]
|
||||||
pred = tokenizer.decode(pred.cpu()[0])
|
pred = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
|
||||||
pred = pred.strip()
|
|
||||||
print(pred)
|
print(pred)
|
||||||
bot_message = parse_codeblock(pred)
|
bot_message = parse_codeblock(pred)
|
||||||
history[-1][1] = bot_message
|
history[-1][1] = bot_message
|
||||||
|
|
Binary file not shown.
|
@ -1,7 +1,6 @@
|
||||||
compute_environment: LOCAL_MACHINE
|
compute_environment: LOCAL_MACHINE
|
||||||
deepspeed_config:
|
deepspeed_config:
|
||||||
deepspeed_multinode_launcher: standard
|
deepspeed_multinode_launcher: standard
|
||||||
gradient_accumulation_steps: 12
|
|
||||||
gradient_clipping: 1.0
|
gradient_clipping: 1.0
|
||||||
offload_optimizer_device: none
|
offload_optimizer_device: none
|
||||||
offload_param_device: none
|
offload_param_device: none
|
||||||
|
|
32
configs/instruct_config.yaml
Normal file
32
configs/instruct_config.yaml
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
data:
|
||||||
|
mode: "instruct"
|
||||||
|
data:
|
||||||
|
mixed: "data/instruction_data/part-*.jsonl.zst"
|
||||||
|
pad_to_max: False
|
||||||
|
sequence_sample_mode: "sample"
|
||||||
|
concat_multiple_sequence: True
|
||||||
|
num_sequences: 50
|
||||||
|
seq_length: 2048
|
||||||
|
tokenizer_model_path: "configs/llama_tokenizer_extended.model"
|
||||||
|
model:
|
||||||
|
initializer_range: 1.0e-2
|
||||||
|
hidden_dropout_prob: 0.1
|
||||||
|
attention_dropout_prob: 0.1
|
||||||
|
use_stable_embedding: False
|
||||||
|
shared_input_output_embedding: False
|
||||||
|
train:
|
||||||
|
train_batch_size: 2
|
||||||
|
num_training_steps: 1000000
|
||||||
|
num_warmup_steps: 2000
|
||||||
|
initializer_range: 1.0e-2
|
||||||
|
lr: 2.0e-4
|
||||||
|
weight_decay: 1.0e-1
|
||||||
|
ckpt: "data/llama_raw_ckpt/7B/extended.pth"
|
||||||
|
train_num_workers: 16
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
# global step
|
||||||
|
log_interval: 50
|
||||||
|
eval_interval: 500
|
||||||
|
save_interval: 1000
|
||||||
|
work_dir: "data/saved_ckpt/7B"
|
||||||
|
project_name: "Llama Instruction"
|
|
@ -1,25 +0,0 @@
|
||||||
"""
|
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
|
||||||
Date: 2023-03-30 21:38:07
|
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
|
||||||
LastEditTime: 2023-04-06 03:37:23
|
|
||||||
FilePath: /Open-Llama/configs/instruction_tuning_config.py
|
|
||||||
Description:
|
|
||||||
|
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
|
||||||
"""
|
|
||||||
max_length = 1024
|
|
||||||
train_batch_size = 2
|
|
||||||
num_training_steps = 40000
|
|
||||||
num_warmup_steps = 100
|
|
||||||
initializer_range = 1e-2
|
|
||||||
lr = 2e-4
|
|
||||||
weight_decay = 1e-1
|
|
||||||
tokenizer_model_path = "configs/10w_vocab_wudao5_pile10.model"
|
|
||||||
patterns = ["data/instruction_data/part-*.jsonl.zst"]
|
|
||||||
# global step
|
|
||||||
log_interval = 50
|
|
||||||
eval_interval = 500
|
|
||||||
save_interval = 1000
|
|
||||||
work_dir = "data/saved_ckpt/"
|
|
||||||
ckpt_path = "data/saved_ckpt/83200.pt"
|
|
|
@ -2,16 +2,18 @@ data:
|
||||||
mode: "pretrain"
|
mode: "pretrain"
|
||||||
data:
|
data:
|
||||||
mixed: "data/pretrain_data/part-*.jsonl.zst"
|
mixed: "data/pretrain_data/part-*.jsonl.zst"
|
||||||
|
pad_to_max: False
|
||||||
|
sequence_sample_mode: "sample"
|
||||||
concat_multiple_sequence: True
|
concat_multiple_sequence: True
|
||||||
num_sequences: 10
|
num_sequences: 20
|
||||||
seq_length: 2048
|
seq_length: 2048
|
||||||
tokenizer_model_path: "configs/llama_tokenizer_extended.model"
|
tokenizer_model_path: "configs/llama_tokenizer_extended.model"
|
||||||
model:
|
model:
|
||||||
initializer_range: 1.0e-2
|
initializer_range: 1.0e-2
|
||||||
hidden_dropout_prob: 0.1
|
hidden_dropout_prob: 0.1
|
||||||
attention_dropout_prob: 0.1
|
attention_dropout_prob: 0.1
|
||||||
use_stable_embedding: True
|
use_stable_embedding: False
|
||||||
shared_input_output_embedding: True
|
shared_input_output_embedding: False
|
||||||
train:
|
train:
|
||||||
train_batch_size: 2
|
train_batch_size: 2
|
||||||
num_training_steps: 1000000
|
num_training_steps: 1000000
|
||||||
|
@ -19,11 +21,12 @@ train:
|
||||||
initializer_range: 1.0e-2
|
initializer_range: 1.0e-2
|
||||||
lr: 2.0e-4
|
lr: 2.0e-4
|
||||||
weight_decay: 1.0e-1
|
weight_decay: 1.0e-1
|
||||||
ckpt: null
|
ckpt: "data/llama_raw_ckpt/7B/extended.pth"
|
||||||
train_num_workers: 16
|
train_num_workers: 16
|
||||||
|
gradient_accumulation_steps: 12
|
||||||
# global step
|
# global step
|
||||||
log_interval: 5
|
log_interval: 5
|
||||||
eval_interval: 200
|
eval_interval: 200
|
||||||
save_interval: 800
|
save_interval: 800
|
||||||
work_dir: "data/saved_ckpt/"
|
work_dir: "data/saved_ckpt/7B"
|
||||||
project_name: "Llama Pretrain"
|
project_name: "Llama Pretrain"
|
||||||
|
|
|
@ -14,15 +14,100 @@ from glob import glob
|
||||||
from datasets import load_dataset, interleave_datasets
|
from datasets import load_dataset, interleave_datasets
|
||||||
|
|
||||||
|
|
||||||
def pretrain_transform(line):
|
def pretrain_transform(batch):
|
||||||
if "title" in line and "text" not in line:
|
# wudao preprocess
|
||||||
line["text"] = line["title"] + "\n" + line["content"]
|
if "title" in batch and "content" in batch:
|
||||||
return line
|
assert len(batch["title"]) == 1
|
||||||
|
batch["text"] = [batch["title"][0] + "\n" + batch["content"][0]]
|
||||||
|
elif "text" in batch:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise Exception("Unrecognized pretrain dataset format.")
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def instruct_transform(batch):
|
||||||
|
# self instruct preprocess
|
||||||
|
if "prompt" in batch and "completion" in batch:
|
||||||
|
prompt = batch["prompt"][0]
|
||||||
|
completion = batch["completion"][0]
|
||||||
|
if prompt.endswith("Output:"):
|
||||||
|
prompt = prompt[:-7]
|
||||||
|
text = "user:{}\nsystem:{}".format(prompt.strip(), completion.strip())
|
||||||
|
texts = [text]
|
||||||
|
# belle preprocess
|
||||||
|
elif "instruction" in batch and "output" in batch:
|
||||||
|
prompt = batch["instruction"][0].replace("\\n", "")
|
||||||
|
prompt = prompt.strip("")
|
||||||
|
|
||||||
|
completion = batch["output"][0].replace("\\n", "")
|
||||||
|
completion = completion.strip("")
|
||||||
|
# multi turn chat
|
||||||
|
if "Human:" in prompt:
|
||||||
|
texts = []
|
||||||
|
chats = prompt + completion
|
||||||
|
chats = chats.split("Human:")
|
||||||
|
for chat in chats:
|
||||||
|
if chat.strip() == "":
|
||||||
|
continue
|
||||||
|
res = chat.split("Assistant:")
|
||||||
|
if len(res) != 2:
|
||||||
|
continue
|
||||||
|
prompt, completion = res
|
||||||
|
prompt = prompt.strip()
|
||||||
|
completion = completion.strip()
|
||||||
|
chat = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
|
texts.append(chat)
|
||||||
|
texts = ["[multiturn_sep]".join(texts)]
|
||||||
|
else:
|
||||||
|
text = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
|
texts = [text]
|
||||||
|
# instruct code preprocess
|
||||||
|
elif "instruction" in batch and "answer" in batch:
|
||||||
|
prompt = batch["instruction"][0].replace("\\n", "")
|
||||||
|
prompt = prompt.strip("")
|
||||||
|
|
||||||
|
completion = batch["answer"][0].replace("\\n", "")
|
||||||
|
completion = completion.strip("")
|
||||||
|
text = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
|
texts = [text]
|
||||||
|
# share gpt preprocess
|
||||||
|
elif "conversations" in batch:
|
||||||
|
chats = batch["conversations"][0]
|
||||||
|
if chats[0]["from"] != "human":
|
||||||
|
chats = chats[1:]
|
||||||
|
texts = []
|
||||||
|
for i in range(len(chats) // 2):
|
||||||
|
prompt = chats[2 * i]
|
||||||
|
completion = chats[2 * i + 1]
|
||||||
|
if not (prompt["from"] == "human" and completion["from"] == "gpt"):
|
||||||
|
continue
|
||||||
|
prompt = prompt["value"]
|
||||||
|
prompt = prompt.strip()
|
||||||
|
completion = completion["value"]
|
||||||
|
completion = completion.strip()
|
||||||
|
chat = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
|
texts.append(chat)
|
||||||
|
texts = ["[multiturn_sep]".join(texts)]
|
||||||
|
else:
|
||||||
|
raise Exception("Unrecognized instruct dataset format.")
|
||||||
|
return {"text": texts}
|
||||||
|
|
||||||
|
|
||||||
|
def split_multiturn(batch):
|
||||||
|
return {"text": batch["text"][0].split("[multiturn_sep]")}
|
||||||
|
|
||||||
|
|
||||||
|
def truncation_gen(seq_length):
|
||||||
|
def truncation(line):
|
||||||
|
return {"input_ids": line["input_ids"][:seq_length]}
|
||||||
|
|
||||||
|
return truncation
|
||||||
|
|
||||||
|
|
||||||
def sample_sequence_gen(seq_length, eos_token_id):
|
def sample_sequence_gen(seq_length, eos_token_id):
|
||||||
def sample_sequence(line):
|
def sample_sequence(line):
|
||||||
doc_length = line["input_ids"].shape[1]
|
doc_length = line["input_ids"].shape[0]
|
||||||
if doc_length <= seq_length:
|
if doc_length <= seq_length:
|
||||||
start = 0
|
start = 0
|
||||||
else:
|
else:
|
||||||
|
@ -30,7 +115,7 @@ def sample_sequence_gen(seq_length, eos_token_id):
|
||||||
start = 0
|
start = 0
|
||||||
else:
|
else:
|
||||||
start = random.randint(0, doc_length - seq_length)
|
start = random.randint(0, doc_length - seq_length)
|
||||||
input_ids = line["input_ids"][0, start : start + seq_length]
|
input_ids = line["input_ids"][start : start + seq_length]
|
||||||
if input_ids[-1] != eos_token_id:
|
if input_ids[-1] != eos_token_id:
|
||||||
input_ids[-1] = eos_token_id
|
input_ids[-1] = eos_token_id
|
||||||
return {"input_ids": input_ids}
|
return {"input_ids": input_ids}
|
||||||
|
@ -38,18 +123,28 @@ def sample_sequence_gen(seq_length, eos_token_id):
|
||||||
return sample_sequence
|
return sample_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def split_sequence_gen(seq_length):
|
||||||
|
def split_sequence(batch):
|
||||||
|
input_ids = batch["input_ids"][0]
|
||||||
|
out = []
|
||||||
|
while len(input_ids) >= (1 + len(out)) * seq_length:
|
||||||
|
out.append(input_ids[len(out) * seq_length : (1 + len(out)) * seq_length])
|
||||||
|
return {"input_ids": out}
|
||||||
|
|
||||||
|
return split_sequence
|
||||||
|
|
||||||
|
|
||||||
def concat_multiple_sequence_gen(seq_length):
|
def concat_multiple_sequence_gen(seq_length):
|
||||||
def concat_multiple_sequence(batch):
|
def concat_multiple_sequence(batch):
|
||||||
concat_input_ids = torch.cat(batch["input_ids"], dim=0)
|
concat_input_ids = torch.cat(batch["input_ids"], dim=0)
|
||||||
input_ids = []
|
input_ids = []
|
||||||
while len(concat_input_ids) > (1 + len(input_ids)) * seq_length:
|
while len(concat_input_ids) >= (1 + len(input_ids)) * seq_length:
|
||||||
input_ids.append(
|
input_ids.append(
|
||||||
concat_input_ids[
|
concat_input_ids[
|
||||||
len(input_ids) * seq_length : (1 + len(input_ids)) * seq_length
|
len(input_ids) * seq_length : (1 + len(input_ids)) * seq_length
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
out = {"input_ids": input_ids}
|
return {"input_ids": input_ids}
|
||||||
return out
|
|
||||||
|
|
||||||
return concat_multiple_sequence
|
return concat_multiple_sequence
|
||||||
|
|
||||||
|
@ -75,8 +170,13 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
"json", data_files=data_files, split="train", streaming=True
|
"json", data_files=data_files, split="train", streaming=True
|
||||||
)
|
)
|
||||||
|
# 文本预处理转换为统一格式
|
||||||
if dataset_config["mode"] == "pretrain":
|
if dataset_config["mode"] == "pretrain":
|
||||||
dataset = dataset.map(pretrain_transform)
|
dataset = dataset.map(pretrain_transform, batched=True, batch_size=1)
|
||||||
|
elif dataset_config["mode"] == "instruct":
|
||||||
|
dataset = dataset.map(instruct_transform, batched=True, batch_size=1)
|
||||||
|
dataset = dataset.select_columns("text")
|
||||||
|
dataset = dataset.map(split_multiturn, batched=True, batch_size=1)
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Dataset mode: {} not found.".format(dataset_config["mode"])
|
"Dataset mode: {} not found.".format(dataset_config["mode"])
|
||||||
|
@ -84,6 +184,7 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
||||||
datasets.append(dataset)
|
datasets.append(dataset)
|
||||||
probabilities.append(dataset.n_shards)
|
probabilities.append(dataset.n_shards)
|
||||||
probabilities_sum = sum(probabilities)
|
probabilities_sum = sum(probabilities)
|
||||||
|
# 多个数据部分按概率采样
|
||||||
probabilities = [p / probabilities_sum for p in probabilities]
|
probabilities = [p / probabilities_sum for p in probabilities]
|
||||||
if len(datasets) > 1:
|
if len(datasets) > 1:
|
||||||
full_dataset = interleave_datasets(
|
full_dataset = interleave_datasets(
|
||||||
|
@ -91,26 +192,16 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
full_dataset = datasets[0]
|
full_dataset = datasets[0]
|
||||||
|
|
||||||
|
# to visualize
|
||||||
if return_raw_text:
|
if return_raw_text:
|
||||||
return full_dataset
|
return full_dataset
|
||||||
|
|
||||||
seq_length = dataset_config["seq_length"]
|
seq_length = dataset_config["seq_length"]
|
||||||
if dataset_config.get("concat_multiple_sequence", False):
|
sequence_sample_mode = dataset_config.get("sequence_sample_mode", "truncation")
|
||||||
num_sequences = dataset_config["num_sequences"]
|
truncation = sequence_sample_mode == "truncation"
|
||||||
full_dataset = full_dataset.map(
|
# tokenize
|
||||||
lambda x: tokenizer(
|
if dataset_config.get("pad_to_max", True):
|
||||||
x["text"], return_tensors="pt", return_attention_mask=False
|
|
||||||
)
|
|
||||||
)
|
|
||||||
full_dataset = full_dataset.map(
|
|
||||||
sample_sequence_gen(seq_length, tokenizer.eos_token_id)
|
|
||||||
)
|
|
||||||
full_dataset = full_dataset.select_columns("input_ids")
|
|
||||||
full_dataset = full_dataset.map(
|
|
||||||
concat_multiple_sequence_gen(seq_length),
|
|
||||||
batched=True,
|
|
||||||
batch_size=num_sequences,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
full_dataset = full_dataset.map(
|
full_dataset = full_dataset.map(
|
||||||
lambda x: tokenizer(
|
lambda x: tokenizer(
|
||||||
x["text"],
|
x["text"],
|
||||||
|
@ -118,12 +209,57 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
||||||
return_attention_mask=False,
|
return_attention_mask=False,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
max_length=seq_length,
|
max_length=seq_length,
|
||||||
truncation=True,
|
truncation=truncation,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
full_dataset = full_dataset.map(
|
||||||
|
lambda x: tokenizer(
|
||||||
|
x["text"],
|
||||||
|
return_tensors="pt",
|
||||||
|
return_attention_mask=False,
|
||||||
|
truncation=truncation,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# format
|
||||||
full_dataset = full_dataset.map(lambda x: {"input_ids": x["input_ids"][0]})
|
full_dataset = full_dataset.map(lambda x: {"input_ids": x["input_ids"][0]})
|
||||||
full_dataset = full_dataset.select_columns("input_ids")
|
full_dataset = full_dataset.select_columns("input_ids")
|
||||||
|
|
||||||
|
# sequence_sample
|
||||||
|
if sequence_sample_mode == "truncation":
|
||||||
|
pass
|
||||||
|
elif sequence_sample_mode == "none":
|
||||||
|
pass
|
||||||
|
elif sequence_sample_mode == "sample":
|
||||||
|
full_dataset = full_dataset.map(
|
||||||
|
sample_sequence_gen(seq_length, tokenizer.eos_token_id)
|
||||||
|
)
|
||||||
|
elif sequence_sample_mode == "split":
|
||||||
|
assert not dataset_config.get("concat_multiple_sequence", False)
|
||||||
|
full_dataset = full_dataset.map(
|
||||||
|
split_sequence_gen(seq_length), batched=True, batch_size=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
"Unknown sequence_sample mode: {}.".format(sequence_sample_mode)
|
||||||
|
)
|
||||||
|
|
||||||
|
# concat multiple sequence
|
||||||
|
if dataset_config.get("concat_multiple_sequence", False):
|
||||||
|
num_sequences = dataset_config["num_sequences"]
|
||||||
|
full_dataset = full_dataset.map(
|
||||||
|
concat_multiple_sequence_gen(seq_length),
|
||||||
|
batched=True,
|
||||||
|
batch_size=num_sequences,
|
||||||
|
drop_last_batch=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# add label
|
||||||
full_dataset = full_dataset.map(get_labels_gen(tokenizer.pad_token_id))
|
full_dataset = full_dataset.map(get_labels_gen(tokenizer.pad_token_id))
|
||||||
|
|
||||||
|
# shuffle
|
||||||
|
full_dataset = full_dataset.shuffle()
|
||||||
return full_dataset
|
return full_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -135,7 +271,9 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
data_config = {
|
data_config = {
|
||||||
"mode": "pretrain",
|
"mode": "pretrain",
|
||||||
"data": {"wudao": "data/pretrain_data/part-wudao*.jsonl.zst"},
|
"data": {"mixed": "data/pretrain_data/part-*.jsonl.zst"},
|
||||||
|
"pad_to_max": False,
|
||||||
|
"sequence_sample_mode": "sample",
|
||||||
"concat_multiple_sequence": True,
|
"concat_multiple_sequence": True,
|
||||||
"num_sequences": 10,
|
"num_sequences": 10,
|
||||||
"seq_length": 2048,
|
"seq_length": 2048,
|
||||||
|
@ -153,8 +291,8 @@ if __name__ == "__main__":
|
||||||
# raw_text = normalize("NFKC", raw_text)
|
# raw_text = normalize("NFKC", raw_text)
|
||||||
input_ids = tokenizer(
|
input_ids = tokenizer(
|
||||||
line["text"], return_tensors="pt", return_attention_mask=False
|
line["text"], return_tensors="pt", return_attention_mask=False
|
||||||
)["input_ids"][0, :-1]
|
)["input_ids"][0]
|
||||||
decode_text = tokenizer.decode(input_ids)
|
decode_text = tokenizer.decode(input_ids, skip_special_tokens=True)
|
||||||
if raw_text != decode_text and "▁" not in raw_text:
|
if raw_text != decode_text and "▁" not in raw_text:
|
||||||
print(raw_text, "\n", decode_text)
|
print(raw_text, "\n", decode_text)
|
||||||
if i == 3000:
|
if i == 3000:
|
||||||
|
|
|
@ -1,178 +0,0 @@
|
||||||
"""
|
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
|
||||||
Date: 2023-03-30 21:02:00
|
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
|
||||||
LastEditTime: 2023-04-06 03:33:27
|
|
||||||
FilePath: /Open-Llama/dataset/instruction_dataset.py
|
|
||||||
Description:
|
|
||||||
|
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
|
||||||
"""
|
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
|
|
||||||
def preprocess_self_instruction(line):
|
|
||||||
"""
|
|
||||||
The format of the data is roughly as follows.
|
|
||||||
{'prompt': 'Explain the origin of life on earth. Output:', 'completion': 'Life on Earth is believed to have'}
|
|
||||||
Split the data based on the tokenized length according to the maximum length.
|
|
||||||
"""
|
|
||||||
prompt = line["prompt"]
|
|
||||||
if prompt.endswith("Output:"):
|
|
||||||
prompt = prompt[:-7]
|
|
||||||
total = "user:{}\nsystem:{}".format(prompt.strip(), line["completion"].strip())
|
|
||||||
out = tokenizer(total)
|
|
||||||
input_ids = out["input_ids"]
|
|
||||||
return [
|
|
||||||
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
|
||||||
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
|
||||||
]
|
|
||||||
|
|
||||||
return preprocess_self_instruction
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_belle_gen(tokenizer, segment_max_length=1024):
|
|
||||||
def preprocess_belle(line):
|
|
||||||
"""
|
|
||||||
The format of the data is roughly as follows.
|
|
||||||
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
|
||||||
Split the data based on the tokenized length according to the maximum length.
|
|
||||||
"""
|
|
||||||
prompt = line["instruction"].replace("\\n", "")
|
|
||||||
prompt = prompt.strip("")
|
|
||||||
|
|
||||||
completion = line["output"].replace("\\n", "")
|
|
||||||
completion = completion.strip("")
|
|
||||||
total = "user:{}\nsystem:{}".format(prompt, completion)
|
|
||||||
out = tokenizer(total)
|
|
||||||
input_ids = out["input_ids"]
|
|
||||||
return [
|
|
||||||
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
|
||||||
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
|
||||||
]
|
|
||||||
|
|
||||||
return preprocess_belle
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_belle_multiturn_chat_gen(tokenizer, segment_max_length=1024):
|
|
||||||
def preprocess_belle_multiturn_chat(line):
|
|
||||||
"""
|
|
||||||
The format of the data is roughly as follows.
|
|
||||||
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
|
||||||
Split the data based on the tokenized length according to the maximum length.
|
|
||||||
"""
|
|
||||||
prompt = line["instruction"].replace("\\n", "")
|
|
||||||
prompt = prompt.strip("")
|
|
||||||
|
|
||||||
completion = line["output"].replace("\\n", "")
|
|
||||||
completion = completion.strip("")
|
|
||||||
chats = prompt + completion
|
|
||||||
chats = chats.split("Human:")
|
|
||||||
input_ids = []
|
|
||||||
for chat in chats:
|
|
||||||
if chat.strip() == "":
|
|
||||||
continue
|
|
||||||
res = chat.split("Assistant:")
|
|
||||||
if len(res) != 2:
|
|
||||||
continue
|
|
||||||
prompt, completion = res
|
|
||||||
prompt = prompt.strip()
|
|
||||||
completion = completion.strip()
|
|
||||||
chat = "user:{}\nsystem:{}".format(prompt, completion)
|
|
||||||
out = tokenizer(chat)
|
|
||||||
input_ids.extend(out["input_ids"])
|
|
||||||
if len(input_ids) == 0:
|
|
||||||
return None
|
|
||||||
return [
|
|
||||||
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
|
||||||
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
|
||||||
]
|
|
||||||
|
|
||||||
return preprocess_belle_multiturn_chat
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_sharegpt_gen(tokenizer, segment_max_length=1024):
|
|
||||||
def preprocess_sharegpt(line):
|
|
||||||
"""
|
|
||||||
The format of the data is roughly as follows.
|
|
||||||
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
|
||||||
Split the data based on the tokenized length according to the maximum length.
|
|
||||||
"""
|
|
||||||
chats = line["conversations"]
|
|
||||||
if chats[0]["from"] != "human":
|
|
||||||
chats = chats[1:]
|
|
||||||
input_ids = []
|
|
||||||
for i in range(len(chats) // 2):
|
|
||||||
prompt = chats[2 * i]
|
|
||||||
completion = chats[2 * i + 1]
|
|
||||||
if not (prompt["from"] == "human" and completion["from"] == "gpt"):
|
|
||||||
continue
|
|
||||||
prompt = prompt["value"]
|
|
||||||
prompt = prompt.strip()
|
|
||||||
completion = completion["value"]
|
|
||||||
completion = completion.strip()
|
|
||||||
chat = "user:{}\nsystem:{}".format(prompt, completion)
|
|
||||||
out = tokenizer(chat)
|
|
||||||
input_ids.extend(out["input_ids"])
|
|
||||||
if input_ids == []:
|
|
||||||
return None
|
|
||||||
return [
|
|
||||||
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
|
||||||
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
|
||||||
]
|
|
||||||
|
|
||||||
return preprocess_sharegpt
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_instruct_code_gen(tokenizer, segment_max_length=1024):
|
|
||||||
def preprocess_instruct_code(line):
|
|
||||||
"""
|
|
||||||
The format of the data is roughly as follows.
|
|
||||||
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
|
||||||
Split the data based on the tokenized length according to the maximum length.
|
|
||||||
"""
|
|
||||||
prompt = line["instruction"].replace("\\n", "")
|
|
||||||
prompt = prompt.strip("")
|
|
||||||
|
|
||||||
completion = line["answer"].replace("\\n", "")
|
|
||||||
completion = completion.strip("")
|
|
||||||
total = "user:{}\nsystem:{}".format(prompt, completion)
|
|
||||||
out = tokenizer(total)
|
|
||||||
input_ids = out["input_ids"]
|
|
||||||
return [
|
|
||||||
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
|
||||||
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
|
||||||
]
|
|
||||||
|
|
||||||
return preprocess_instruct_code
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import sentencepiece as spm
|
|
||||||
|
|
||||||
from dataset.tokenizer import Tokenizer
|
|
||||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
|
||||||
|
|
||||||
sp_model = spm.SentencePieceProcessor(
|
|
||||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
|
||||||
)
|
|
||||||
tokenizer = Tokenizer(sp_model)
|
|
||||||
patterns = ["data/instruction_data/part-belle_multiturn_chat_0.8M-*.jsonl.zst"]
|
|
||||||
paths = create_shard_kwargs(patterns)
|
|
||||||
transform_dict = {
|
|
||||||
"self_instruct": preprocess_self_instruction_gen(tokenizer),
|
|
||||||
"belle_1M": preprocess_belle_gen(tokenizer),
|
|
||||||
"belle_0.5M": preprocess_belle_gen(tokenizer),
|
|
||||||
"belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
|
|
||||||
"belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
|
|
||||||
"instruct_to_code": preprocess_instruct_code_gen(tokenizer),
|
|
||||||
"sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
|
|
||||||
}
|
|
||||||
data_set = DataIter(
|
|
||||||
paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
|
|
||||||
)
|
|
||||||
for i, sample in enumerate(data_set):
|
|
||||||
print(sp_model.decode(sample))
|
|
||||||
if i == 1:
|
|
||||||
break
|
|
|
@ -1,189 +0,0 @@
|
||||||
"""
|
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
|
||||||
Date: 2023-03-30 21:35:01
|
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
|
||||||
LastEditTime: 2023-04-06 03:35:31
|
|
||||||
FilePath: /Open-Llama/inctruction_tuning.py
|
|
||||||
Description:
|
|
||||||
|
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import wandb
|
|
||||||
import torch
|
|
||||||
import random
|
|
||||||
import sentencepiece as spm
|
|
||||||
from torchinfo import summary
|
|
||||||
from accelerate import Accelerator
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from deepspeed.ops.adam import FusedAdam
|
|
||||||
from transformers import (
|
|
||||||
OpenLlamaForCausalLM,
|
|
||||||
OpenLlamaConfig,
|
|
||||||
get_cosine_schedule_with_warmup,
|
|
||||||
)
|
|
||||||
|
|
||||||
from dataset.validation import val_set
|
|
||||||
from dataset.tokenizer import Tokenizer
|
|
||||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
|
||||||
from dataset.collate_fn import collate_fn_gen
|
|
||||||
from dataset.instruction_dataset import (
|
|
||||||
preprocess_belle_gen,
|
|
||||||
preprocess_self_instruction_gen,
|
|
||||||
preprocess_belle_multiturn_chat_gen,
|
|
||||||
preprocess_instruct_code_gen,
|
|
||||||
preprocess_sharegpt_gen,
|
|
||||||
)
|
|
||||||
from configs.instruction_tuning_config import *
|
|
||||||
|
|
||||||
accelerator = Accelerator()
|
|
||||||
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
wandb.init(project="LLAMA Instruction")
|
|
||||||
|
|
||||||
log_interval *= accelerator.gradient_accumulation_steps
|
|
||||||
eval_interval *= accelerator.gradient_accumulation_steps
|
|
||||||
save_interval *= accelerator.gradient_accumulation_steps
|
|
||||||
|
|
||||||
sp_model = spm.SentencePieceProcessor(model_file=tokenizer_model_path)
|
|
||||||
tokenizer = Tokenizer(sp_model)
|
|
||||||
|
|
||||||
paths = create_shard_kwargs(patterns, repeat=3)
|
|
||||||
random.shuffle(paths)
|
|
||||||
transform_dict = {
|
|
||||||
"self_instruct": preprocess_self_instruction_gen(tokenizer),
|
|
||||||
"belle_1M": preprocess_belle_gen(tokenizer),
|
|
||||||
"belle_0.5M": preprocess_belle_gen(tokenizer),
|
|
||||||
"belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
|
|
||||||
"belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
|
|
||||||
"instruct_to_code": preprocess_instruct_code_gen(tokenizer),
|
|
||||||
"sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
|
|
||||||
}
|
|
||||||
data_set = DataIter(
|
|
||||||
paths,
|
|
||||||
transform_dict=transform_dict,
|
|
||||||
concat_docs=True,
|
|
||||||
max_length=max_length,
|
|
||||||
process_index=accelerator.process_index,
|
|
||||||
num_processes=accelerator.num_processes,
|
|
||||||
)
|
|
||||||
train_loader = DataLoader(
|
|
||||||
data_set,
|
|
||||||
batch_size=train_batch_size,
|
|
||||||
# If num_workers is greater than 1, duplicate data may occur.
|
|
||||||
num_workers=0,
|
|
||||||
collate_fn=collate_fn_gen(tokenizer, max_length),
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
# smaller initializer_range make training more stable
|
|
||||||
# add stabel embedding to token embedding
|
|
||||||
raw_model = OpenLlamaForCausalLM(
|
|
||||||
OpenLlamaConfig(
|
|
||||||
vocab_size=tokenizer.vocab_size,
|
|
||||||
initializer_range=initializer_range,
|
|
||||||
pad_token_id=tokenizer.pad_id,
|
|
||||||
rms_norm_eps=1e-5,
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
attention_dropout_prob=0.1,
|
|
||||||
use_stable_embedding=True,
|
|
||||||
shared_input_output_embedding=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
|
||||||
raw_model.load_state_dict(ckpt)
|
|
||||||
raw_model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
summary(raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda())
|
|
||||||
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
|
|
||||||
optimizer_grouped_parameters = [
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p
|
|
||||||
for n, p in raw_model.named_parameters()
|
|
||||||
if not any(nd in n for nd in no_decay)
|
|
||||||
],
|
|
||||||
"weight_decay": weight_decay,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p
|
|
||||||
for n, p in raw_model.named_parameters()
|
|
||||||
if any(nd in n for nd in no_decay)
|
|
||||||
],
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
optim = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=(0.9, 0.95))
|
|
||||||
optim.zero_grad()
|
|
||||||
factor = accelerator.num_processes / accelerator.gradient_accumulation_steps
|
|
||||||
scheduler = get_cosine_schedule_with_warmup(
|
|
||||||
optim,
|
|
||||||
num_warmup_steps=num_warmup_steps * factor,
|
|
||||||
num_training_steps=num_training_steps * factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
_, model, optim, scheduler = accelerator.prepare(
|
|
||||||
train_loader, raw_model, optim, scheduler
|
|
||||||
)
|
|
||||||
print("start training...")
|
|
||||||
train_loader_iter = iter(train_loader)
|
|
||||||
global_step = 0
|
|
||||||
start_time = time.time()
|
|
||||||
for data_step in range(num_training_steps):
|
|
||||||
model.train()
|
|
||||||
with accelerator.accumulate(model):
|
|
||||||
batch = next(train_loader_iter)
|
|
||||||
for k, v in batch.items():
|
|
||||||
batch[k] = v.to(accelerator.device, non_blocking=True)
|
|
||||||
out = model(**batch, labels=batch["input_ids"])
|
|
||||||
total_loss = out.loss
|
|
||||||
losses = {"total_loss": total_loss}
|
|
||||||
accelerator.backward(total_loss)
|
|
||||||
optim.step()
|
|
||||||
scheduler.step()
|
|
||||||
optim.zero_grad()
|
|
||||||
if accelerator.sync_gradients:
|
|
||||||
global_step += 1
|
|
||||||
if data_step % log_interval == 0 and data_step > 0 and accelerator.is_main_process:
|
|
||||||
cost_time = time.time() - start_time
|
|
||||||
start_time = time.time()
|
|
||||||
tokens = train_batch_size * log_interval * max_length
|
|
||||||
wandb.log({"Training/Token per second per gpu": tokens / cost_time})
|
|
||||||
for k, v in losses.items():
|
|
||||||
wandb.log({"Losses/{}".format(k): v})
|
|
||||||
current_lr = optim.param_groups[0]["lr"]
|
|
||||||
wandb.log({"Training/LR": current_lr})
|
|
||||||
if optim.scaler is not None:
|
|
||||||
wandb.log({"Training/Loss Scale": optim.scaler.get_scale()})
|
|
||||||
wandb.log({"Training/Data Step": data_step})
|
|
||||||
wandb.log({"Training/Global Step": global_step})
|
|
||||||
accelerator.print(
|
|
||||||
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
|
|
||||||
global_step, data_step, losses["total_loss"], tokens / cost_time
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if data_step % eval_interval == 0 and accelerator.is_main_process:
|
|
||||||
text_table = wandb.Table(columns=["question", "pred"])
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
for data in val_set:
|
|
||||||
raw_inputs = data
|
|
||||||
inputs_len = len(raw_inputs)
|
|
||||||
inputs = tokenizer(
|
|
||||||
raw_inputs, return_tensors=True, add_special_tokens=False
|
|
||||||
)
|
|
||||||
for k, v in inputs.items():
|
|
||||||
inputs[k] = v.to(accelerator.device)
|
|
||||||
pred = model.generate(
|
|
||||||
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
|
|
||||||
)
|
|
||||||
pred = tokenizer.decode(pred.cpu())[0]
|
|
||||||
pred = pred[inputs_len:]
|
|
||||||
text_table.add_data(raw_inputs, pred)
|
|
||||||
wandb.log({"Predictions on {}".format(global_step): text_table})
|
|
||||||
if data_step % save_interval == 0 and data_step > 0 and accelerator.is_main_process:
|
|
||||||
if not os.path.isdir(work_dir):
|
|
||||||
os.mkdir(work_dir)
|
|
||||||
torch.save(raw_model.state_dict(), "{}/{}.pt".format(work_dir, global_step))
|
|
||||||
wandb.finish()
|
|
|
@ -191,6 +191,6 @@ class Trainer:
|
||||||
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
|
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
|
||||||
)
|
)
|
||||||
pred = pred[0, input_length:]
|
pred = pred[0, input_length:]
|
||||||
pred = self.tokenizer.decode(pred.cpu())
|
pred = self.tokenizer.decode(pred.cpu(), skip_special_tokens=True)
|
||||||
text_table.add_data(raw_inputs, pred)
|
text_table.add_data(raw_inputs, pred)
|
||||||
wandb.log({"Predictions on {}".format(self.global_step): text_table})
|
wandb.log({"Predictions on {}".format(self.global_step): text_table})
|
||||||
|
|
|
@ -8,7 +8,6 @@ Description:
|
||||||
|
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import yaml
|
import yaml
|
||||||
import torch
|
import torch
|
||||||
from absl import app
|
from absl import app
|
||||||
|
@ -26,10 +25,14 @@ flags.DEFINE_string("config", None, "Training config path")
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
accelerator = Accelerator()
|
|
||||||
|
|
||||||
with open(FLAGS.config, "r", encoding="utf-8") as fp:
|
with open(FLAGS.config, "r", encoding="utf-8") as fp:
|
||||||
config = yaml.load(fp, Loader=yaml.FullLoader)
|
config = yaml.load(fp, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
accelerator = Accelerator(
|
||||||
|
gradient_accumulation_steps=config["train"].get(
|
||||||
|
"gradient_accumulation_steps", 1
|
||||||
|
)
|
||||||
|
)
|
||||||
tokenizer = LlamaTokenizer(
|
tokenizer = LlamaTokenizer(
|
||||||
config["data"]["tokenizer_model_path"],
|
config["data"]["tokenizer_model_path"],
|
||||||
pad_token="<pad>",
|
pad_token="<pad>",
|
||||||
|
@ -37,14 +40,14 @@ def main(argv):
|
||||||
add_eos_token=True,
|
add_eos_token=True,
|
||||||
)
|
)
|
||||||
data_config = config["data"]
|
data_config = config["data"]
|
||||||
pretrain_dataset = construct_dataset(data_config, tokenizer)
|
train_dataset = construct_dataset(data_config, tokenizer)
|
||||||
pretrain_dataset = split_dataset_by_node(
|
train_dataset = split_dataset_by_node(
|
||||||
pretrain_dataset,
|
train_dataset,
|
||||||
rank=int(os.environ["RANK"]),
|
rank=accelerator.process_index,
|
||||||
world_size=int(os.environ["WORLD_SIZE"]),
|
world_size=accelerator.num_processes,
|
||||||
)
|
)
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
pretrain_dataset,
|
train_dataset,
|
||||||
batch_size=config["train"]["train_batch_size"],
|
batch_size=config["train"]["train_batch_size"],
|
||||||
num_workers=config["train"]["train_num_workers"],
|
num_workers=config["train"]["train_num_workers"],
|
||||||
)
|
)
|
Loading…
Reference in New Issue
Block a user