fix long seq bug

This commit is contained in:
LiangSong 2023-03-31 10:12:28 +08:00
parent a62ac2658f
commit b9bc7eaf35
3 changed files with 12 additions and 5 deletions

View File

@ -22,4 +22,4 @@ log_interval = 50
eval_interval = 500 eval_interval = 500
save_interval = 1000 save_interval = 1000
work_dir = "data/saved_ckpt/" work_dir = "data/saved_ckpt/"
ckpt_path = "data/saved_ckpt/30000.pt" ckpt_path = "data/saved_ckpt/40000.pt"

View File

@ -21,11 +21,11 @@ def create_data_iter(paths, transform_dict=None, process_index=0, num_processes=
past = None past = None
for i, path in paths: for i, path in paths:
dataset_name = path.split("-")[-2] dataset_name = path.split("-")[-2]
if num_processes > 1 and i % num_processes != process_index:
continue
if past != dataset_name: if past != dataset_name:
print("Loading data from {}".format(path)) print("Loading data from {}".format(path))
past = path past = path
if num_processes > 1 and i % num_processes != process_index:
continue
if path.endswith("jsonl.zst"): if path.endswith("jsonl.zst"):
with zstd.open(path, "r", encoding="utf-8") as fp: with zstd.open(path, "r", encoding="utf-8") as fp:
for line in fp: for line in fp:

View File

@ -8,6 +8,7 @@ Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
""" """
import math
def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024): def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
@ -23,7 +24,10 @@ def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
total = "user:{}<s>system:{}".format(prompt.strip(), line["completion"].strip()) total = "user:{}<s>system:{}".format(prompt.strip(), line["completion"].strip())
out = tokenizer(total) out = tokenizer(total)
input_ids = out["input_ids"] input_ids = out["input_ids"]
return [input_ids] return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_self_instruction return preprocess_self_instruction
@ -43,7 +47,10 @@ def preprocess_belle_gen(tokenizer, segment_max_length=1024):
total = "user:{}<s>system:{}".format(prompt, completion) total = "user:{}<s>system:{}".format(prompt, completion)
out = tokenizer(total) out = tokenizer(total)
input_ids = out["input_ids"] input_ids = out["input_ids"]
return [input_ids] return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_belle return preprocess_belle