update header config and add padding to concat_multiple_sequence
This commit is contained in:
parent
db6cdb51d0
commit
49118aad42
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-06 22:30:10
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-26 23:58:23
|
||||
LastEditTime: 2023-04-27 20:34:58
|
||||
FilePath: /Open-Llama/chat_server.py
|
||||
Description:
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ data:
|
|||
data:
|
||||
mixed: "data/instruction_data/part-*.jsonl.zst"
|
||||
pad_to_max: False
|
||||
sequence_sample_mode: "sample"
|
||||
sequence_sample_mode: "none"
|
||||
concat_multiple_sequence: True
|
||||
num_sequences: 50
|
||||
seq_length: 2048
|
||||
|
@ -24,6 +24,7 @@ train:
|
|||
ckpt: "data/llama_raw_ckpt/7B/extended.pth"
|
||||
train_num_workers: 16
|
||||
gradient_accumulation_steps: 1
|
||||
prefetch_factor: 100
|
||||
# global step
|
||||
log_interval: 50
|
||||
eval_interval: 500
|
||||
|
|
|
@ -3,9 +3,9 @@ data:
|
|||
data:
|
||||
mixed: "data/pretrain_data/part-*.jsonl.zst"
|
||||
pad_to_max: False
|
||||
sequence_sample_mode: "sample"
|
||||
sequence_sample_mode: "none"
|
||||
concat_multiple_sequence: True
|
||||
num_sequences: 20
|
||||
num_sequences: 10
|
||||
seq_length: 2048
|
||||
tokenizer_model_path: "configs/llama_tokenizer_extended.model"
|
||||
model:
|
||||
|
@ -24,6 +24,7 @@ train:
|
|||
ckpt: "data/llama_raw_ckpt/7B/extended.pth"
|
||||
train_num_workers: 16
|
||||
gradient_accumulation_steps: 12
|
||||
prefetch_factor: 100
|
||||
# global step
|
||||
log_interval: 5
|
||||
eval_interval: 200
|
||||
|
|
|
@ -2,12 +2,13 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-24 20:05:21
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-24 20:05:59
|
||||
LastEditTime: 2023-04-27 22:19:37
|
||||
FilePath: /Open-Llama/dataset/dataset.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
import random
|
||||
from glob import glob
|
||||
|
@ -98,13 +99,6 @@ def split_multiturn(batch):
|
|||
return {"text": batch["text"][0].split("[multiturn_sep]")}
|
||||
|
||||
|
||||
def truncation_gen(seq_length):
|
||||
def truncation(line):
|
||||
return {"input_ids": line["input_ids"][:seq_length]}
|
||||
|
||||
return truncation
|
||||
|
||||
|
||||
def sample_sequence_gen(seq_length, eos_token_id):
|
||||
def sample_sequence(line):
|
||||
doc_length = line["input_ids"].shape[0]
|
||||
|
@ -134,16 +128,15 @@ def split_sequence_gen(seq_length):
|
|||
return split_sequence
|
||||
|
||||
|
||||
def concat_multiple_sequence_gen(seq_length):
|
||||
def concat_multiple_sequence_gen(seq_length, pad_token_id):
|
||||
def concat_multiple_sequence(batch):
|
||||
concat_input_ids = torch.cat(batch["input_ids"], dim=0)
|
||||
input_ids = []
|
||||
while len(concat_input_ids) >= (1 + len(input_ids)) * seq_length:
|
||||
input_ids.append(
|
||||
concat_input_ids[
|
||||
len(input_ids) * seq_length : (1 + len(input_ids)) * seq_length
|
||||
]
|
||||
)
|
||||
length = concat_input_ids.shape[0]
|
||||
chunks = math.ceil(length / seq_length)
|
||||
pad_length = chunks * seq_length - length
|
||||
pad = torch.ones(pad_length, dtype=concat_input_ids.dtype) * pad_token_id
|
||||
concat_input_ids = torch.cat([concat_input_ids, pad], dim=0)
|
||||
input_ids = torch.chunk(concat_input_ids, chunks)
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
return concat_multiple_sequence
|
||||
|
@ -170,6 +163,8 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
|||
dataset = load_dataset(
|
||||
"json", data_files=data_files, split="train", streaming=True
|
||||
)
|
||||
# shuffle
|
||||
dataset = dataset.shuffle()
|
||||
# 文本预处理转换为统一格式
|
||||
if dataset_config["mode"] == "pretrain":
|
||||
dataset = dataset.map(pretrain_transform, batched=True, batch_size=1)
|
||||
|
@ -198,10 +193,12 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
|||
return full_dataset
|
||||
|
||||
seq_length = dataset_config["seq_length"]
|
||||
pad_to_max = dataset_config.get("pad_to_max", True)
|
||||
sequence_sample_mode = dataset_config.get("sequence_sample_mode", "truncation")
|
||||
truncation = sequence_sample_mode == "truncation"
|
||||
concat_multiple_sequence = dataset_config.get("concat_multiple_sequence", False)
|
||||
# tokenize
|
||||
if dataset_config.get("pad_to_max", True):
|
||||
if pad_to_max:
|
||||
full_dataset = full_dataset.map(
|
||||
lambda x: tokenizer(
|
||||
x["text"],
|
||||
|
@ -232,11 +229,12 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
|||
elif sequence_sample_mode == "none":
|
||||
pass
|
||||
elif sequence_sample_mode == "sample":
|
||||
assert pad_to_max or concat_multiple_sequence
|
||||
full_dataset = full_dataset.map(
|
||||
sample_sequence_gen(seq_length, tokenizer.eos_token_id)
|
||||
)
|
||||
elif sequence_sample_mode == "split":
|
||||
assert not dataset_config.get("concat_multiple_sequence", False)
|
||||
assert not concat_multiple_sequence
|
||||
full_dataset = full_dataset.map(
|
||||
split_sequence_gen(seq_length), batched=True, batch_size=1
|
||||
)
|
||||
|
@ -246,10 +244,10 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
|||
)
|
||||
|
||||
# concat multiple sequence
|
||||
if dataset_config.get("concat_multiple_sequence", False):
|
||||
if concat_multiple_sequence:
|
||||
num_sequences = dataset_config["num_sequences"]
|
||||
full_dataset = full_dataset.map(
|
||||
concat_multiple_sequence_gen(seq_length),
|
||||
concat_multiple_sequence_gen(seq_length, tokenizer.pad_token_id),
|
||||
batched=True,
|
||||
batch_size=num_sequences,
|
||||
drop_last_batch=True,
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-24 20:05:21
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-26 23:06:55
|
||||
LastEditTime: 2023-04-27 20:34:47
|
||||
FilePath: /Open-Llama/solver/trainer.py
|
||||
Description:
|
||||
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-12 19:12:42
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-26 23:05:47
|
||||
FilePath: /Open-Llama/pretrain.py
|
||||
LastEditTime: 2023-04-27 23:08:47
|
||||
FilePath: /Open-Llama/train_lm.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
|
@ -50,6 +50,8 @@ def main(argv):
|
|||
train_dataset,
|
||||
batch_size=config["train"]["train_batch_size"],
|
||||
num_workers=config["train"]["train_num_workers"],
|
||||
prefetch_factor=config["train"].get("prefetch_factor", 2),
|
||||
pin_memory=True,
|
||||
)
|
||||
# smaller initializer_range make training more stable
|
||||
# add stabel embedding to token embedding
|
||||
|
|
Loading…
Reference in New Issue
Block a user