fix long seq bug
This commit is contained in:
parent
a62ac2658f
commit
b9bc7eaf35
|
@ -22,4 +22,4 @@ log_interval = 50
|
||||||
eval_interval = 500
|
eval_interval = 500
|
||||||
save_interval = 1000
|
save_interval = 1000
|
||||||
work_dir = "data/saved_ckpt/"
|
work_dir = "data/saved_ckpt/"
|
||||||
ckpt_path = "data/saved_ckpt/30000.pt"
|
ckpt_path = "data/saved_ckpt/40000.pt"
|
||||||
|
|
|
@ -21,11 +21,11 @@ def create_data_iter(paths, transform_dict=None, process_index=0, num_processes=
|
||||||
past = None
|
past = None
|
||||||
for i, path in paths:
|
for i, path in paths:
|
||||||
dataset_name = path.split("-")[-2]
|
dataset_name = path.split("-")[-2]
|
||||||
|
if num_processes > 1 and i % num_processes != process_index:
|
||||||
|
continue
|
||||||
if past != dataset_name:
|
if past != dataset_name:
|
||||||
print("Loading data from {}".format(path))
|
print("Loading data from {}".format(path))
|
||||||
past = path
|
past = path
|
||||||
if num_processes > 1 and i % num_processes != process_index:
|
|
||||||
continue
|
|
||||||
if path.endswith("jsonl.zst"):
|
if path.endswith("jsonl.zst"):
|
||||||
with zstd.open(path, "r", encoding="utf-8") as fp:
|
with zstd.open(path, "r", encoding="utf-8") as fp:
|
||||||
for line in fp:
|
for line in fp:
|
||||||
|
|
|
@ -8,6 +8,7 @@ Description:
|
||||||
|
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
|
def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
|
||||||
|
@ -23,7 +24,10 @@ def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
|
||||||
total = "user:{}<s>system:{}".format(prompt.strip(), line["completion"].strip())
|
total = "user:{}<s>system:{}".format(prompt.strip(), line["completion"].strip())
|
||||||
out = tokenizer(total)
|
out = tokenizer(total)
|
||||||
input_ids = out["input_ids"]
|
input_ids = out["input_ids"]
|
||||||
return [input_ids]
|
return [
|
||||||
|
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||||
|
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||||
|
]
|
||||||
|
|
||||||
return preprocess_self_instruction
|
return preprocess_self_instruction
|
||||||
|
|
||||||
|
@ -43,7 +47,10 @@ def preprocess_belle_gen(tokenizer, segment_max_length=1024):
|
||||||
total = "user:{}<s>system:{}".format(prompt, completion)
|
total = "user:{}<s>system:{}".format(prompt, completion)
|
||||||
out = tokenizer(total)
|
out = tokenizer(total)
|
||||||
input_ids = out["input_ids"]
|
input_ids = out["input_ids"]
|
||||||
return [input_ids]
|
return [
|
||||||
|
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||||
|
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||||
|
]
|
||||||
|
|
||||||
return preprocess_belle
|
return preprocess_belle
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user