add more instruction data
This commit is contained in:
parent
9f140dc99f
commit
bc16df4751
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-03-30 21:38:07
|
Date: 2023-03-30 21:38:07
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-03-30 21:39:40
|
LastEditTime: 2023-04-06 03:37:23
|
||||||
FilePath: /Open-Llama/configs/instruction_tuning_config.py
|
FilePath: /Open-Llama/configs/instruction_tuning_config.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
max_length = 1024
|
max_length = 1024
|
||||||
train_batch_size = 2
|
train_batch_size = 2
|
||||||
num_training_steps = 37500
|
num_training_steps = 40000
|
||||||
num_warmup_steps = 100
|
num_warmup_steps = 100
|
||||||
initializer_range = 1e-2
|
initializer_range = 1e-2
|
||||||
lr = 2e-4
|
lr = 2e-4
|
||||||
|
@ -22,4 +22,4 @@ log_interval = 50
|
||||||
eval_interval = 500
|
eval_interval = 500
|
||||||
save_interval = 1000
|
save_interval = 1000
|
||||||
work_dir = "data/saved_ckpt/"
|
work_dir = "data/saved_ckpt/"
|
||||||
ckpt_path = "data/saved_ckpt/40000.pt"
|
ckpt_path = "data/saved_ckpt/83200.pt"
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-03-17 19:32:20
|
Date: 2023-03-17 19:32:20
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-05 22:36:45
|
LastEditTime: 2023-04-06 03:37:55
|
||||||
FilePath: /Open-Llama/dataset/data_iter.py
|
FilePath: /Open-Llama/dataset/data_iter.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -68,7 +68,10 @@ class DataIter(IterableDataset):
|
||||||
# Transformation, including sample, tokenize, etc.
|
# Transformation, including sample, tokenize, etc.
|
||||||
if self.transform_dict:
|
if self.transform_dict:
|
||||||
line = self.transform_dict[dataset_name](line)
|
line = self.transform_dict[dataset_name](line)
|
||||||
if isinstance(line, str):
|
# skip bad doc
|
||||||
|
if line is None:
|
||||||
|
continue
|
||||||
|
elif isinstance(line, str):
|
||||||
yield line
|
yield line
|
||||||
# must be list of list
|
# must be list of list
|
||||||
elif isinstance(line, list) and isinstance(line[0], list):
|
elif isinstance(line, list) and isinstance(line[0], list):
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-03-30 21:02:00
|
Date: 2023-03-30 21:02:00
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-05 22:35:24
|
LastEditTime: 2023-04-06 03:33:27
|
||||||
FilePath: /Open-Llama/dataset/instruction_dataset.py
|
FilePath: /Open-Llama/dataset/instruction_dataset.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
|
||||||
prompt = line["prompt"]
|
prompt = line["prompt"]
|
||||||
if prompt.endswith("Output:"):
|
if prompt.endswith("Output:"):
|
||||||
prompt = prompt[:-7]
|
prompt = prompt[:-7]
|
||||||
total = "user:{}<s>system:{}".format(prompt.strip(), line["completion"].strip())
|
total = "user:{}\nsystem:{}".format(prompt.strip(), line["completion"].strip())
|
||||||
out = tokenizer(total)
|
out = tokenizer(total)
|
||||||
input_ids = out["input_ids"]
|
input_ids = out["input_ids"]
|
||||||
return [
|
return [
|
||||||
|
@ -39,12 +39,12 @@ def preprocess_belle_gen(tokenizer, segment_max_length=1024):
|
||||||
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||||
Split the data based on the tokenized length according to the maximum length.
|
Split the data based on the tokenized length according to the maximum length.
|
||||||
"""
|
"""
|
||||||
prompt = line["input"].replace("\\n", "")
|
prompt = line["instruction"].replace("\\n", "")
|
||||||
prompt = prompt.strip("")
|
prompt = prompt.strip("")
|
||||||
|
|
||||||
completion = line["target"].replace("\\n", "")
|
completion = line["output"].replace("\\n", "")
|
||||||
completion = completion.strip("")
|
completion = completion.strip("")
|
||||||
total = "user:{}<s>system:{}".format(prompt, completion)
|
total = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
out = tokenizer(total)
|
out = tokenizer(total)
|
||||||
input_ids = out["input_ids"]
|
input_ids = out["input_ids"]
|
||||||
return [
|
return [
|
||||||
|
@ -55,9 +55,101 @@ def preprocess_belle_gen(tokenizer, segment_max_length=1024):
|
||||||
return preprocess_belle
|
return preprocess_belle
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_belle_multiturn_chat_gen(tokenizer, segment_max_length=1024):
|
||||||
|
def preprocess_belle_multiturn_chat(line):
|
||||||
|
"""
|
||||||
|
The format of the data is roughly as follows.
|
||||||
|
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||||
|
Split the data based on the tokenized length according to the maximum length.
|
||||||
|
"""
|
||||||
|
prompt = line["instruction"].replace("\\n", "")
|
||||||
|
prompt = prompt.strip("")
|
||||||
|
|
||||||
|
completion = line["output"].replace("\\n", "")
|
||||||
|
completion = completion.strip("")
|
||||||
|
chats = prompt + completion
|
||||||
|
chats = chats.split("Human:")
|
||||||
|
input_ids = []
|
||||||
|
for chat in chats:
|
||||||
|
if chat.strip() == "":
|
||||||
|
continue
|
||||||
|
res = chat.split("Assistant:")
|
||||||
|
if len(res) != 2:
|
||||||
|
continue
|
||||||
|
prompt, completion = res
|
||||||
|
prompt = prompt.strip()
|
||||||
|
completion = completion.strip()
|
||||||
|
chat = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
|
out = tokenizer(chat)
|
||||||
|
input_ids.extend(out["input_ids"])
|
||||||
|
if len(input_ids) == 0:
|
||||||
|
return None
|
||||||
|
return [
|
||||||
|
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||||
|
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||||
|
]
|
||||||
|
|
||||||
|
return preprocess_belle_multiturn_chat
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_sharegpt_gen(tokenizer, segment_max_length=1024):
|
||||||
|
def preprocess_sharegpt(line):
|
||||||
|
"""
|
||||||
|
The format of the data is roughly as follows.
|
||||||
|
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||||
|
Split the data based on the tokenized length according to the maximum length.
|
||||||
|
"""
|
||||||
|
chats = line["conversations"]
|
||||||
|
if chats[0]["from"] != "human":
|
||||||
|
chats = chats[1:]
|
||||||
|
input_ids = []
|
||||||
|
for i in range(len(chats) // 2):
|
||||||
|
prompt = chats[2 * i]
|
||||||
|
completion = chats[2 * i + 1]
|
||||||
|
if not (prompt["from"] == "human" and completion["from"] == "gpt"):
|
||||||
|
continue
|
||||||
|
prompt = prompt["value"]
|
||||||
|
prompt = prompt.strip()
|
||||||
|
completion = completion["value"]
|
||||||
|
completion = completion.strip()
|
||||||
|
chat = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
|
out = tokenizer(chat)
|
||||||
|
input_ids.extend(out["input_ids"])
|
||||||
|
if input_ids == []:
|
||||||
|
return None
|
||||||
|
return [
|
||||||
|
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||||
|
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||||
|
]
|
||||||
|
|
||||||
|
return preprocess_sharegpt
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_instruct_code_gen(tokenizer, segment_max_length=1024):
|
||||||
|
def preprocess_instruct_code(line):
|
||||||
|
"""
|
||||||
|
The format of the data is roughly as follows.
|
||||||
|
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||||
|
Split the data based on the tokenized length according to the maximum length.
|
||||||
|
"""
|
||||||
|
prompt = line["instruction"].replace("\\n", "")
|
||||||
|
prompt = prompt.strip("")
|
||||||
|
|
||||||
|
completion = line["answer"].replace("\\n", "")
|
||||||
|
completion = completion.strip("")
|
||||||
|
total = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
|
out = tokenizer(total)
|
||||||
|
input_ids = out["input_ids"]
|
||||||
|
return [
|
||||||
|
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||||
|
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||||
|
]
|
||||||
|
|
||||||
|
return preprocess_instruct_code
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
from datasets import IterableDataset
|
|
||||||
|
|
||||||
from dataset.tokenizer import Tokenizer
|
from dataset.tokenizer import Tokenizer
|
||||||
from dataset.data_iter import create_shard_kwargs, DataIter
|
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||||
|
@ -66,17 +158,21 @@ if __name__ == "__main__":
|
||||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||||
)
|
)
|
||||||
tokenizer = Tokenizer(sp_model)
|
tokenizer = Tokenizer(sp_model)
|
||||||
patterns = ["data/instruction_data/part-belle_1M*.jsonl.zst"]
|
patterns = ["data/instruction_data/part-belle_multiturn_chat_0.8M-*.jsonl.zst"]
|
||||||
paths = create_shard_kwargs(patterns)
|
paths = create_shard_kwargs(patterns)
|
||||||
transform_dict = {
|
transform_dict = {
|
||||||
|
"self_instruct": preprocess_self_instruction_gen(tokenizer),
|
||||||
"belle_1M": preprocess_belle_gen(tokenizer),
|
"belle_1M": preprocess_belle_gen(tokenizer),
|
||||||
"belle_0.5M": preprocess_belle_gen(tokenizer),
|
"belle_0.5M": preprocess_belle_gen(tokenizer),
|
||||||
"self_instruct": preprocess_self_instruction_gen(tokenizer),
|
"belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
|
||||||
|
"belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
|
||||||
|
"instruct_to_code": preprocess_instruct_code_gen(tokenizer),
|
||||||
|
"sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
|
||||||
}
|
}
|
||||||
data_set = DataIter(
|
data_set = DataIter(
|
||||||
paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
|
paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
|
||||||
)
|
)
|
||||||
for i, sample in enumerate(data_set):
|
for i, sample in enumerate(data_set):
|
||||||
print(sample, sp_model.Decode(sample))
|
print(sp_model.decode(sample))
|
||||||
if i == 20:
|
if i == 1:
|
||||||
break
|
break
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-03-30 21:35:01
|
Date: 2023-03-30 21:35:01
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-05 22:47:25
|
LastEditTime: 2023-04-06 03:35:31
|
||||||
FilePath: /Open-Llama/inctruction_tuning.py
|
FilePath: /Open-Llama/inctruction_tuning.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -27,6 +27,9 @@ from dataset.collate_fn import collate_fn_gen
|
||||||
from dataset.instruction_dataset import (
|
from dataset.instruction_dataset import (
|
||||||
preprocess_belle_gen,
|
preprocess_belle_gen,
|
||||||
preprocess_self_instruction_gen,
|
preprocess_self_instruction_gen,
|
||||||
|
preprocess_belle_multiturn_chat_gen,
|
||||||
|
preprocess_instruct_code_gen,
|
||||||
|
preprocess_sharegpt_gen,
|
||||||
)
|
)
|
||||||
from configs.instruction_tuning_config import *
|
from configs.instruction_tuning_config import *
|
||||||
|
|
||||||
|
@ -45,9 +48,13 @@ tokenizer = Tokenizer(sp_model)
|
||||||
paths = create_shard_kwargs(patterns, repeat=3)
|
paths = create_shard_kwargs(patterns, repeat=3)
|
||||||
random.shuffle(paths)
|
random.shuffle(paths)
|
||||||
transform_dict = {
|
transform_dict = {
|
||||||
"belle_1M": preprocess_belle_gen(tokenizer, max_length),
|
"self_instruct": preprocess_self_instruction_gen(tokenizer),
|
||||||
"belle_0.5M": preprocess_belle_gen(tokenizer, max_length),
|
"belle_1M": preprocess_belle_gen(tokenizer),
|
||||||
"self_instruct": preprocess_self_instruction_gen(tokenizer, max_length),
|
"belle_0.5M": preprocess_belle_gen(tokenizer),
|
||||||
|
"belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
|
||||||
|
"belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
|
||||||
|
"instruct_to_code": preprocess_instruct_code_gen(tokenizer),
|
||||||
|
"sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
|
||||||
}
|
}
|
||||||
data_set = DataIter(
|
data_set = DataIter(
|
||||||
paths,
|
paths,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user