Open-Llama/dataset/instruction_dataset.py
2023-04-06 03:45:24 +08:00

179 lines
6.6 KiB
Python

"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 21:02:00
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-06 03:33:27
FilePath: /Open-Llama/dataset/instruction_dataset.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import math
def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
def preprocess_self_instruction(line):
"""
The format of the data is roughly as follows.
{'prompt': 'Explain the origin of life on earth. Output:', 'completion': 'Life on Earth is believed to have'}
Split the data based on the tokenized length according to the maximum length.
"""
prompt = line["prompt"]
if prompt.endswith("Output:"):
prompt = prompt[:-7]
total = "user:{}\nsystem:{}".format(prompt.strip(), line["completion"].strip())
out = tokenizer(total)
input_ids = out["input_ids"]
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_self_instruction
def preprocess_belle_gen(tokenizer, segment_max_length=1024):
def preprocess_belle(line):
"""
The format of the data is roughly as follows.
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
"""
prompt = line["instruction"].replace("\\n", "")
prompt = prompt.strip("")
completion = line["output"].replace("\\n", "")
completion = completion.strip("")
total = "user:{}\nsystem:{}".format(prompt, completion)
out = tokenizer(total)
input_ids = out["input_ids"]
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_belle
def preprocess_belle_multiturn_chat_gen(tokenizer, segment_max_length=1024):
def preprocess_belle_multiturn_chat(line):
"""
The format of the data is roughly as follows.
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
"""
prompt = line["instruction"].replace("\\n", "")
prompt = prompt.strip("")
completion = line["output"].replace("\\n", "")
completion = completion.strip("")
chats = prompt + completion
chats = chats.split("Human:")
input_ids = []
for chat in chats:
if chat.strip() == "":
continue
res = chat.split("Assistant:")
if len(res) != 2:
continue
prompt, completion = res
prompt = prompt.strip()
completion = completion.strip()
chat = "user:{}\nsystem:{}".format(prompt, completion)
out = tokenizer(chat)
input_ids.extend(out["input_ids"])
if len(input_ids) == 0:
return None
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_belle_multiturn_chat
def preprocess_sharegpt_gen(tokenizer, segment_max_length=1024):
def preprocess_sharegpt(line):
"""
The format of the data is roughly as follows.
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
"""
chats = line["conversations"]
if chats[0]["from"] != "human":
chats = chats[1:]
input_ids = []
for i in range(len(chats) // 2):
prompt = chats[2 * i]
completion = chats[2 * i + 1]
if not (prompt["from"] == "human" and completion["from"] == "gpt"):
continue
prompt = prompt["value"]
prompt = prompt.strip()
completion = completion["value"]
completion = completion.strip()
chat = "user:{}\nsystem:{}".format(prompt, completion)
out = tokenizer(chat)
input_ids.extend(out["input_ids"])
if input_ids == []:
return None
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_sharegpt
def preprocess_instruct_code_gen(tokenizer, segment_max_length=1024):
def preprocess_instruct_code(line):
"""
The format of the data is roughly as follows.
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
"""
prompt = line["instruction"].replace("\\n", "")
prompt = prompt.strip("")
completion = line["answer"].replace("\\n", "")
completion = completion.strip("")
total = "user:{}\nsystem:{}".format(prompt, completion)
out = tokenizer(total)
input_ids = out["input_ids"]
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_instruct_code
if __name__ == "__main__":
import sentencepiece as spm
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, DataIter
sp_model = spm.SentencePieceProcessor(
model_file="configs/10w_vocab_wudao5_pile10.model"
)
tokenizer = Tokenizer(sp_model)
patterns = ["data/instruction_data/part-belle_multiturn_chat_0.8M-*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
transform_dict = {
"self_instruct": preprocess_self_instruction_gen(tokenizer),
"belle_1M": preprocess_belle_gen(tokenizer),
"belle_0.5M": preprocess_belle_gen(tokenizer),
"belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
"belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
"instruct_to_code": preprocess_instruct_code_gen(tokenizer),
"sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
}
data_set = DataIter(
paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
)
for i, sample in enumerate(data_set):
print(sp_model.decode(sample))
if i == 1:
break