Open-Llama/dataset/dataset.py

308 lines
10 KiB
Python

"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-24 20:05:21
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-24 20:05:59
FilePath: /Open-Llama/dataset/dataset.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import torch
import random
from glob import glob
from datasets import load_dataset, interleave_datasets
def pretrain_transform(batch):
# wudao preprocess
if "title" in batch and "content" in batch:
assert len(batch["title"]) == 1
batch["text"] = [batch["title"][0] + "\n" + batch["content"][0]]
elif "text" in batch:
pass
else:
raise Exception("Unrecognized pretrain dataset format.")
return batch
def instruct_transform(batch):
# self instruct preprocess
if "prompt" in batch and "completion" in batch:
prompt = batch["prompt"][0]
completion = batch["completion"][0]
if prompt.endswith("Output:"):
prompt = prompt[:-7]
text = "user:{}\nsystem:{}".format(prompt.strip(), completion.strip())
texts = [text]
# belle preprocess
elif "instruction" in batch and "output" in batch:
prompt = batch["instruction"][0].replace("\\n", "")
prompt = prompt.strip("")
completion = batch["output"][0].replace("\\n", "")
completion = completion.strip("")
# multi turn chat
if "Human:" in prompt:
texts = []
chats = prompt + completion
chats = chats.split("Human:")
for chat in chats:
if chat.strip() == "":
continue
res = chat.split("Assistant:")
if len(res) != 2:
continue
prompt, completion = res
prompt = prompt.strip()
completion = completion.strip()
chat = "user:{}\nsystem:{}".format(prompt, completion)
texts.append(chat)
texts = ["[multiturn_sep]".join(texts)]
else:
text = "user:{}\nsystem:{}".format(prompt, completion)
texts = [text]
# instruct code preprocess
elif "instruction" in batch and "answer" in batch:
prompt = batch["instruction"][0].replace("\\n", "")
prompt = prompt.strip("")
completion = batch["answer"][0].replace("\\n", "")
completion = completion.strip("")
text = "user:{}\nsystem:{}".format(prompt, completion)
texts = [text]
# share gpt preprocess
elif "conversations" in batch:
chats = batch["conversations"][0]
if chats[0]["from"] != "human":
chats = chats[1:]
texts = []
for i in range(len(chats) // 2):
prompt = chats[2 * i]
completion = chats[2 * i + 1]
if not (prompt["from"] == "human" and completion["from"] == "gpt"):
continue
prompt = prompt["value"]
prompt = prompt.strip()
completion = completion["value"]
completion = completion.strip()
chat = "user:{}\nsystem:{}".format(prompt, completion)
texts.append(chat)
texts = ["[multiturn_sep]".join(texts)]
else:
raise Exception("Unrecognized instruct dataset format.")
return {"text": texts}
def split_multiturn(batch):
return {"text": batch["text"][0].split("[multiturn_sep]")}
def truncation_gen(seq_length):
def truncation(line):
return {"input_ids": line["input_ids"][:seq_length]}
return truncation
def sample_sequence_gen(seq_length, eos_token_id):
def sample_sequence(line):
doc_length = line["input_ids"].shape[0]
if doc_length <= seq_length:
start = 0
else:
if random.random() < 1 / 4:
start = 0
else:
start = random.randint(0, doc_length - seq_length)
input_ids = line["input_ids"][start : start + seq_length]
if input_ids[-1] != eos_token_id:
input_ids[-1] = eos_token_id
return {"input_ids": input_ids}
return sample_sequence
def split_sequence_gen(seq_length):
def split_sequence(batch):
input_ids = batch["input_ids"][0]
out = []
while len(input_ids) >= (1 + len(out)) * seq_length:
out.append(input_ids[len(out) * seq_length : (1 + len(out)) * seq_length])
return {"input_ids": out}
return split_sequence
def concat_multiple_sequence_gen(seq_length):
def concat_multiple_sequence(batch):
concat_input_ids = torch.cat(batch["input_ids"], dim=0)
input_ids = []
while len(concat_input_ids) >= (1 + len(input_ids)) * seq_length:
input_ids.append(
concat_input_ids[
len(input_ids) * seq_length : (1 + len(input_ids)) * seq_length
]
)
return {"input_ids": input_ids}
return concat_multiple_sequence
def get_labels_gen(pad_token_id):
def get_labels(line):
input_ids = line["input_ids"]
labels = input_ids.clone()
labels[labels == pad_token_id] = -100
return {"labels": labels}
return get_labels
def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
datasets = []
probabilities = []
# 暂时只使用一个,使用多个时无法使用多进程读取导致耗时较长
assert len(dataset_config["data"]) == 1
for name, pattern in dataset_config["data"].items():
data_files = glob(pattern)
assert len(data_files) > 0
dataset = load_dataset(
"json", data_files=data_files, split="train", streaming=True
)
# 文本预处理转换为统一格式
if dataset_config["mode"] == "pretrain":
dataset = dataset.map(pretrain_transform, batched=True, batch_size=1)
elif dataset_config["mode"] == "instruct":
dataset = dataset.map(instruct_transform, batched=True, batch_size=1)
dataset = dataset.select_columns("text")
dataset = dataset.map(split_multiturn, batched=True, batch_size=1)
else:
raise Exception(
"Dataset mode: {} not found.".format(dataset_config["mode"])
)
datasets.append(dataset)
probabilities.append(dataset.n_shards)
probabilities_sum = sum(probabilities)
# 多个数据部分按概率采样
probabilities = [p / probabilities_sum for p in probabilities]
if len(datasets) > 1:
full_dataset = interleave_datasets(
datasets, probabilities=probabilities, seed=42
)
else:
full_dataset = datasets[0]
# to visualize
if return_raw_text:
return full_dataset
seq_length = dataset_config["seq_length"]
sequence_sample_mode = dataset_config.get("sequence_sample_mode", "truncation")
truncation = sequence_sample_mode == "truncation"
# tokenize
if dataset_config.get("pad_to_max", True):
full_dataset = full_dataset.map(
lambda x: tokenizer(
x["text"],
return_tensors="pt",
return_attention_mask=False,
padding="max_length",
max_length=seq_length,
truncation=truncation,
)
)
else:
full_dataset = full_dataset.map(
lambda x: tokenizer(
x["text"],
return_tensors="pt",
return_attention_mask=False,
truncation=truncation,
)
)
# format
full_dataset = full_dataset.map(lambda x: {"input_ids": x["input_ids"][0]})
full_dataset = full_dataset.select_columns("input_ids")
# sequence_sample
if sequence_sample_mode == "truncation":
pass
elif sequence_sample_mode == "none":
pass
elif sequence_sample_mode == "sample":
full_dataset = full_dataset.map(
sample_sequence_gen(seq_length, tokenizer.eos_token_id)
)
elif sequence_sample_mode == "split":
assert not dataset_config.get("concat_multiple_sequence", False)
full_dataset = full_dataset.map(
split_sequence_gen(seq_length), batched=True, batch_size=1
)
else:
raise Exception(
"Unknown sequence_sample mode: {}.".format(sequence_sample_mode)
)
# concat multiple sequence
if dataset_config.get("concat_multiple_sequence", False):
num_sequences = dataset_config["num_sequences"]
full_dataset = full_dataset.map(
concat_multiple_sequence_gen(seq_length),
batched=True,
batch_size=num_sequences,
drop_last_batch=True,
)
# add label
full_dataset = full_dataset.map(get_labels_gen(tokenizer.pad_token_id))
# shuffle
full_dataset = full_dataset.shuffle()
return full_dataset
if __name__ == "__main__":
import time
from unicodedata import normalize
from torch.utils.data import DataLoader
from transformers import LlamaTokenizer
data_config = {
"mode": "pretrain",
"data": {"mixed": "data/pretrain_data/part-*.jsonl.zst"},
"pad_to_max": False,
"sequence_sample_mode": "sample",
"concat_multiple_sequence": True,
"num_sequences": 10,
"seq_length": 2048,
}
tokenizer = LlamaTokenizer(
"configs/llama_tokenizer_extended.model",
pad_token="<pad>",
add_bos_token=False,
add_eos_token=True,
)
pretrain_dataset = construct_dataset(data_config, tokenizer, True)
start = time.time()
for i, line in enumerate(pretrain_dataset):
raw_text = line["text"]
# raw_text = normalize("NFKC", raw_text)
input_ids = tokenizer(
line["text"], return_tensors="pt", return_attention_mask=False
)["input_ids"][0]
decode_text = tokenizer.decode(input_ids, skip_special_tokens=True)
if raw_text != decode_text and "" not in raw_text:
print(raw_text, "\n", decode_text)
if i == 3000:
break
print("all checked in {} seconds.".format(time.time() - start))
pretrain_dataset = construct_dataset(data_config, tokenizer)
print(pretrain_dataset.n_shards)
pretrain_loader = DataLoader(pretrain_dataset, batch_size=2, num_workers=16)
for batch in pretrain_loader:
for k, v in batch.items():
print(k, v.shape, "\n", v)
break