fix long seq bug

This commit is contained in:
LiangSong 2023-03-31 10:12:28 +08:00
parent a62ac2658f
commit b9bc7eaf35
3 changed files with 12 additions and 5 deletions

View File

@ -22,4 +22,4 @@ log_interval = 50
eval_interval = 500
save_interval = 1000
work_dir = "data/saved_ckpt/"
ckpt_path = "data/saved_ckpt/30000.pt"
ckpt_path = "data/saved_ckpt/40000.pt"

View File

@ -21,11 +21,11 @@ def create_data_iter(paths, transform_dict=None, process_index=0, num_processes=
past = None
for i, path in paths:
dataset_name = path.split("-")[-2]
if num_processes > 1 and i % num_processes != process_index:
continue
if past != dataset_name:
print("Loading data from {}".format(path))
past = path
if num_processes > 1 and i % num_processes != process_index:
continue
if path.endswith("jsonl.zst"):
with zstd.open(path, "r", encoding="utf-8") as fp:
for line in fp:

View File

@ -8,6 +8,7 @@ Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import math
def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
@ -23,7 +24,10 @@ def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
total = "user:{}<s>system:{}".format(prompt.strip(), line["completion"].strip())
out = tokenizer(total)
input_ids = out["input_ids"]
return [input_ids]
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_self_instruction
@ -43,7 +47,10 @@ def preprocess_belle_gen(tokenizer, segment_max_length=1024):
total = "user:{}<s>system:{}".format(prompt, completion)
out = tokenizer(total)
input_ids = out["input_ids"]
return [input_ids]
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_belle