add xP3 dataset and belle_2M
This commit is contained in:
parent
00cbdbbf26
commit
85caa97a6a
|
@ -26,6 +26,7 @@ train:
|
||||||
train_num_workers: 16
|
train_num_workers: 16
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
prefetch_factor: 100
|
prefetch_factor: 100
|
||||||
|
train_and_eval: False
|
||||||
# global step
|
# global step
|
||||||
log_interval: 50
|
log_interval: 50
|
||||||
eval_interval: 500
|
eval_interval: 500
|
||||||
|
|
|
@ -28,6 +28,7 @@ train:
|
||||||
train_num_workers: 16
|
train_num_workers: 16
|
||||||
gradient_accumulation_steps: 12
|
gradient_accumulation_steps: 12
|
||||||
prefetch_factor: 100
|
prefetch_factor: 100
|
||||||
|
train_and_eval: True
|
||||||
# global step
|
# global step
|
||||||
log_interval: 5
|
log_interval: 5
|
||||||
eval_interval: 500
|
eval_interval: 500
|
||||||
|
|
|
@ -9,6 +9,7 @@ Description:
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
import zstandard as zstd
|
import zstandard as zstd
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
@ -20,7 +21,7 @@ write_path = root_dir + "/instruction_data/part-self_instruct-{}.jsonl.zst"
|
||||||
total_num = 0
|
total_num = 0
|
||||||
file_num = 1
|
file_num = 1
|
||||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
for line in dataset["train"]:
|
for line in tqdm(dataset["train"]):
|
||||||
line = json.dumps(line)
|
line = json.dumps(line)
|
||||||
if total_num % 1024 == 0 and total_num > 0:
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
file_num += 1
|
file_num += 1
|
||||||
|
@ -41,7 +42,7 @@ write_path = root_dir + "/instruction_data/part-belle_0.5M-{}.jsonl.zst"
|
||||||
total_num = 0
|
total_num = 0
|
||||||
file_num = 1
|
file_num = 1
|
||||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
for line in dataset["train"]:
|
for line in tqdm(dataset["train"]):
|
||||||
line = json.dumps(line)
|
line = json.dumps(line)
|
||||||
if total_num % 1024 == 0 and total_num > 0:
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
file_num += 1
|
file_num += 1
|
||||||
|
@ -62,7 +63,7 @@ write_path = root_dir + "/instruction_data/part-belle_1M-{}.jsonl.zst"
|
||||||
total_num = 0
|
total_num = 0
|
||||||
file_num = 1
|
file_num = 1
|
||||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
for line in dataset["train"]:
|
for line in tqdm(dataset["train"]):
|
||||||
line = json.dumps(line)
|
line = json.dumps(line)
|
||||||
if total_num % 1024 == 0 and total_num > 0:
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
file_num += 1
|
file_num += 1
|
||||||
|
@ -78,12 +79,33 @@ print(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dataset = load_dataset("BelleGroup/train_2M_CN")
|
||||||
|
write_path = root_dir + "/instruction_data/part-belle_2M-{}.jsonl.zst"
|
||||||
|
total_num = 0
|
||||||
|
file_num = 1
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
for line in tqdm(dataset["train"]):
|
||||||
|
line = json.dumps(line)
|
||||||
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
|
file_num += 1
|
||||||
|
wfp.close()
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
wfp.write(line.encode("utf-8"))
|
||||||
|
wfp.write(b"\n")
|
||||||
|
total_num += 1
|
||||||
|
wfp.close()
|
||||||
|
print(
|
||||||
|
"BelleGroup/train_2M_CN preprocess done. Total line: {}, Total file: {}".format(
|
||||||
|
total_num, file_num
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
dataset = load_dataset("BelleGroup/school_math_0.25M")
|
dataset = load_dataset("BelleGroup/school_math_0.25M")
|
||||||
write_path = root_dir + "/instruction_data/part-belle_school_math_0.25M-{}.jsonl.zst"
|
write_path = root_dir + "/instruction_data/part-belle_school_math_0.25M-{}.jsonl.zst"
|
||||||
total_num = 0
|
total_num = 0
|
||||||
file_num = 1
|
file_num = 1
|
||||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
for line in dataset["train"]:
|
for line in tqdm(dataset["train"]):
|
||||||
line = json.dumps(line)
|
line = json.dumps(line)
|
||||||
if total_num % 1024 == 0 and total_num > 0:
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
file_num += 1
|
file_num += 1
|
||||||
|
@ -104,7 +126,7 @@ write_path = root_dir + "/instruction_data/part-belle_multiturn_chat_0.8M-{}.jso
|
||||||
total_num = 0
|
total_num = 0
|
||||||
file_num = 1
|
file_num = 1
|
||||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
for line in dataset["train"]:
|
for line in tqdm(dataset["train"]):
|
||||||
line = json.dumps(line)
|
line = json.dumps(line)
|
||||||
if total_num % 1024 == 0 and total_num > 0:
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
file_num += 1
|
file_num += 1
|
||||||
|
@ -125,7 +147,7 @@ write_path = root_dir + "/instruction_data/part-instruct_to_code-{}.jsonl.zst"
|
||||||
total_num = 0
|
total_num = 0
|
||||||
file_num = 1
|
file_num = 1
|
||||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
for line in dataset["train"]:
|
for line in tqdm(dataset["train"]):
|
||||||
line = json.dumps(line)
|
line = json.dumps(line)
|
||||||
if total_num % 1024 == 0 and total_num > 0:
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
file_num += 1
|
file_num += 1
|
||||||
|
@ -141,6 +163,69 @@ print(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# dataset = load_dataset("bigscience/xP3mt", "en")
|
||||||
|
# write_path = root_dir + "/instruction_data/part-bigscience/xP3mt_en-{}.jsonl.zst"
|
||||||
|
# total_num = 0
|
||||||
|
# file_num = 1
|
||||||
|
# wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
# for line in tqdm(dataset["train"]):
|
||||||
|
# line = json.dumps(line)
|
||||||
|
# if total_num % 1024 == 0 and total_num > 0:
|
||||||
|
# file_num += 1
|
||||||
|
# wfp.close()
|
||||||
|
# wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
# wfp.write(line.encode("utf-8"))
|
||||||
|
# wfp.write(b"\n")
|
||||||
|
# total_num += 1
|
||||||
|
# wfp.close()
|
||||||
|
# print(
|
||||||
|
# "bigscience/xP3mt_en preprocess done. Total line: {}, Total file: {}".format(
|
||||||
|
# total_num, file_num
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
|
dataset = load_dataset("bigscience/xP3mt", "code")
|
||||||
|
write_path = root_dir + "/instruction_data/part-xP3mt_code-{}.jsonl.zst"
|
||||||
|
total_num = 0
|
||||||
|
file_num = 1
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
for line in tqdm(dataset["train"]):
|
||||||
|
line = json.dumps(line)
|
||||||
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
|
file_num += 1
|
||||||
|
wfp.close()
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
wfp.write(line.encode("utf-8"))
|
||||||
|
wfp.write(b"\n")
|
||||||
|
total_num += 1
|
||||||
|
wfp.close()
|
||||||
|
print(
|
||||||
|
"bigscience/xP3mt_code preprocess done. Total line: {}, Total file: {}".format(
|
||||||
|
total_num, file_num
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = load_dataset("bigscience/xP3mt", "zh")
|
||||||
|
write_path = root_dir + "/instruction_data/part-xP3mt_zh-{}.jsonl.zst"
|
||||||
|
total_num = 0
|
||||||
|
file_num = 1
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
for line in tqdm(dataset["train"]):
|
||||||
|
line = json.dumps(line)
|
||||||
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
|
file_num += 1
|
||||||
|
wfp.close()
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
wfp.write(line.encode("utf-8"))
|
||||||
|
wfp.write(b"\n")
|
||||||
|
total_num += 1
|
||||||
|
wfp.close()
|
||||||
|
print(
|
||||||
|
"bigscience/xP3mt_zh preprocess done. Total line: {}, Total file: {}".format(
|
||||||
|
total_num, file_num
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
write_path = root_dir + "/instruction_data/part-sharegpt_90K-{}.jsonl.zst"
|
write_path = root_dir + "/instruction_data/part-sharegpt_90K-{}.jsonl.zst"
|
||||||
total_num = 0
|
total_num = 0
|
||||||
file_num = 1
|
file_num = 1
|
||||||
|
@ -150,7 +235,7 @@ with open("{}/sg_90k_part1_html_cleaned.json".format(root_dir), "r") as fp:
|
||||||
with open("{}/sg_90k_part2_html_cleaned.json".format(root_dir), "r") as fp:
|
with open("{}/sg_90k_part2_html_cleaned.json".format(root_dir), "r") as fp:
|
||||||
data2 = json.load(fp)
|
data2 = json.load(fp)
|
||||||
data = data1 + data2
|
data = data1 + data2
|
||||||
for line in data:
|
for line in tqdm(data):
|
||||||
line = json.dumps(line)
|
line = json.dumps(line)
|
||||||
if total_num % 1024 == 0 and total_num > 0:
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
file_num += 1
|
file_num += 1
|
||||||
|
|
|
@ -93,6 +93,12 @@ def instruct_transform(batch):
|
||||||
chat = "user:{}\nsystem:{}".format(prompt, completion)
|
chat = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
texts.append(chat)
|
texts.append(chat)
|
||||||
texts = ["[multiturn_sep]".join(texts)]
|
texts = ["[multiturn_sep]".join(texts)]
|
||||||
|
# xP3 preprocess
|
||||||
|
elif "inputs" in batch and "targets" in batch:
|
||||||
|
inputs = batch["inputs"][0]
|
||||||
|
targets = batch["targets"][0]
|
||||||
|
text = "user:{}\nsystem:{}".format(inputs.strip(), targets.strip())
|
||||||
|
texts = [text]
|
||||||
else:
|
else:
|
||||||
raise Exception("Unrecognized instruct dataset format.")
|
raise Exception("Unrecognized instruct dataset format.")
|
||||||
return {"text": texts}
|
return {"text": texts}
|
||||||
|
|
|
@ -26,6 +26,7 @@ class Trainer:
|
||||||
self.train_loader = train_loader
|
self.train_loader = train_loader
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.accelerator = accelerator
|
self.accelerator = accelerator
|
||||||
|
self.train_and_eval = config.get("train_and_eval", False)
|
||||||
self.gradient_accumulation_steps = config["train"].get(
|
self.gradient_accumulation_steps = config["train"].get(
|
||||||
"gradient_accumulation_steps", 1
|
"gradient_accumulation_steps", 1
|
||||||
)
|
)
|
||||||
|
@ -164,6 +165,7 @@ class Trainer:
|
||||||
if (
|
if (
|
||||||
self.data_step % self.eval_interval == 0
|
self.data_step % self.eval_interval == 0
|
||||||
and self.accelerator.is_main_process
|
and self.accelerator.is_main_process
|
||||||
|
and self.train_and_eval
|
||||||
):
|
):
|
||||||
self.eval()
|
self.eval()
|
||||||
# save state
|
# save state
|
||||||
|
@ -189,8 +191,10 @@ class Trainer:
|
||||||
wandb.log({"Training/Loss Scale": self.optim.scaler.get_scale()})
|
wandb.log({"Training/Loss Scale": self.optim.scaler.get_scale()})
|
||||||
wandb.log({"Training/Data Step": self.data_step})
|
wandb.log({"Training/Data Step": self.data_step})
|
||||||
wandb.log({"Training/Global Step": self.global_step})
|
wandb.log({"Training/Global Step": self.global_step})
|
||||||
|
wandb.log({"Training/Epoch": self.epoch})
|
||||||
self.accelerator.print(
|
self.accelerator.print(
|
||||||
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
|
"Epoch: {}, Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
|
||||||
|
self.epoch,
|
||||||
self.global_step,
|
self.global_step,
|
||||||
self.data_step,
|
self.data_step,
|
||||||
losses["total_loss"],
|
losses["total_loss"],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user