unified pre-training and instrcution-tuning both use train_lm and dataset

This commit is contained in:
LiangSong 2023-04-27 19:42:06 +08:00
parent 97aff0e051
commit db6cdb51d0
11 changed files with 225 additions and 443 deletions

View File

@ -102,8 +102,7 @@ with gr.Blocks() as demo:
context = context.cuda()
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
pred = pred[:, inputs_len:]
pred = tokenizer.decode(pred.cpu()[0])
pred = pred.strip()
pred = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
print(pred)
bot_message = parse_codeblock(pred)
history[-1][1] = bot_message

View File

@ -1,7 +1,6 @@
compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 12
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none

View File

@ -0,0 +1,32 @@
data:
mode: "instruct"
data:
mixed: "data/instruction_data/part-*.jsonl.zst"
pad_to_max: False
sequence_sample_mode: "sample"
concat_multiple_sequence: True
num_sequences: 50
seq_length: 2048
tokenizer_model_path: "configs/llama_tokenizer_extended.model"
model:
initializer_range: 1.0e-2
hidden_dropout_prob: 0.1
attention_dropout_prob: 0.1
use_stable_embedding: False
shared_input_output_embedding: False
train:
train_batch_size: 2
num_training_steps: 1000000
num_warmup_steps: 2000
initializer_range: 1.0e-2
lr: 2.0e-4
weight_decay: 1.0e-1
ckpt: "data/llama_raw_ckpt/7B/extended.pth"
train_num_workers: 16
gradient_accumulation_steps: 1
# global step
log_interval: 50
eval_interval: 500
save_interval: 1000
work_dir: "data/saved_ckpt/7B"
project_name: "Llama Instruction"

View File

@ -1,25 +0,0 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 21:38:07
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-06 03:37:23
FilePath: /Open-Llama/configs/instruction_tuning_config.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
max_length = 1024
train_batch_size = 2
num_training_steps = 40000
num_warmup_steps = 100
initializer_range = 1e-2
lr = 2e-4
weight_decay = 1e-1
tokenizer_model_path = "configs/10w_vocab_wudao5_pile10.model"
patterns = ["data/instruction_data/part-*.jsonl.zst"]
# global step
log_interval = 50
eval_interval = 500
save_interval = 1000
work_dir = "data/saved_ckpt/"
ckpt_path = "data/saved_ckpt/83200.pt"

View File

@ -2,16 +2,18 @@ data:
mode: "pretrain"
data:
mixed: "data/pretrain_data/part-*.jsonl.zst"
pad_to_max: False
sequence_sample_mode: "sample"
concat_multiple_sequence: True
num_sequences: 10
num_sequences: 20
seq_length: 2048
tokenizer_model_path: "configs/llama_tokenizer_extended.model"
model:
initializer_range: 1.0e-2
hidden_dropout_prob: 0.1
attention_dropout_prob: 0.1
use_stable_embedding: True
shared_input_output_embedding: True
use_stable_embedding: False
shared_input_output_embedding: False
train:
train_batch_size: 2
num_training_steps: 1000000
@ -19,11 +21,12 @@ train:
initializer_range: 1.0e-2
lr: 2.0e-4
weight_decay: 1.0e-1
ckpt: null
ckpt: "data/llama_raw_ckpt/7B/extended.pth"
train_num_workers: 16
gradient_accumulation_steps: 12
# global step
log_interval: 5
eval_interval: 200
save_interval: 800
work_dir: "data/saved_ckpt/"
work_dir: "data/saved_ckpt/7B"
project_name: "Llama Pretrain"

View File

