add xP3 dataset and belle_2M

This commit is contained in:
LiangSong 2023-05-05 17:05:41 +08:00
parent 00cbdbbf26
commit 85caa97a6a
5 changed files with 105 additions and 8 deletions

View File

@ -26,6 +26,7 @@ train:
train_num_workers: 16
gradient_accumulation_steps: 1
prefetch_factor: 100
train_and_eval: False
# global step
log_interval: 50
eval_interval: 500

View File

@ -28,6 +28,7 @@ train:
train_num_workers: 16
gradient_accumulation_steps: 12
prefetch_factor: 100
train_and_eval: True
# global step
log_interval: 5
eval_interval: 500

View File

@ -9,6 +9,7 @@ Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import json
from tqdm import tqdm
import zstandard as zstd
from datasets import load_dataset
@ -20,7 +21,7 @@ write_path = root_dir + "/instruction_data/part-self_instruct-{}.jsonl.zst"
total_num = 0
file_num = 1
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in dataset["train"]:
for line in tqdm(dataset["train"]):
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
@ -41,7 +42,7 @@ write_path = root_dir + "/instruction_data/part-belle_0.5M-{}.jsonl.zst"
total_num = 0
file_num = 1
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in dataset["train"]:
for line in tqdm(dataset["train"]):
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
@ -62,7 +63,7 @@ write_path = root_dir + "/instruction_data/part-belle_1M-{}.jsonl.zst"
total_num = 0
file_num = 1
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in dataset["train"]:
for line in tqdm(dataset["train"]):
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
@ -78,12 +79,33 @@ print(
)
)
dataset = load_dataset("BelleGroup/train_2M_CN")
write_path = root_dir + "/instruction_data/part-belle_2M-{}.jsonl.zst"
total_num = 0
file_num = 1
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in tqdm(dataset["train"]):
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
wfp.close()
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
wfp.write(line.encode("utf-8"))
wfp.write(b"\n")
total_num += 1
wfp.close()
print(
"BelleGroup/train_2M_CN preprocess done. Total line: {}, Total file: {}".format(
total_num, file_num
)
)
dataset = load_dataset("BelleGroup/school_math_0.25M")
write_path = root_dir + "/instruction_data/part-belle_school_math_0.25M-{}.jsonl.zst"
total_num = 0
file_num = 1
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in dataset["train"]:
for line in tqdm(dataset["train"]):
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
@ -104,7 +126,7 @@ write_path = root_dir + "/instruction_data/part-belle_multiturn_chat_0.8M-{}.jso
total_num = 0
file_num = 1
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in dataset["train"]:
for line in tqdm(dataset["train"]):
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
@ -125,7 +147,7 @@ write_path = root_dir + "/instruction_data/part-instruct_to_code-{}.jsonl.zst"
total_num = 0
file_num = 1
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in dataset["train"]:
for line in tqdm(dataset["train"]):
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
@ -141,6 +163,69 @@ print(
)
)
# dataset = load_dataset("bigscience/xP3mt", "en")
# write_path = root_dir + "/instruction_data/part-bigscience/xP3mt_en-{}.jsonl.zst"
# total_num = 0
# file_num = 1
# wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
# for line in tqdm(dataset["train"]):
# line = json.dumps(line)
# if total_num % 1024 == 0 and total_num > 0:
# file_num += 1
# wfp.close()
# wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
# wfp.write(line.encode("utf-8"))
# wfp.write(b"\n")
# total_num += 1
# wfp.close()
# print(
# "bigscience/xP3mt_en preprocess done. Total line: {}, Total file: {}".format(
# total_num, file_num
# )
# )
dataset = load_dataset("bigscience/xP3mt", "code")
write_path = root_dir + "/instruction_data/part-xP3mt_code-{}.jsonl.zst"
total_num = 0
file_num = 1
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in tqdm(dataset["train"]):
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
wfp.close()
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
wfp.write(line.encode("utf-8"))
wfp.write(b"\n")
total_num += 1
wfp.close()
print(
"bigscience/xP3mt_code preprocess done. Total line: {}, Total file: {}".format(
total_num, file_num
)
)
dataset = load_dataset("bigscience/xP3mt", "zh")
write_path = root_dir + "/instruction_data/part-xP3mt_zh-{}.jsonl.zst"
total_num = 0
file_num = 1
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in tqdm(dataset["train"]):
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
wfp.close()
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
wfp.write(line.encode("utf-8"))
wfp.write(b"\n")
total_num += 1
wfp.close()
print(
"bigscience/xP3mt_zh preprocess done. Total line: {}, Total file: {}".format(
total_num, file_num
)
)
write_path = root_dir + "/instruction_data/part-sharegpt_90K-{}.jsonl.zst"
total_num = 0
file_num = 1
@ -150,7 +235,7 @@ with open("{}/sg_90k_part1_html_cleaned.json".format(root_dir), "r") as fp:
with open("{}/sg_90k_part2_html_cleaned.json".format(root_dir), "r") as fp:
data2 = json.load(fp)
data = data1 + data2
for line in data:
for line in tqdm(data):
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1

View File

@ -93,6 +93,12 @@ def instruct_transform(batch):
chat = "user:{}\nsystem:{}".format(prompt, completion)
texts.append(chat)
texts = ["[multiturn_sep]".join(texts)]
# xP3 preprocess
elif "inputs" in batch and "targets" in batch:
inputs = batch["inputs"][0]
targets = batch["targets"][0]
text = "user:{}\nsystem:{}".format(inputs.strip(), targets.strip())
texts = [text]
else:
raise Exception("Unrecognized instruct dataset format.")
return {"text": texts}

View File

@ -26,6 +26,7 @@ class Trainer:
self.train_loader = train_loader
self.tokenizer = tokenizer
self.accelerator = accelerator
self.train_and_eval = config.get("train_and_eval", False)
self.gradient_accumulation_steps = config["train"].get(
"gradient_accumulation_steps", 1
)
@ -164,6 +165,7 @@ class Trainer:
if (
self.data_step % self.eval_interval == 0
and self.accelerator.is_main_process
and self.train_and_eval
):
self.eval()
# save state
@ -189,8 +191,10 @@ class Trainer:
wandb.log({"Training/Loss Scale": self.optim.scaler.get_scale()})
wandb.log({"Training/Data Step": self.data_step})
wandb.log({"Training/Global Step": self.global_step})
wandb.log({"Training/Epoch": self.epoch})
self.accelerator.print(
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
"Epoch: {}, Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
self.epoch,
self.global_step,
self.data_step,
losses["total_loss"],