Merge pull request #1 from Bayes-Song/dev

update instruct-tuning
This commit is contained in:
S 2023-04-07 23:21:06 +08:00 committed by GitHub
commit 56f71e24df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 580 additions and 314 deletions

View File

@ -2,7 +2,7 @@
* @Author: LiangSong(sl12160010@gmail.com)
* @Date: 2023-03-10 21:18:35
* @LastEditors: LiangSong(sl12160010@gmail.com)
* @LastEditTime: 2023-04-02 21:32:26
* @LastEditTime: 2023-04-07 23:19:21
* @FilePath: /Open-Llama/README.md
* @Description:
*
@ -16,7 +16,8 @@ Open-Llama是一个开源项目提供了一整套用于构建大型语言模
## 进展
虽然还没有完整的预训练完但是我们先使用40K step预训练的模型进行了Instruction-tuning模型可以服从简单的命令。目前没有多轮对话能力
我们完成了300B token的预训练总共训练80 K stepGlobal Batch Size和Llama中一致为4M。
使用总共7部分数据构成Instruction-tuning数据模型具有一定的编程能力、数学能力和多轮对话能力具体数据见Instruction-Tuning部分。
[Demo](http://home.ustc.edu.cn/~sl9292/)
@ -25,6 +26,9 @@ Open-Llama是一个开源项目提供了一整套用于构建大型语言模
本模型的效果如下图更多结果还待进一步测试。由于国内网络问题使用上面的Demo可能出现请求丢失的情况如长时间无响应可刷新重试
![image1](assets/image1.png)![image2](assets/image2.png)![image3](assets/image3.png)
下面是一个关于代码的多轮对话能力的展示
![image4](assets/multiturn_chat.jpeg)
我们简单预估一下达到上面效果的一个花费训练40K step使用了1.5亿条预训练数据大约为110B token总共训练时间76h按Google Cloud的A100报价花费大约为19152美元。后续的Instruction-tuning训练了12k Step使用1.6M条数据总共训练时间3.4h大约花费342美元。因此从0开始训练一个这样的模型总花费不到20000美元。
目前模型在数学方面和代码方面表现明显较差,这一方面和训练数据有关,另一方面我认为也是模型大小所造成的,然而这方面的逻辑推理能力是一个可用的模型所必备,因此后续更新会关注提升相关能力。
@ -166,12 +170,17 @@ Total mult-adds (G): 6.89
我们使用目前开源的三个数据集进行Instruction-tuning后续会加入更多的任务以及自己构建的数据集。
- [yizhongw/self_instruct](https://huggingface.co/datasets/yizhongw/self_instruct)
- [BelleGroup/generated_train_0.5M_CN](https://huggingface.co/datasets/BelleGroup/generated_train_0.5M_CN)
- [BelleGroup/generated_train_1M_CN](https://huggingface.co/datasets/BelleGroup/generated_train_1M_CN)
- [BelleGroup/train_0.5M_CN](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
- [BelleGroup/train_1M_CN](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
- [BelleGroup/multiturn_chat_0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
- [BelleGroup/school_math_0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
- [RyokoAI/ShareGPT52K](https://huggingface.co/datasets/RyokoAI/ShareGPT52K)
- [Graverman/Instruct-to-Code](https://huggingface.co/datasets/Graverman/Instruct-to-Code)
其中ShareGPT52K数据在datastes的处理有些问题我们直接下载原数据重新进行了处理。
我们对原始数据进行了一些预处理,格式如下
```
user: {prompt}<s>system: {completion}</s>
user: {prompt}\nsystem: {completion}</s>
```
具体训练代码和预训练基本一样,代码可见
@ -195,7 +204,12 @@ accelerate launch --config_file configs/default_config.yaml instruction_tuning.p
过程中Loss如下基本在波动不怎么下降
![loss](assets/instruct_loss.png)
### RLHF
暂无
### Server
单轮对话使用server.py对于多轮对话使用chat_server.py
基于Gradio开发。
## 性能对比
### 训练框架

BIN
assets/multiturn_chat.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 810 KiB

115
chat_server.py Normal file
View File

@ -0,0 +1,115 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-06 22:30:10
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-07 23:03:31
FilePath: /Open-Llama/chat_server.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import torch
import gradio as gr
import sentencepiece as spm
from dataset.tokenizer import Tokenizer
from transformers import LlamaForCausalLM, LlamaConfig
sp_model = spm.SentencePieceProcessor(
model_file="configs/10w_vocab_wudao5_pile10.model"
)
tokenizer = Tokenizer(sp_model)
raw_model = LlamaForCausalLM(
LlamaConfig(
vocab_size=tokenizer.vocab_size,
initializer_range=0.01,
pad_token_id=tokenizer.pad_id,
rms_norm_eps=1e-5,
hidden_dropout_prob=0.1,
attention_dropout_prob=0.1,
use_stable_embedding=True,
shared_input_output_embedding=True,
)
)
ckpt = torch.load(
"data/saved_ckpt/instruction_tuning_math_code_multiturn/36001.pt",
map_location="cpu",
)
raw_model.load_state_dict(ckpt)
raw_model.eval()
model = raw_model.cuda()
print("ready")
def parse_codeblock(text):
lines = text.split("\n")
for i, line in enumerate(lines):
if "```" in line:
if line != "```":
lines[i] = f'<pre><code class="{lines[i][3:]}">'
else:
lines[i] = "</code></pre>"
else:
if i > 0:
lines[i] = "<br/>" + line.replace("<", "&lt;").replace(">", "&gt;")
return "".join(lines)
with gr.Blocks() as demo:
gr.Markdown(
"""
# [Open-Llama](https://github.com/Bayes-Song/Open-Llama)
完全使用Open-Llama项目从0开始训练的Instruct-GPT模型当长时间无响应如20s以上可刷新重试
Instruct-GPT model is trained from scratch using the Open-Llama project without relying on any other pre-trained models. If there is no response for a long time (such as more than 20 seconds), please refresh and try again.
"""
)
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
def user(user_message, history):
print(user_message)
return "", history + [[user_message, None]]
def bot(history):
context = []
round = 0
for prompt, completion in history:
round += 1
if completion is None:
inputs = "user:{}\nsystem:".format(prompt)
inputs = tokenizer(
inputs, return_tensors=True, add_special_tokens=False
)
context.append(inputs["input_ids"])
else:
inputs = "user:{}\nsystem:{}".format(prompt, completion)
inputs = tokenizer(inputs, return_tensors=True, add_special_tokens=True)
context.append(inputs["input_ids"])
context = torch.cat(context, dim=-1)
context = context[:, -1024:]
inputs_len = context.shape[1]
context = context.cuda()
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
pred = pred[:, inputs_len:]
pred = tokenizer.decode(pred.cpu())[0]
print(pred)
bot_message = parse_codeblock(pred)
history[-1][1] = bot_message
return history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
gr.Markdown(
"""
当前体验服务生成的所有内容都是由人工智能模型生成我们对其生成内容的准确性完整性和功能性不做任何保证并且其生成的内容不代表我们的态度或观点
联系方式: sl12160010@gmail.com 对于该项目有任何意见和建议都欢迎联系我.
Contact information: sl12160010@gmail.com. Any opinions or suggestions regarding the project are welcome to be addressed to me through this email.
"""
)
demo.launch(share=True)

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 21:38:07
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-30 21:39:40
LastEditTime: 2023-04-06 03:37:23
FilePath: /Open-Llama/configs/instruction_tuning_config.py
Description:
@ -10,7 +10,7 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
max_length = 1024
train_batch_size = 2
num_training_steps = 37500
num_training_steps = 40000
num_warmup_steps = 100
initializer_range = 1e-2
lr = 2e-4
@ -22,4 +22,4 @@ log_interval = 50
eval_interval = 500
save_interval = 1000
work_dir = "data/saved_ckpt/"
ckpt_path = "data/saved_ckpt/40000.pt"
ckpt_path = "data/saved_ckpt/83200.pt"

15
data/download_instruct.sh Normal file
View File

@ -0,0 +1,15 @@
#!/bin/bash
###
# @Author: LiangSong(sl12160010@gmail.com)
# @Date: 2023-04-05 23:18:10
# @LastEditors: LiangSong(sl12160010@gmail.com)
# @LastEditTime: 2023-04-05 23:34:30
# @FilePath: /Open-Llama/data/download_instruct.sh
# @Description:
#
# Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
###
mkdir data/instruction_data
curl -C - --retry 3 'https://huggingface.co/datasets/RyokoAI/ShareGPT52K/resolve/main/sg_90k_part1.json' -o data/sg_90k_part1.json
curl -C - --retry 3 'https://huggingface.co/datasets/RyokoAI/ShareGPT52K/resolve/main/sg_90k_part2.json' -o data/sg_90k_part2.json
python3 data/preprocess_instruction.py

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 20:52:10
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-30 20:52:12
LastEditTime: 2023-04-05 23:51:16
FilePath: /Open-Llama/data/preprocess_instruction.py
Description:
@ -12,8 +12,11 @@ import json
import zstandard as zstd
from datasets import load_dataset
root_dir = "data"
dataset = load_dataset("yizhongw/self_instruct")
write_path = "data/instruction_data/part-self_instruct-{}.jsonl.zst"
write_path = root_dir + "/instruction_data/part-self_instruct-{}.jsonl.zst"
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
@ -27,9 +30,14 @@ for line in dataset["train"]:
wfp.write(b"\n")
total_num += 1
wfp.close()
print(
"yizhongw/self_instruct preprocess done. Total line: {}, Total file: {}".format(
total_num, file_num
)
)
dataset = load_dataset("BelleGroup/generated_train_0.5M_CN")
write_path = "data/instruction_data/part-belle_0.5M-{}.jsonl.zst"
dataset = load_dataset("BelleGroup/train_0.5M_CN")
write_path = root_dir + "/instruction_data/part-belle_0.5M-{}.jsonl.zst"
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
@ -43,9 +51,14 @@ for line in dataset["train"]:
wfp.write(b"\n")
total_num += 1
wfp.close()
print(
"BelleGroup/train_0.5M_CN preprocess done. Total line: {}, Total file: {}".format(
total_num, file_num
)
)
dataset = load_dataset("BelleGroup/generated_train_1M_CN")
write_path = "data/instruction_data/part-belle_1M-{}.jsonl.zst"
dataset = load_dataset("BelleGroup/train_1M_CN")
write_path = root_dir + "/instruction_data/part-belle_1M-{}.jsonl.zst"
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
@ -59,3 +72,96 @@ for line in dataset["train"]:
wfp.write(b"\n")
total_num += 1
wfp.close()
print(
"BelleGroup/train_1M_CN preprocess done. Total line: {}, Total file: {}".format(
total_num, file_num
)
)
dataset = load_dataset("BelleGroup/school_math_0.25M")
write_path = root_dir + "/instruction_data/part-belle_school_math_0.25M-{}.jsonl.zst"
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in dataset["train"]:
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
wfp.close()
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
wfp.write(line.encode("utf-8"))
wfp.write(b"\n")
total_num += 1
wfp.close()
print(
"BelleGroup/school_math_0.25M preprocess done. Total line: {}, Total file: {}".format(
total_num, file_num
)
)
dataset = load_dataset("BelleGroup/multiturn_chat_0.8M")
write_path = root_dir + "/instruction_data/part-belle_multiturn_chat_0.8M-{}.jsonl.zst"
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in dataset["train"]:
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
wfp.close()
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
wfp.write(line.encode("utf-8"))
wfp.write(b"\n")
total_num += 1
wfp.close()
print(
"BelleGroup/multiturn_chat_0.8M preprocess done. Total line: {}, Total file: {}".format(
total_num, file_num
)
)
dataset = load_dataset("Graverman/Instruct-to-Code")
write_path = root_dir + "/instruction_data/part-instruct_to_code-{}.jsonl.zst"
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
for line in dataset["train"]:
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
wfp.close()
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
wfp.write(line.encode("utf-8"))
wfp.write(b"\n")
total_num += 1
wfp.close()
print(
"Graverman/Instruct-to-Code preprocess done. Total line: {}, Total file: {}".format(
total_num, file_num
)
)
write_path = root_dir + "/instruction_data/part-sharegpt_90K-{}.jsonl.zst"
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
with open("data/sg_90k_part1.json", "r") as fp:
data1 = json.load(fp)
with open("data/sg_90k_part2.json", "r") as fp:
data2 = json.load(fp)
data = data1 + data2
for line in data:
line = json.dumps(line)
if total_num % 1024 == 0 and total_num > 0:
file_num += 1
wfp.close()
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
wfp.write(line.encode("utf-8"))
wfp.write(b"\n")
total_num += 1
wfp.close()
print(
"RyokoAI/ShareGPT52K preprocess done. Total line: {}, Total file: {}".format(
total_num, file_num
)
)

69
dataset/collate_fn.py Normal file
View File

@ -0,0 +1,69 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 20:58:16
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-05 22:11:03
FilePath: /Open-Llama/dataset/collate_fn.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import torch
def collate_fn_gen(tokenizer, segment_max_length=1024, padding="longest"):
"""
Organize data into tensors by padding based on the preset maximum length.
"""
pad_id = tokenizer.pad_id
def collate_fn(batch):
if padding == "longest":
max_length = max([len(i) for i in batch])
elif padding == "max_length":
max_length = segment_max_length
else:
raise Exception("Invalid argumet for padding: {}".format(padding))
input_ids = []
for i in batch:
input_len = len(i)
input_ids.append(i + [pad_id] * (max_length - input_len))
inputs = {
"input_ids": torch.tensor(input_ids, dtype=torch.int64),
}
return inputs
return collate_fn
if __name__ == "__main__":
import sentencepiece as spm
from torch.utils.data import DataLoader
from dataset.pretrain_dataset import preprocess_wudao_gen, preprocess_the_pile_gen
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, DataIter
sp_model = spm.SentencePieceProcessor(
model_file="configs/10w_vocab_wudao5_pile10.model"
)
tokenizer = Tokenizer(sp_model)
patterns = ["data/pretrain_data/part-*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
transform_dict = {
"wudao": preprocess_wudao_gen(tokenizer),
"pile": preprocess_the_pile_gen(tokenizer),
}
data_set = DataIter(paths, transform_dict=transform_dict)
train_loader = DataLoader(
data_set,
batch_size=8,
num_workers=4,
collate_fn=collate_fn_gen(tokenizer),
drop_last=True,
)
for batch in train_loader:
for k, v in batch.items():
print(k, v.shape)
break

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-17 19:32:20
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-26 23:03:32
LastEditTime: 2023-04-06 03:37:55
FilePath: /Open-Llama/dataset/data_iter.py
Description:
@ -11,67 +11,84 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
import json
from glob import glob
import zstandard as zstd
from torch.utils.data import IterableDataset
def create_data_iter(paths, transform_dict=None, process_index=0, num_processes=1):
class DataIter(IterableDataset):
"""
Currently, the allowed storage formats are jsonl and jsonl.zst.
Currently, the allowed storage formats are jsonl.zst.
Each line of the data is a dictionary, which can be parsed as JSON for subsequent processing after reading.
Currently, only single worker is supported.
"""
past = None
for i, path in paths:
dataset_name = path.split("-")[-2]
if num_processes > 1 and i % num_processes != process_index:
continue
if past != dataset_name:
print("Loading data from {}".format(path))
past = path
if path.endswith("jsonl.zst"):
def __init__(
self,
paths_with_index,
transform_dict=None,
max_length=None,
concat_docs=False,
process_index=0,
num_processes=1,
):
super().__init__()
self.paths_with_index = paths_with_index
self.max_length = max_length
self.transform_dict = transform_dict
self.concat_docs = concat_docs
self.process_index = process_index
self.num_processes = num_processes
if self.concat_docs:
self.cache = []
def __iter__(self):
past = None
for i, path in self.paths_with_index:
# part-dataset_name-01.jsonl.zst
dataset_name = path.split("-")[-2]
# shard to multiple device
if self.num_processes > 1 and i % self.num_processes != self.process_index:
continue
# Log the file name when encountering a new file.
if past != dataset_name:
print("Loading data from {}".format(path))
past = path
# Currently, the allowed storage formats are jsonl.zst.
assert path.endswith("jsonl.zst")
with zstd.open(path, "r", encoding="utf-8") as fp:
for line in fp:
# If the length of the cache is greater than max_length.
if self.concat_docs and len(self.cache) >= self.max_length:
seq = self.cache[: self.max_length]
self.cache = self.cache[self.max_length :]
yield seq
if isinstance(line, bytes):
line = line.decode("utf-8")
line = json.loads(line)
line["dataset"] = dataset_name
if transform_dict:
line = transform_dict[dataset_name](line)
if isinstance(line, str):
# Transformation, including sample, tokenize, etc.
if self.transform_dict:
line = self.transform_dict[dataset_name](line)
# skip bad doc
if line is None:
continue
elif isinstance(line, str):
yield line
elif isinstance(line, list):
for i in line:
yield i
# must be list of list
elif isinstance(line, list) and isinstance(line[0], list):
for seq in line:
if self.concat_docs:
# concat seq from multiple docs
self.cache += seq
else:
yield seq
else:
raise Exception(
"Unsupported type in Transformation: {}".format(
transform_dict[dataset_name]
self.transform_dict[dataset_name]
)
)
else:
yield line
elif path.endswith("jsonl"):
with open(path, "r") as fp:
for line in fp:
if isinstance(line, bytes):
line = line.decode("utf-8")
line = json.loads(line)
line["dataset"] = dataset_name
if transform_dict:
line = transform_dict[dataset_name](line)
if isinstance(line, str):
yield line
elif isinstance(line, list):
for i in line:
yield i
else:
raise Exception(
"Unsupported type in Transformation: {}".format(
transform_dict[dataset_name]
)
)
else:
yield line
else:
raise Exception("File format of {} is not supported yet.".format(path))
def create_shard_kwargs(patterns, repeat=1):
@ -90,7 +107,9 @@ if __name__ == "__main__":
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
transform_dict = {"wudao": lambda x: x["title"], "pile": lambda x: [x["text"]]}
data_iter = create_data_iter(paths, transform_dict=transform_dict)
data_iter = DataIter(
paths, transform_dict=transform_dict, max_length=16, concat_docs=True
)
for i, data in enumerate(data_iter):
print(i, data)
if i == 20:

View File

@ -1,192 +0,0 @@
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 20:58:16
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-30 21:00:49
FilePath: /Open-Llama/dataset/data_loader.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import math
import torch
def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024, padding="longest"):
"""
Organize data into tensors by padding based on the preset maximum length.
"""
pad_id = tokenizer.pad_id
def pretrain_collate_fn(batch):
if padding == "longest":
max_length = max([len(i) for i in batch])
elif padding == "max_length":
max_length = segment_max_length
else:
raise Exception("Invalid argumet for padding: {}".format(padding))
input_ids = []
for i in batch:
input_len = len(i)
input_ids.append(i + [pad_id] * (max_length - input_len))
inputs = {
"input_ids": torch.tensor(input_ids, dtype=torch.int64),
}
return inputs
return pretrain_collate_fn
class BySequenceLengthDataset(torch.utils.data.IterableDataset):
"""
experimental
"""
def __init__(
self, generator, batch_size, accelerator=None, bucket_size=16, max_length=1024
):
super().__init__()
self.generator = generator
self.batch_size = batch_size
self.bucket_size = bucket_size
self.bucket_num = math.ceil(max_length / bucket_size)
self.buckets = [[] for _ in range(self.bucket_num)]
self.bucket_idx = None
self.accelerator = accelerator
if self.accelerator is not None:
self.buckets_ele_num = torch.tensor(
[0] * self.bucket_num, dtype=torch.int64, device=accelerator.device
)
self.buckets_indexes = torch.arange(
self.bucket_num, device=accelerator.device
)
self.finished = False
self.has_no_same_bucket = False
self.rest = None
def __iter__(self):
if self.batch_size <= 1:
return self.generator
def bucket_iter():
while True:
if self.bucket_idx is not None:
sample = self.buckets[self.bucket_idx].pop()
if len(self.buckets[self.bucket_idx]) == 0:
self.bucket_idx = None
yield sample
try:
sample = next(self.generator)
except StopIteration:
break
sample_len = len(sample) - 1
bucket_idx = sample_len // self.bucket_size
if len(self.buckets[bucket_idx]) == self.batch_size - 1:
self.bucket_idx = bucket_idx
yield sample
else:
self.buckets[bucket_idx].append(sample)
def parallel_bucket_iter():
while True:
if self.bucket_idx is not None:
sample = self.buckets[self.bucket_idx].pop()
self.buckets_ele_num[self.bucket_idx] -= 1
buckets_ele_num = self.accelerator.gather(self.buckets_ele_num)
buckets_ele_num = buckets_ele_num.reshape(
self.accelerator.num_processes, self.bucket_num
)
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
if min_buckets_ele_num[self.bucket_idx] <= 0:
self.bucket_idx = None
yield sample
else:
if self.finished:
if self.has_no_same_bucket:
if self.rest is None:
self.rest = []
for bucket in self.buckets:
for i in bucket:
self.rest.append(i)
elif len(self.rest) > 0:
yield self.rest.pop()
else:
raise StopIteration()
else:
buckets_ele_num = self.accelerator.gather(
self.buckets_ele_num
)
buckets_ele_num = buckets_ele_num.view(
self.accelerator.num_processes, self.bucket_num
)
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
valid_bucket_idx = self.buckets_indexes[
min_buckets_ele_num >= self.batch_size
]
if len(valid_bucket_idx) > 0:
self.bucket_idx = valid_bucket_idx[0].cpu().item()
else:
self.has_no_same_bucket = True
else:
try:
sample = next(self.generator)
except StopIteration:
self.finished = True
continue
sample_len = len(sample) - 1
bucket_idx = sample_len // self.bucket_size
self.buckets[bucket_idx].append(sample)
self.buckets_ele_num[bucket_idx] += 1
buckets_ele_num = self.accelerator.gather(
self.buckets_ele_num
).cpu()
buckets_ele_num = buckets_ele_num.view(
self.accelerator.num_processes, self.bucket_num
)
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
valid_bucket_idx = self.buckets_indexes[
min_buckets_ele_num >= self.batch_size
]
if len(valid_bucket_idx) > 0:
self.bucket_idx = valid_bucket_idx[0].cpu().item()
if self.accelerator:
return parallel_bucket_iter()
else:
return bucket_iter()
if __name__ == "__main__":
import sentencepiece as spm
from datasets import IterableDataset
from torch.utils.data import DataLoader
from dataset.pretrain_dataset import preprocess_wudao_gen, preprocess_the_pile_gen
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, create_data_iter
sp_model = spm.SentencePieceProcessor(
model_file="configs/10w_vocab_wudao5_pile10.model"
)
tokenizer = Tokenizer(sp_model)
patterns = ["data/pretrain_data/part-*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
transform_dict = {
"wudao": preprocess_wudao_gen(tokenizer),
"pile": preprocess_the_pile_gen(tokenizer),
}
data_set = IterableDataset.from_generator(
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
)
train_loader = DataLoader(
data_set,
batch_size=8,
num_workers=4,
collate_fn=pretrain_collate_fn_gen(tokenizer),
drop_last=True,
)
for batch in train_loader:
for k, v in batch.items():
print(k, v.shape)
break

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 21:02:00
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-30 21:02:06
LastEditTime: 2023-04-06 03:33:27
FilePath: /Open-Llama/dataset/instruction_dataset.py
Description:
@ -21,7 +21,7 @@ def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
prompt = line["prompt"]
if prompt.endswith("Output:"):
prompt = prompt[:-7]
total = "user:{}<s>system:{}".format(prompt.strip(), line["completion"].strip())
total = "user:{}\nsystem:{}".format(prompt.strip(), line["completion"].strip())
out = tokenizer(total)
input_ids = out["input_ids"]
return [
@ -39,12 +39,12 @@ def preprocess_belle_gen(tokenizer, segment_max_length=1024):
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
"""
prompt = line["input"].replace("\\n", "")
prompt = line["instruction"].replace("\\n", "")
prompt = prompt.strip("")
completion = line["target"].replace("\\n", "")
completion = line["output"].replace("\\n", "")
completion = completion.strip("")
total = "user:{}<s>system:{}".format(prompt, completion)
total = "user:{}\nsystem:{}".format(prompt, completion)
out = tokenizer(total)
input_ids = out["input_ids"]
return [
@ -55,28 +55,124 @@ def preprocess_belle_gen(tokenizer, segment_max_length=1024):
return preprocess_belle
def preprocess_belle_multiturn_chat_gen(tokenizer, segment_max_length=1024):
def preprocess_belle_multiturn_chat(line):
"""
The format of the data is roughly as follows.
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
"""
prompt = line["instruction"].replace("\\n", "")
prompt = prompt.strip("")
completion = line["output"].replace("\\n", "")
completion = completion.strip("")
chats = prompt + completion
chats = chats.split("Human:")
input_ids = []
for chat in chats:
if chat.strip() == "":
continue
res = chat.split("Assistant:")
if len(res) != 2:
continue
prompt, completion = res
prompt = prompt.strip()
completion = completion.strip()
chat = "user:{}\nsystem:{}".format(prompt, completion)
out = tokenizer(chat)
input_ids.extend(out["input_ids"])
if len(input_ids) == 0:
return None
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_belle_multiturn_chat
def preprocess_sharegpt_gen(tokenizer, segment_max_length=1024):
def preprocess_sharegpt(line):
"""
The format of the data is roughly as follows.
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
"""
chats = line["conversations"]
if chats[0]["from"] != "human":
chats = chats[1:]
input_ids = []
for i in range(len(chats) // 2):
prompt = chats[2 * i]
completion = chats[2 * i + 1]
if not (prompt["from"] == "human" and completion["from"] == "gpt"):
continue
prompt = prompt["value"]
prompt = prompt.strip()
completion = completion["value"]
completion = completion.strip()
chat = "user:{}\nsystem:{}".format(prompt, completion)
out = tokenizer(chat)
input_ids.extend(out["input_ids"])
if input_ids == []:
return None
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_sharegpt
def preprocess_instruct_code_gen(tokenizer, segment_max_length=1024):
def preprocess_instruct_code(line):
"""
The format of the data is roughly as follows.
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
"""
prompt = line["instruction"].replace("\\n", "")
prompt = prompt.strip("")
completion = line["answer"].replace("\\n", "")
completion = completion.strip("")
total = "user:{}\nsystem:{}".format(prompt, completion)
out = tokenizer(total)
input_ids = out["input_ids"]
return [
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
for i in range(math.ceil(len(input_ids) / segment_max_length))
]
return preprocess_instruct_code
if __name__ == "__main__":
import sentencepiece as spm
from datasets import IterableDataset
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, create_data_iter
from dataset.data_iter import create_shard_kwargs, DataIter
sp_model = spm.SentencePieceProcessor(
model_file="configs/10w_vocab_wudao5_pile10.model"
)
tokenizer = Tokenizer(sp_model)
patterns = ["data/instruction_data/part-belle_1M*.jsonl.zst"]
patterns = ["data/instruction_data/part-belle_multiturn_chat_0.8M-*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
transform_dict = {
"self_instruct": preprocess_self_instruction_gen(tokenizer),
"belle_1M": preprocess_belle_gen(tokenizer),
"belle_0.5M": preprocess_belle_gen(tokenizer),
"self_instruct": preprocess_self_instruction_gen(tokenizer),
"belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
"belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
"instruct_to_code": preprocess_instruct_code_gen(tokenizer),
"sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
}
data_set = IterableDataset.from_generator(
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
data_set = DataIter(
paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
)
for i, sample in enumerate(data_set):
print(sample, sp_model.Decode(sample))
if i == 20:
print(sp_model.decode(sample))
if i == 1:
break

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-17 20:41:25
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-26 23:07:56
LastEditTime: 2023-04-05 22:32:39
FilePath: /Open-Llama/dataset/pretrain_dataset.py
Description:
@ -49,10 +49,9 @@ def preprocess_the_pile_gen(tokenizer, segment_max_length=1024):
if __name__ == "__main__":
import sentencepiece as spm
from datasets import IterableDataset
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, create_data_iter
from dataset.data_iter import create_shard_kwargs, DataIter
sp_model = spm.SentencePieceProcessor(
model_file="configs/10w_vocab_wudao5_pile10.model"
@ -64,8 +63,8 @@ if __name__ == "__main__":
"wudao": preprocess_wudao_gen(tokenizer),
"pile": preprocess_the_pile_gen(tokenizer),
}
data_set = IterableDataset.from_generator(
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
data_set = DataIter(
paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
)
for sample in data_set:
print(sample)

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-20 21:39:47
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-26 23:09:39
LastEditTime: 2023-04-06 23:01:50
FilePath: /Open-Llama/dataset/tokenizer.py
Description:
@ -145,14 +145,34 @@ class Tokenizer:
out["attention_mask"] = attention_mask
return out
def decode(self, inputs):
def decode(self, inputs, max_rounds=None):
inputs = inputs.tolist()
out = []
for i in inputs:
if self.eos_id in i:
eos_idx = i.index(self.eos_id)
i = i[:eos_idx]
out.append(i)
for i, ids in enumerate(inputs):
count = 0
flag = False
for j, token in enumerate(ids):
if token == self.eos_id:
if max_rounds is None:
flag = True
break
elif isinstance(max_rounds, int):
if count < max_rounds:
count += 1
else:
flag = True
break
elif isinstance(max_rounds, list):
if count < max_rounds[i]:
count += 1
else:
flag = True
break
if flag:
ids = ids[:j]
else:
ids = ids
out.append(ids)
out = self.sp_model.Decode(out)
return out
@ -183,11 +203,11 @@ if __name__ == "__main__":
for i, j in zip(tmp, out):
assert normalize("NFKC", i) == j
from dataset.data_iter import create_shard_kwargs, create_data_iter
from dataset.data_iter import create_shard_kwargs, DataIter
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
data_iter = create_data_iter(paths)
data_iter = DataIter(paths)
for i, data in enumerate(data_iter):
assert (
normalize("NFKC", data["content"])

View File

@ -2,14 +2,14 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-24 20:49:03
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-26 23:43:59
LastEditTime: 2023-04-05 22:40:29
FilePath: /Open-Llama/dataset/train_tokenizer.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import random
from dataset.data_iter import create_data_iter, create_shard_kwargs
from dataset.data_iter import DataIter, create_shard_kwargs
wudao_patterns = [
"data/pretrain_data/part-wudao-*.jsonl.zst",
@ -24,10 +24,10 @@ pile_paths = create_shard_kwargs(pile_patterns)
random.shuffle(pile_paths)
paths = wudao_paths[:5] + pile_paths[:10]
transform_dict = {
"wudao": lambda line: [(line["title"] + "\n" + line["content"])],
"pile": lambda line: [line["text"]],
"wudao": lambda line: line["title"] + "\n" + line["content"],
"pile": lambda line: line["text"],
}
data_iter = create_data_iter(paths, transform_dict)
data_iter = iter(DataIter(paths, transform_dict))
import io
import sentencepiece as spm

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 21:35:01
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-30 21:40:03
LastEditTime: 2023-04-06 03:35:31
FilePath: /Open-Llama/inctruction_tuning.py
Description:
@ -16,18 +16,20 @@ import random
import sentencepiece as spm
from torchinfo import summary
from accelerate import Accelerator
from datasets import IterableDataset
from torch.utils.data import DataLoader
from deepspeed.ops.adam import FusedAdam
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
from dataset.validation import val_set
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, create_data_iter
from dataset.data_loader import pretrain_collate_fn_gen
from dataset.data_iter import create_shard_kwargs, DataIter
from dataset.collate_fn import collate_fn_gen
from dataset.instruction_dataset import (
preprocess_belle_gen,
preprocess_self_instruction_gen,
preprocess_belle_multiturn_chat_gen,
preprocess_instruct_code_gen,
preprocess_sharegpt_gen,
)
from configs.instruction_tuning_config import *
@ -46,25 +48,28 @@ tokenizer = Tokenizer(sp_model)
paths = create_shard_kwargs(patterns, repeat=3)
random.shuffle(paths)
transform_dict = {
"belle_1M": preprocess_belle_gen(tokenizer, max_length),
"belle_0.5M": preprocess_belle_gen(tokenizer, max_length),
"self_instruct": preprocess_self_instruction_gen(tokenizer, max_length),
"self_instruct": preprocess_self_instruction_gen(tokenizer),
"belle_1M": preprocess_belle_gen(tokenizer),
"belle_0.5M": preprocess_belle_gen(tokenizer),
"belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
"belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
"instruct_to_code": preprocess_instruct_code_gen(tokenizer),
"sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
}
data_set = IterableDataset.from_generator(
create_data_iter,
gen_kwargs={
"paths": paths,
"transform_dict": transform_dict,
"process_index": accelerator.process_index,
"num_processes": accelerator.num_processes,
},
data_set = DataIter(
paths,
transform_dict=transform_dict,
concat_docs=True,
max_length=max_length,
process_index=accelerator.process_index,
num_processes=accelerator.num_processes,
)
train_loader = DataLoader(
data_set,
batch_size=train_batch_size,
# If num_workers is greater than 1, duplicate data may occur.
num_workers=0,
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
collate_fn=collate_fn_gen(tokenizer, max_length),
drop_last=True,
)
# smaller initializer_range make training more stable

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-17 14:27:28
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-27 01:07:25
LastEditTime: 2023-04-05 22:46:31
FilePath: /Open-Llama/pretrain_llama.py
Description:
pretrain GPT
@ -16,15 +16,14 @@ import random
import sentencepiece as spm
from torchinfo import summary
from accelerate import Accelerator
from datasets import IterableDataset
from torch.utils.data import DataLoader
from deepspeed.ops.adam import FusedAdam
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
from dataset.validation import val_set
from dataset.tokenizer import Tokenizer
from dataset.data_iter import create_shard_kwargs, create_data_iter
from dataset.data_loader import pretrain_collate_fn_gen
from dataset.data_iter import create_shard_kwargs, DataIter
from dataset.collate_fn import collate_fn_gen
from dataset.pretrain_dataset import (
preprocess_the_pile_gen,
preprocess_wudao_gen,
@ -49,21 +48,20 @@ transform_dict = {
"wudao": preprocess_wudao_gen(tokenizer, max_length),
"pile": preprocess_the_pile_gen(tokenizer, max_length),
}
data_set = IterableDataset.from_generator(
create_data_iter,
gen_kwargs={
"paths": paths,
"transform_dict": transform_dict,
"process_index": accelerator.process_index,
"num_processes": accelerator.num_processes,
},
data_set = DataIter(
paths,
transform_dict=transform_dict,
concat_docs=True,
max_length=max_length,
process_index=accelerator.process_index,
num_processes=accelerator.num_processes,
)
train_loader = DataLoader(
data_set,
batch_size=train_batch_size,
# If num_workers is greater than 1, duplicate data may occur.
num_workers=0,
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
collate_fn=collate_fn_gen(tokenizer, max_length),
drop_last=True,
)
# smaller initializer_range make training more stable

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-31 13:26:15
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-31 14:05:35
LastEditTime: 2023-04-06 03:45:44
FilePath: /Open-Llama/server.py
Description:
@ -32,7 +32,9 @@ raw_model = LlamaForCausalLM(
shared_input_output_embedding=True,
)
)
ckpt = torch.load("data/saved_ckpt/instruction_tuning_3_epochs/23001.pt", map_location="cpu")
ckpt = torch.load(
"data/saved_ckpt/instruction_tuning_3_epochs/37001.pt", map_location="cpu"
)
raw_model.load_state_dict(ckpt)
raw_model.eval()
model = raw_model.cuda()
@ -41,7 +43,7 @@ print("ready")
def question_answer(prompt):
print(prompt)
raw_inputs = "user:{}<s>system:".format(prompt)
raw_inputs = "user:{}\nsystem:".format(prompt)
inputs_len = len(raw_inputs)
inputs = tokenizer(raw_inputs, return_tensors=True, add_special_tokens=False)
for k, v in inputs.items():