update header config and add padding to concat_multiple_sequence
This commit is contained in:
parent
db6cdb51d0
commit
49118aad42
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-06 22:30:10
|
Date: 2023-04-06 22:30:10
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-26 23:58:23
|
LastEditTime: 2023-04-27 20:34:58
|
||||||
FilePath: /Open-Llama/chat_server.py
|
FilePath: /Open-Llama/chat_server.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ data:
|
||||||
data:
|
data:
|
||||||
mixed: "data/instruction_data/part-*.jsonl.zst"
|
mixed: "data/instruction_data/part-*.jsonl.zst"
|
||||||
pad_to_max: False
|
pad_to_max: False
|
||||||
sequence_sample_mode: "sample"
|
sequence_sample_mode: "none"
|
||||||
concat_multiple_sequence: True
|
concat_multiple_sequence: True
|
||||||
num_sequences: 50
|
num_sequences: 50
|
||||||
seq_length: 2048
|
seq_length: 2048
|
||||||
|
@ -24,6 +24,7 @@ train:
|
||||||
ckpt: "data/llama_raw_ckpt/7B/extended.pth"
|
ckpt: "data/llama_raw_ckpt/7B/extended.pth"
|
||||||
train_num_workers: 16
|
train_num_workers: 16
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
|
prefetch_factor: 100
|
||||||
# global step
|
# global step
|
||||||
log_interval: 50
|
log_interval: 50
|
||||||
eval_interval: 500
|
eval_interval: 500
|
||||||
|
|
|
@ -3,9 +3,9 @@ data:
|
||||||
data:
|
data:
|
||||||
mixed: "data/pretrain_data/part-*.jsonl.zst"
|
mixed: "data/pretrain_data/part-*.jsonl.zst"
|
||||||
pad_to_max: False
|
pad_to_max: False
|
||||||
sequence_sample_mode: "sample"
|
sequence_sample_mode: "none"
|
||||||
concat_multiple_sequence: True
|
concat_multiple_sequence: True
|
||||||
num_sequences: 20
|
num_sequences: 10
|
||||||
seq_length: 2048
|
seq_length: 2048
|
||||||
tokenizer_model_path: "configs/llama_tokenizer_extended.model"
|
tokenizer_model_path: "configs/llama_tokenizer_extended.model"
|
||||||
model:
|
model:
|
||||||
|
@ -24,6 +24,7 @@ train:
|
||||||
ckpt: "data/llama_raw_ckpt/7B/extended.pth"
|
ckpt: "data/llama_raw_ckpt/7B/extended.pth"
|
||||||
train_num_workers: 16
|
train_num_workers: 16
|
||||||
gradient_accumulation_steps: 12
|
gradient_accumulation_steps: 12
|
||||||
|
prefetch_factor: 100
|
||||||
# global step
|
# global step
|
||||||
log_interval: 5
|
log_interval: 5
|
||||||
eval_interval: 200
|
eval_interval: 200
|
||||||
|
|
|
@ -2,12 +2,13 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-24 20:05:21
|
Date: 2023-04-24 20:05:21
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-24 20:05:59
|
LastEditTime: 2023-04-27 22:19:37
|
||||||
FilePath: /Open-Llama/dataset/dataset.py
|
FilePath: /Open-Llama/dataset/dataset.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
@ -98,13 +99,6 @@ def split_multiturn(batch):
|
||||||
return {"text": batch["text"][0].split("[multiturn_sep]")}
|
return {"text": batch["text"][0].split("[multiturn_sep]")}
|
||||||
|
|
||||||
|
|
||||||
def truncation_gen(seq_length):
|
|
||||||
def truncation(line):
|
|
||||||
return {"input_ids": line["input_ids"][:seq_length]}
|
|
||||||
|
|
||||||
return truncation
|
|
||||||
|
|
||||||
|
|
||||||
def sample_sequence_gen(seq_length, eos_token_id):
|
def sample_sequence_gen(seq_length, eos_token_id):
|
||||||
def sample_sequence(line):
|
def sample_sequence(line):
|
||||||
doc_length = line["input_ids"].shape[0]
|
doc_length = line["input_ids"].shape[0]
|
||||||
|
@ -134,16 +128,15 @@ def split_sequence_gen(seq_length):
|
||||||
return split_sequence
|
return split_sequence
|
||||||
|
|
||||||
|
|
||||||
def concat_multiple_sequence_gen(seq_length):
|
def concat_multiple_sequence_gen(seq_length, pad_token_id):
|
||||||
def concat_multiple_sequence(batch):
|
def concat_multiple_sequence(batch):
|
||||||
concat_input_ids = torch.cat(batch["input_ids"], dim=0)
|
concat_input_ids = torch.cat(batch["input_ids"], dim=0)
|
||||||
input_ids = []
|
length = concat_input_ids.shape[0]
|
||||||
while len(concat_input_ids) >= (1 + len(input_ids)) * seq_length:
|
chunks = math.ceil(length / seq_length)
|
||||||
input_ids.append(
|
pad_length = chunks * seq_length - length
|
||||||
concat_input_ids[
|
pad = torch.ones(pad_length, dtype=concat_input_ids.dtype) * pad_token_id
|
||||||
len(input_ids) * seq_length : (1 + len(input_ids)) * seq_length
|
concat_input_ids = torch.cat([concat_input_ids, pad], dim=0)
|
||||||
]
|
input_ids = torch.chunk(concat_input_ids, chunks)
|
||||||
)
|
|
||||||
return {"input_ids": input_ids}
|
return {"input_ids": input_ids}
|
||||||
|
|
||||||
return concat_multiple_sequence
|
return concat_multiple_sequence
|
||||||
|
@ -170,6 +163,8 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
"json", data_files=data_files, split="train", streaming=True
|
"json", data_files=data_files, split="train", streaming=True
|
||||||
)
|
)
|
||||||
|
# shuffle
|
||||||
|
dataset = dataset.shuffle()
|
||||||
# 文本预处理转换为统一格式
|
# 文本预处理转换为统一格式
|
||||||
if dataset_config["mode"] == "pretrain":
|
if dataset_config["mode"] == "pretrain":
|
||||||
dataset = dataset.map(pretrain_transform, batched=True, batch_size=1)
|
dataset = dataset.map(pretrain_transform, batched=True, batch_size=1)
|
||||||
|
@ -198,10 +193,12 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
||||||
return full_dataset
|
return full_dataset
|
||||||
|
|
||||||
seq_length = dataset_config["seq_length"]
|
seq_length = dataset_config["seq_length"]
|
||||||
|
pad_to_max = dataset_config.get("pad_to_max", True)
|
||||||
sequence_sample_mode = dataset_config.get("sequence_sample_mode", "truncation")
|
sequence_sample_mode = dataset_config.get("sequence_sample_mode", "truncation")
|
||||||
truncation = sequence_sample_mode == "truncation"
|
truncation = sequence_sample_mode == "truncation"
|
||||||
|
concat_multiple_sequence = dataset_config.get("concat_multiple_sequence", False)
|
||||||
# tokenize
|
# tokenize
|
||||||
if dataset_config.get("pad_to_max", True):
|
if pad_to_max:
|
||||||
full_dataset = full_dataset.map(
|
full_dataset = full_dataset.map(
|
||||||
lambda x: tokenizer(
|
lambda x: tokenizer(
|
||||||
x["text"],
|
x["text"],
|
||||||
|
@ -232,11 +229,12 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
||||||
elif sequence_sample_mode == "none":
|
elif sequence_sample_mode == "none":
|
||||||
pass
|
pass
|
||||||
elif sequence_sample_mode == "sample":
|
elif sequence_sample_mode == "sample":
|
||||||
|
assert pad_to_max or concat_multiple_sequence
|
||||||
full_dataset = full_dataset.map(
|
full_dataset = full_dataset.map(
|
||||||
sample_sequence_gen(seq_length, tokenizer.eos_token_id)
|
sample_sequence_gen(seq_length, tokenizer.eos_token_id)
|
||||||
)
|
)
|
||||||
elif sequence_sample_mode == "split":
|
elif sequence_sample_mode == "split":
|
||||||
assert not dataset_config.get("concat_multiple_sequence", False)
|
assert not concat_multiple_sequence
|
||||||
full_dataset = full_dataset.map(
|
full_dataset = full_dataset.map(
|
||||||
split_sequence_gen(seq_length), batched=True, batch_size=1
|
split_sequence_gen(seq_length), batched=True, batch_size=1
|
||||||
)
|
)
|
||||||
|
@ -246,10 +244,10 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
|
||||||
)
|
)
|
||||||
|
|
||||||
# concat multiple sequence
|
# concat multiple sequence
|
||||||
if dataset_config.get("concat_multiple_sequence", False):
|
if concat_multiple_sequence:
|
||||||
num_sequences = dataset_config["num_sequences"]
|
num_sequences = dataset_config["num_sequences"]
|
||||||
full_dataset = full_dataset.map(
|
full_dataset = full_dataset.map(
|
||||||
concat_multiple_sequence_gen(seq_length),
|
concat_multiple_sequence_gen(seq_length, tokenizer.pad_token_id),
|
||||||
batched=True,
|
batched=True,
|
||||||
batch_size=num_sequences,
|
batch_size=num_sequences,
|
||||||
drop_last_batch=True,
|
drop_last_batch=True,
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-24 20:05:21
|
Date: 2023-04-24 20:05:21
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-26 23:06:55
|
LastEditTime: 2023-04-27 20:34:47
|
||||||
FilePath: /Open-Llama/solver/trainer.py
|
FilePath: /Open-Llama/solver/trainer.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,8 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-12 19:12:42
|
Date: 2023-04-12 19:12:42
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-26 23:05:47
|
LastEditTime: 2023-04-27 23:08:47
|
||||||
FilePath: /Open-Llama/pretrain.py
|
FilePath: /Open-Llama/train_lm.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
@ -50,6 +50,8 @@ def main(argv):
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=config["train"]["train_batch_size"],
|
batch_size=config["train"]["train_batch_size"],
|
||||||
num_workers=config["train"]["train_num_workers"],
|
num_workers=config["train"]["train_num_workers"],
|
||||||
|
prefetch_factor=config["train"].get("prefetch_factor", 2),
|
||||||
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
# smaller initializer_range make training more stable
|
# smaller initializer_range make training more stable
|
||||||
# add stabel embedding to token embedding
|
# add stabel embedding to token embedding
|
||||||
|
|
Loading…
Reference in New Issue
Block a user