update header config and add padding to concat_multiple_sequence

This commit is contained in:
LiangSong 2023-04-27 23:42:11 +08:00
parent db6cdb51d0
commit 49118aad42
6 changed files with 29 additions and 27 deletions

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-06 22:30:10
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-26 23:58:23
LastEditTime: 2023-04-27 20:34:58
FilePath: /Open-Llama/chat_server.py
Description:

View File

@ -3,7 +3,7 @@ data:
data:
mixed: "data/instruction_data/part-*.jsonl.zst"
pad_to_max: False
sequence_sample_mode: "sample"
sequence_sample_mode: "none"
concat_multiple_sequence: True
num_sequences: 50
seq_length: 2048
@ -24,6 +24,7 @@ train:
ckpt: "data/llama_raw_ckpt/7B/extended.pth"
train_num_workers: 16
gradient_accumulation_steps: 1
prefetch_factor: 100
# global step
log_interval: 50
eval_interval: 500

View File

@ -3,9 +3,9 @@ data:
data:
mixed: "data/pretrain_data/part-*.jsonl.zst"
pad_to_max: False
sequence_sample_mode: "sample"
sequence_sample_mode: "none"
concat_multiple_sequence: True
num_sequences: 20
num_sequences: 10
seq_length: 2048
tokenizer_model_path: "configs/llama_tokenizer_extended.model"
model:
@ -24,6 +24,7 @@ train:
ckpt: "data/llama_raw_ckpt/7B/extended.pth"
train_num_workers: 16
gradient_accumulation_steps: 12
prefetch_factor: 100
# global step
log_interval: 5
eval_interval: 200

View File

@ -2,12 +2,13 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-24 20:05:21
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-24 20:05:59
LastEditTime: 2023-04-27 22:19:37
FilePath: /Open-Llama/dataset/dataset.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import math
import torch
import random
from glob import glob
@ -98,13 +99,6 @@ def split_multiturn(batch):
return {"text": batch["text"][0].split("[multiturn_sep]")}
def truncation_gen(seq_length):
def truncation(line):
return {"input_ids": line["input_ids"][:seq_length]}
return truncation
def sample_sequence_gen(seq_length, eos_token_id):
def sample_sequence(line):
doc_length = line["input_ids"].shape[0]
@ -134,16 +128,15 @@ def split_sequence_gen(seq_length):
return split_sequence
def concat_multiple_sequence_gen(seq_length):
def concat_multiple_sequence_gen(seq_length, pad_token_id):
def concat_multiple_sequence(batch):
concat_input_ids = torch.cat(batch["input_ids"], dim=0)
input_ids = []
while len(concat_input_ids) >= (1 + len(input_ids)) * seq_length:
input_ids.append(
concat_input_ids[
len(input_ids) * seq_length : (1 + len(input_ids)) * seq_length
]
)
length = concat_input_ids.shape[0]
chunks = math.ceil(length / seq_length)
pad_length = chunks * seq_length - length
pad = torch.ones(pad_length, dtype=concat_input_ids.dtype) * pad_token_id
concat_input_ids = torch.cat([concat_input_ids, pad], dim=0)
input_ids = torch.chunk(concat_input_ids, chunks)
return {"input_ids": input_ids}
return concat_multiple_sequence
@ -170,6 +163,8 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
dataset = load_dataset(
"json", data_files=data_files, split="train", streaming=True
)
# shuffle
dataset = dataset.shuffle()
# 文本预处理转换为统一格式
if dataset_config["mode"] == "pretrain":
dataset = dataset.map(pretrain_transform, batched=True, batch_size=1)
@ -198,10 +193,12 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
return full_dataset
seq_length = dataset_config["seq_length"]
pad_to_max = dataset_config.get("pad_to_max", True)
sequence_sample_mode = dataset_config.get("sequence_sample_mode", "truncation")
truncation = sequence_sample_mode == "truncation"
concat_multiple_sequence = dataset_config.get("concat_multiple_sequence", False)
# tokenize
if dataset_config.get("pad_to_max", True):
if pad_to_max:
full_dataset = full_dataset.map(
lambda x: tokenizer(
x["text"],
@ -232,11 +229,12 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
elif sequence_sample_mode == "none":
pass
elif sequence_sample_mode == "sample":
assert pad_to_max or concat_multiple_sequence
full_dataset = full_dataset.map(
sample_sequence_gen(seq_length, tokenizer.eos_token_id)
)
elif sequence_sample_mode == "split":
assert not dataset_config.get("concat_multiple_sequence", False)
assert not concat_multiple_sequence
full_dataset = full_dataset.map(
split_sequence_gen(seq_length), batched=True, batch_size=1
)
@ -246,10 +244,10 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
)
# concat multiple sequence
if dataset_config.get("concat_multiple_sequence", False):
if concat_multiple_sequence:
num_sequences = dataset_config["num_sequences"]
full_dataset = full_dataset.map(
concat_multiple_sequence_gen(seq_length),
concat_multiple_sequence_gen(seq_length, tokenizer.pad_token_id),
batched=True,
batch_size=num_sequences,
drop_last_batch=True,

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-24 20:05:21
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-26 23:06:55
LastEditTime: 2023-04-27 20:34:47
FilePath: /Open-Llama/solver/trainer.py
Description:

View File

@ -2,8 +2,8 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-12 19:12:42
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-26 23:05:47
FilePath: /Open-Llama/pretrain.py
LastEditTime: 2023-04-27 23:08:47
FilePath: /Open-Llama/train_lm.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
@ -50,6 +50,8 @@ def main(argv):
train_dataset,
batch_size=config["train"]["train_batch_size"],
num_workers=config["train"]["train_num_workers"],
prefetch_factor=config["train"].get("prefetch_factor", 2),
pin_memory=True,
)
# smaller initializer_range make training more stable
# add stabel embedding to token embedding