@ -14,15 +14,100 @@ from glob import glob
from datasets import load_dataset, interleave_datasets
def pretrain_transform(line):
if "title" in line and "text" not in line:
line["text"] = line["title"] + "\n" + line["content"]
return line
def pretrain_transform(batch):
# wudao preprocess
if "title" in batch and "content" in batch:
assert len(batch["title"]) == 1
batch["text"] = [batch["title"][0] + "\n" + batch["content"][0]]
elif "text" in batch:
pass
else:
raise Exception("Unrecognized pretrain dataset format.")
return batch
def instruct_transform(batch):
# self instruct preprocess
if "prompt" in batch and "completion" in batch:
prompt = batch["prompt"][0]
completion = batch["completion"][0]
if prompt.endswith("Output:"):
prompt = prompt[:-7]
text = "user:{}\nsystem:{}".format(prompt.strip(), completion.strip())
texts = [text]
# belle preprocess
elif "instruction" in batch and "output" in batch:
prompt = batch["instruction"][0].replace("\\n", "")
prompt = prompt.strip("")
completion = batch["output"][0].replace("\\n", "")
completion = completion.strip("")
# multi turn chat
if "Human:" in prompt:
texts = []
chats = prompt + completion
chats = chats.split("Human:")
for chat in chats:
if chat.strip() == "":
continue
res = chat.split("Assistant:")
if len(res) != 2:
continue
prompt, completion = res
prompt = prompt.strip()
completion = completion.strip()
chat = "user:{}\nsystem:{}".format(prompt, completion)
texts.append(chat)
texts = ["[multiturn_sep]".join(texts)]
else:
text = "user:{}\nsystem:{}".format(prompt, completion)
texts = [text]
# instruct code preprocess
elif "instruction" in batch and "answer" in batch:
prompt = batch["instruction"][0].replace("\\n", "")
prompt = prompt.strip("")
completion = batch["answer"][0].replace("\\n", "")
completion = completion.strip("")
text = "user:{}\nsystem:{}".format(prompt, completion)
texts = [text]
# share gpt preprocess
elif "conversations" in batch:
chats = batch["conversations"][0]
if chats[0]["from"] != "human":
chats = chats[1:]
texts = []
for i in range(len(chats) // 2):
prompt = chats[2 * i]
completion = chats[2 * i + 1]
if not (prompt["from"] == "human" and completion["from"] == "gpt"):
continue
prompt = prompt["value"]
prompt = prompt.strip()
completion = completion["value"]
completion = completion.strip()
chat = "user:{}\nsystem:{}".format(prompt, completion)
texts.append(chat)
texts = ["[multiturn_sep]".join(texts)]
else:
raise Exception("Unrecognized instruct dataset format.")
return {"text": texts}
def split_multiturn(batch):
return {"text": batch["text"][0].split("[multiturn_sep]")}
def truncation_gen(seq_length):
def truncation(line):
return {"input_ids": line["input_ids"][:seq_length]}
return truncation
def sample_sequence_gen(seq_length, eos_token_id):
def sample_sequence(line):
doc_length = line["input_ids"].shape[1]
doc_length = line["input_ids"].shape[0]
if doc_length <= seq_length:
start = 0
else:
@ -30,7 +115,7 @@ def sample_sequence_gen(seq_length, eos_token_id):
start = 0
else:
start = random.randint(0, doc_length - seq_length)
input_ids = line["input_ids"][0, start : start + seq_length]
input_ids = line["input_ids"][start : start + seq_length]
if input_ids[-1] != eos_token_id:
input_ids[-1] = eos_token_id
return {"input_ids": input_ids}
@ -38,18 +123,28 @@ def sample_sequence_gen(seq_length, eos_token_id):
return sample_sequence
def split_sequence_gen(seq_length):
def split_sequence(batch):
input_ids = batch["input_ids"][0]
out = []
while len(input_ids) >= (1 + len(out)) * seq_length:
out.append(input_ids[len(out) * seq_length : (1 + len(out)) * seq_length])
return {"input_ids": out}
return split_sequence
def concat_multiple_sequence_gen(seq_length):
def concat_multiple_sequence(batch):
concat_input_ids = torch.cat(batch["input_ids"], dim=0)
input_ids = []
while len(concat_input_ids) > (1 + len(input_ids)) * seq_length:
while len(concat_input_ids) >= (1 + len(input_ids)) * seq_length:
input_ids.append(
concat_input_ids[
len(input_ids) * seq_length : (1 + len(input_ids)) * seq_length
]
)
out = {"input_ids": input_ids}
return out
return {"input_ids": input_ids}
return concat_multiple_sequence
@ -75,8 +170,13 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
dataset = load_dataset(
"json", data_files=data_files, split="train", streaming=True
)
# 文本预处理转换为统一格式
if dataset_config["mode"] == "pretrain":
dataset = dataset.map(pretrain_transform)
dataset = dataset.map(pretrain_transform, batched=True, batch_size=1)
elif dataset_config["mode"] == "instruct":
dataset = dataset.map(instruct_transform, batched=True, batch_size=1)
dataset = dataset.select_columns("text")
dataset = dataset.map(split_multiturn, batched=True, batch_size=1)
else:
raise Exception(
"Dataset mode: {} not found.".format(dataset_config["mode"])
@ -84,6 +184,7 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
datasets.append(dataset)
probabilities.append(dataset.n_shards)
probabilities_sum = sum(probabilities)
# 多个数据部分按概率采样
probabilities = [p / probabilities_sum for p in probabilities]
if len(datasets) > 1:
full_dataset = interleave_datasets(
@ -91,26 +192,16 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
)
else:
full_dataset = datasets[0]
# to visualize
if return_raw_text:
return full_dataset
seq_length = dataset_config["seq_length"]
if dataset_config.get("concat_multiple_sequence", False):
num_sequences = dataset_config["num_sequences"]
full_dataset = full_dataset.map(
lambda x: tokenizer(
x["text"], return_tensors="pt", return_attention_mask=False
)
)
full_dataset = full_dataset.map(
sample_sequence_gen(seq_length, tokenizer.eos_token_id)
)
full_dataset = full_dataset.select_columns("input_ids")
full_dataset = full_dataset.map(
concat_multiple_sequence_gen(seq_length),
batched=True,
batch_size=num_sequences,
)
else:
sequence_sample_mode = dataset_config.get("sequence_sample_mode", "truncation")
truncation = sequence_sample_mode == "truncation"
# tokenize
if dataset_config.get("pad_to_max", True):
full_dataset = full_dataset.map(
lambda x: tokenizer(
x["text"],
@ -118,12 +209,57 @@ def construct_dataset(dataset_config, tokenizer, return_raw_text=False):
return_attention_mask=False,
padding="max_length",
max_length=seq_length,
truncation=True,
truncation=truncation,
)
)
else:
full_dataset = full_dataset.map(
lambda x: tokenizer(
x["text"],
return_tensors="pt",
return_attention_mask=False,
truncation=truncation,
)
)
# format
full_dataset = full_dataset.map(lambda x: {"input_ids": x["input_ids"][0]})
full_dataset = full_dataset.select_columns("input_ids")
# sequence_sample
if sequence_sample_mode == "truncation":
pass
elif sequence_sample_mode == "none":
pass
elif sequence_sample_mode == "sample":
full_dataset = full_dataset.map(
sample_sequence_gen(seq_length, tokenizer.eos_token_id)
)
elif sequence_sample_mode == "split":
assert not dataset_config.get("concat_multiple_sequence", False)
full_dataset = full_dataset.map(
split_sequence_gen(seq_length), batched=True, batch_size=1
)
else:
raise Exception(
"Unknown sequence_sample mode: {}.".format(sequence_sample_mode)
)
# concat multiple sequence
if dataset_config.get("concat_multiple_sequence", False):
num_sequences = dataset_config["num_sequences"]
full_dataset = full_dataset.map(
concat_multiple_sequence_gen(seq_length),
batched=True,
batch_size=num_sequences,
drop_last_batch=True,
)
# add label
full_dataset = full_dataset.map(get_labels_gen(tokenizer.pad_token_id))
# shuffle
full_dataset = full_dataset.shuffle()
return full_dataset
@ -135,7 +271,9 @@ if __name__ == "__main__":
data_config = {
"mode": "pretrain",
"data": {"wudao": "data/pretrain_data/part-wudao*.jsonl.zst"},
"data": {"mixed": "data/pretrain_data/part-*.jsonl.zst"},
"pad_to_max": False,
"sequence_sample_mode": "sample",
"concat_multiple_sequence": True,
"num_sequences": 10,
"seq_length": 2048,
@ -153,8 +291,8 @@ if __name__ == "__main__":
# raw_text = normalize("NFKC", raw_text)
input_ids = tokenizer(
line["text"], return_tensors="pt", return_attention_mask=False
)["input_ids"][0, :-1]
decode_text = tokenizer.decode(input_ids)
)["input_ids"][0]
decode_text = tokenizer.decode(input_ids, skip_special_tokens=True)
if raw_text != decode_text and "" not in raw_text:
print(raw_text, "\n", decode_text)
if i == 3000:

View File

@ -1,178 +0,0 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 21:02:00
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-06 03:33:27
FilePath: /Open-Llama/dataset/instruction_dataset.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import math
def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
def preprocess_self_instruction(line):
"""
The format of the data is roughly as follows.
{'prompt': 'Explain the origin of life on earth. Output:', 'completion': 'Life on Earth is believed to have'}
Split the data based on the tokenized length according to the maximum length.
"""
prompt = line["prompt"]
if prompt.endswith("Output:"):
prompt = prompt[:-7]
total = "user:{}\nsystem:{}".format(prompt.strip(), line["completion"].strip())
out = tokenizer(total)
input_ids = out["input_ids"]
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_self_instruction
def preprocess_belle_gen(tokenizer, segment_max_length=1024):
def preprocess_belle(line):
"""
The format of the data is roughly as follows.
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
"""
prompt = line["instruction"].replace("\\n", "")
prompt = prompt.strip("")
completion = line["output"].replace("\\n", "")
completion = completion.strip("")
total = "user:{}\nsystem:{}".format(prompt, completion)
out = tokenizer(total)
input_ids = out["input_ids"]
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_belle
def preprocess_belle_multiturn_chat_gen(tokenizer, segment_max_length=1024):
def preprocess_belle_multiturn_chat(line):
"""
The format of the data is roughly as follows.
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
"""
prompt = line["instruction"].replace("\\n", "")
prompt = prompt.strip("")
completion = line["output"].replace("\\n", "")
completion = completion.strip("")
chats = prompt + completion
chats = chats.split("Human:")
input_ids = []
for chat in chats:
if chat.strip() == "":
continue
res = chat.split("Assistant:")
if len(res) != 2:
continue
prompt, completion = res
prompt = prompt.strip()
completion = completion.strip()
chat = "user:{}\nsystem:{}".format(prompt, completion)
out = tokenizer(chat)
input_ids.extend(out["input_ids"])
if len(input_ids) == 0:
return None
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_belle_multiturn_chat
def preprocess_sharegpt_gen(tokenizer, segment_max_length=1024):
def preprocess_sharegpt(line):
"""
The format of the data is roughly as follows.
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
"""
chats = line["conversations"]
if chats[0]["from"] != "human":
chats = chats[1:]
input_ids = []
for i in range(len(chats) // 2):
prompt = chats[2 * i]
completion = chats[2 * i + 1]
if not (prompt["from"] == "human" and completion["from"] == "gpt"):
continue
prompt = prompt["value"]
prompt = prompt.strip()
completion = completion["value"]
completion = completion.strip()
chat = "user:{}\nsystem:{}".format(prompt, completion)
out = tokenizer(chat)
input_ids.extend(out["input_ids"])
if input_ids == []:
return None
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_sharegpt
def preprocess_instruct_code_gen(tokenizer, segment_max_length=1024):
def preprocess_instruct_code(line):
"""
The format of the data is roughly as follows.
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
"""
prompt = line["instruction"].replace("\\n", "")
prompt = prompt.strip("")
completion = line["answer"].replace("\\n", "")
completion = completion.strip("")
total = "user:{}\nsystem:{}".format(prompt, completion)
out = tokenizer(total)
input_ids = out["input_ids"]
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_instruct_code
if __name__ == "__main__":
import sentencepiece as spm
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, DataIter
sp_model = spm.SentencePieceProcessor(
model_file="configs/10w_vocab_wudao5_pile10.model"
)
tokenizer = Tokenizer(sp_model)
patterns = ["data/instruction_data/part-belle_multiturn_chat_0.8M-*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
transform_dict = {
"self_instruct": preprocess_self_instruction_gen(tokenizer),
"belle_1M": preprocess_belle_gen(tokenizer),
"belle_0.5M": preprocess_belle_gen(tokenizer),
"belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
"belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
"instruct_to_code": preprocess_instruct_code_gen(tokenizer),
"sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
}
data_set = DataIter(
paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
)
for i, sample in enumerate(data_set):
print(sp_model.decode(sample))
if i == 1:
break

View File

@ -1,189 +0,0 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 21:35:01
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-06 03:35:31
FilePath: /Open-Llama/inctruction_tuning.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import os
import time
import wandb
import torch
import random
import sentencepiece as spm
from torchinfo import summary
from accelerate import Accelerator
from torch.utils.data import DataLoader
from deepspeed.ops.adam import FusedAdam
from transformers import (
OpenLlamaForCausalLM,
OpenLlamaConfig,
get_cosine_schedule_with_warmup,
)
from dataset.validation import val_set
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, DataIter
from dataset.collate_fn import collate_fn_gen
from dataset.instruction_dataset import (
preprocess_belle_gen,
preprocess_self_instruction_gen,
preprocess_belle_multiturn_chat_gen,
preprocess_instruct_code_gen,
preprocess_sharegpt_gen,
)
from configs.instruction_tuning_config import *
accelerator = Accelerator()
if accelerator.is_main_process:
wandb.init(project="LLAMA Instruction")
log_interval *= accelerator.gradient_accumulation_steps
eval_interval *= accelerator.gradient_accumulation_steps
save_interval *= accelerator.gradient_accumulation_steps
sp_model = spm.SentencePieceProcessor(model_file=tokenizer_model_path)
tokenizer = Tokenizer(sp_model)
paths = create_shard_kwargs(patterns, repeat=3)
random.shuffle(paths)
transform_dict = {
"self_instruct": preprocess_self_instruction_gen(tokenizer),
"belle_1M": preprocess_belle_gen(tokenizer),
"belle_0.5M": preprocess_belle_gen(tokenizer),
"belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
"belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
"instruct_to_code": preprocess_instruct_code_gen(tokenizer),
"sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
}
data_set = DataIter(
paths,
transform_dict=transform_dict,
concat_docs=True,
max_length=max_length,
process_index=accelerator.process_index,
num_processes=accelerator.num_processes,
)
train_loader = DataLoader(
data_set,
batch_size=train_batch_size,
# If num_workers is greater than 1, duplicate data may occur.
num_workers=0,
collate_fn=collate_fn_gen(tokenizer, max_length),
drop_last=True,
)
# smaller initializer_range make training more stable
# add stabel embedding to token embedding
raw_model = OpenLlamaForCausalLM(
OpenLlamaConfig(
vocab_size=tokenizer.vocab_size,
initializer_range=initializer_range,
pad_token_id=tokenizer.pad_id,
rms_norm_eps=1e-5,
hidden_dropout_prob=0.1,
attention_dropout_prob=0.1,
use_stable_embedding=True,
shared_input_output_embedding=True,
)
)
ckpt = torch.load(ckpt_path, map_location="cpu")
raw_model.load_state_dict(ckpt)
raw_model.eval()
with torch.no_grad():
summary(raw_model.cuda(), input_data=torch.ones(1, 64, dtype=torch.int64).cuda())
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in raw_model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": weight_decay,
},
{
"params": [
p
for n, p in raw_model.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
optim = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=(0.9, 0.95))
optim.zero_grad()
factor = accelerator.num_processes / accelerator.gradient_accumulation_steps
scheduler = get_cosine_schedule_with_warmup(
optim,
num_warmup_steps=num_warmup_steps * factor,
num_training_steps=num_training_steps * factor,
)
_, model, optim, scheduler = accelerator.prepare(
train_loader, raw_model, optim, scheduler
)
print("start training...")
train_loader_iter = iter(train_loader)
global_step = 0
start_time = time.time()
for data_step in range(num_training_steps):
model.train()
with accelerator.accumulate(model):
batch = next(train_loader_iter)
for k, v in batch.items():
batch[k] = v.to(accelerator.device, non_blocking=True)
out = model(**batch, labels=batch["input_ids"])
total_loss = out.loss
losses = {"total_loss": total_loss}
accelerator.backward(total_loss)
optim.step()
scheduler.step()
optim.zero_grad()
if accelerator.sync_gradients:
global_step += 1
if data_step % log_interval == 0 and data_step > 0 and accelerator.is_main_process:
cost_time = time.time() - start_time
start_time = time.time()
tokens = train_batch_size * log_interval * max_length
wandb.log({"Training/Token per second per gpu": tokens / cost_time})
for k, v in losses.items():
wandb.log({"Losses/{}".format(k): v})
current_lr = optim.param_groups[0]["lr"]
wandb.log({"Training/LR": current_lr})
if optim.scaler is not None:
wandb.log({"Training/Loss Scale": optim.scaler.get_scale()})
wandb.log({"Training/Data Step": data_step})
wandb.log({"Training/Global Step": global_step})
accelerator.print(
"Global Step: {}, Data Step: {}, Loss: {}, Token per second per gpu: {}".format(
global_step, data_step, losses["total_loss"], tokens / cost_time
)
)
if data_step % eval_interval == 0 and accelerator.is_main_process:
text_table = wandb.Table(columns=["question", "pred"])
model.eval()
with torch.no_grad():
for data in val_set:
raw_inputs = data
inputs_len = len(raw_inputs)
inputs = tokenizer(
raw_inputs, return_tensors=True, add_special_tokens=False
)
for k, v in inputs.items():
inputs[k] = v.to(accelerator.device)
pred = model.generate(
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
)
pred = tokenizer.decode(pred.cpu())[0]
pred = pred[inputs_len:]
text_table.add_data(raw_inputs, pred)
wandb.log({"Predictions on {}".format(global_step): text_table})
if data_step % save_interval == 0 and data_step > 0 and accelerator.is_main_process:
if not os.path.isdir(work_dir):
os.mkdir(work_dir)
torch.save(raw_model.state_dict(), "{}/{}.pt".format(work_dir, global_step))
wandb.finish()

View File

@ -191,6 +191,6 @@ class Trainer:
**inputs, max_new_tokens=256, do_sample=True, repetition_penalty=2.0
)
pred = pred[0, input_length:]
pred = self.tokenizer.decode(pred.cpu())
pred = self.tokenizer.decode(pred.cpu(), skip_special_tokens=True)
text_table.add_data(raw_inputs, pred)
wandb.log({"Predictions on {}".format(self.global_step): text_table})

View File

@ -8,7 +8,6 @@ Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import os
import yaml
import torch
from absl import app
@ -26,10 +25,14 @@ flags.DEFINE_string("config", None, "Training config path")
def main(argv):
accelerator = Accelerator()
with open(FLAGS.config, "r", encoding="utf-8") as fp:
config = yaml.load(fp, Loader=yaml.FullLoader)
accelerator = Accelerator(
gradient_accumulation_steps=config["train"].get(
"gradient_accumulation_steps", 1
)
)
tokenizer = LlamaTokenizer(
config["data"]["tokenizer_model_path"],
pad_token="<pad>",
@ -37,14 +40,14 @@ def main(argv):
add_eos_token=True,
)
data_config = config["data"]
pretrain_dataset = construct_dataset(data_config, tokenizer)
pretrain_dataset = split_dataset_by_node(
pretrain_dataset,
rank=int(os.environ["RANK"]),
world_size=int(os.environ["WORLD_SIZE"]),
train_dataset = construct_dataset(data_config, tokenizer)
train_dataset = split_dataset_by_node(
train_dataset,
rank=accelerator.process_index,
world_size=accelerator.num_processes,
)
train_loader = DataLoader(
pretrain_dataset,
train_dataset,
batch_size=config["train"]["train_batch_size"],
num_workers=config["train"]["train_num_workers"],
)