commit
56f71e24df
24
README.md
24
README.md
|
@ -2,7 +2,7 @@
|
||||||
* @Author: LiangSong(sl12160010@gmail.com)
|
* @Author: LiangSong(sl12160010@gmail.com)
|
||||||
* @Date: 2023-03-10 21:18:35
|
* @Date: 2023-03-10 21:18:35
|
||||||
* @LastEditors: LiangSong(sl12160010@gmail.com)
|
* @LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
* @LastEditTime: 2023-04-02 21:32:26
|
* @LastEditTime: 2023-04-07 23:19:21
|
||||||
* @FilePath: /Open-Llama/README.md
|
* @FilePath: /Open-Llama/README.md
|
||||||
* @Description:
|
* @Description:
|
||||||
*
|
*
|
||||||
|
@ -16,7 +16,8 @@ Open-Llama是一个开源项目,提供了一整套用于构建大型语言模
|
||||||
|
|
||||||
## 进展
|
## 进展
|
||||||
|
|
||||||
虽然还没有完整的预训练完,但是我们先使用40K step预训练的模型进行了Instruction-tuning,模型可以服从简单的命令。目前没有多轮对话能力
|
我们完成了300B token的预训练,总共训练80 K step,Global Batch Size和Llama中一致为4M。
|
||||||
|
使用总共7部分数据构成Instruction-tuning数据,模型具有一定的编程能力、数学能力和多轮对话能力,具体数据见Instruction-Tuning部分。
|
||||||
|
|
||||||
[Demo](http://home.ustc.edu.cn/~sl9292/)
|
[Demo](http://home.ustc.edu.cn/~sl9292/)
|
||||||
|
|
||||||
|
@ -25,6 +26,9 @@ Open-Llama是一个开源项目,提供了一整套用于构建大型语言模
|
||||||
本模型的效果如下图,更多结果还待进一步测试。由于国内网络问题,使用上面的Demo可能出现请求丢失的情况,如长时间无响应可刷新重试
|
本模型的效果如下图,更多结果还待进一步测试。由于国内网络问题,使用上面的Demo可能出现请求丢失的情况,如长时间无响应可刷新重试
|
||||||

|

|
||||||
|
|
||||||
|
下面是一个关于代码的多轮对话能力的展示
|
||||||
|
|
||||||
|

|
||||||
我们简单预估一下达到上面效果的一个花费,训练40K step使用了1.5亿条预训练数据,大约为110B token,总共训练时间76h,按Google Cloud的A100报价花费大约为19152美元。后续的Instruction-tuning训练了12k Step,使用1.6M条数据,总共训练时间3.4h,大约花费342美元。因此从0开始训练一个这样的模型总花费不到20000美元。
|
我们简单预估一下达到上面效果的一个花费,训练40K step使用了1.5亿条预训练数据,大约为110B token,总共训练时间76h,按Google Cloud的A100报价花费大约为19152美元。后续的Instruction-tuning训练了12k Step,使用1.6M条数据,总共训练时间3.4h,大约花费342美元。因此从0开始训练一个这样的模型总花费不到20000美元。
|
||||||
|
|
||||||
目前模型在数学方面和代码方面表现明显较差,这一方面和训练数据有关,另一方面我认为也是模型大小所造成的,然而这方面的逻辑推理能力是一个可用的模型所必备,因此后续更新会关注提升相关能力。
|
目前模型在数学方面和代码方面表现明显较差,这一方面和训练数据有关,另一方面我认为也是模型大小所造成的,然而这方面的逻辑推理能力是一个可用的模型所必备,因此后续更新会关注提升相关能力。
|
||||||
|
@ -166,12 +170,17 @@ Total mult-adds (G): 6.89
|
||||||
|
|
||||||
我们使用目前开源的三个数据集进行Instruction-tuning,后续会加入更多的任务以及自己构建的数据集。
|
我们使用目前开源的三个数据集进行Instruction-tuning,后续会加入更多的任务以及自己构建的数据集。
|
||||||
- [yizhongw/self_instruct](https://huggingface.co/datasets/yizhongw/self_instruct)
|
- [yizhongw/self_instruct](https://huggingface.co/datasets/yizhongw/self_instruct)
|
||||||
- [BelleGroup/generated_train_0.5M_CN](https://huggingface.co/datasets/BelleGroup/generated_train_0.5M_CN)
|
- [BelleGroup/train_0.5M_CN](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||||
- [BelleGroup/generated_train_1M_CN](https://huggingface.co/datasets/BelleGroup/generated_train_1M_CN)
|
- [BelleGroup/train_1M_CN](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||||
|
- [BelleGroup/multiturn_chat_0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||||
|
- [BelleGroup/school_math_0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||||
|
- [RyokoAI/ShareGPT52K](https://huggingface.co/datasets/RyokoAI/ShareGPT52K)
|
||||||
|
- [Graverman/Instruct-to-Code](https://huggingface.co/datasets/Graverman/Instruct-to-Code)
|
||||||
|
|
||||||
|
其中ShareGPT52K数据在datastes的处理有些问题,我们直接下载原数据重新进行了处理。
|
||||||
我们对原始数据进行了一些预处理,格式如下
|
我们对原始数据进行了一些预处理,格式如下
|
||||||
```
|
```
|
||||||
user: {prompt}<s>system: {completion}</s>
|
user: {prompt}\nsystem: {completion}</s>
|
||||||
```
|
```
|
||||||
|
|
||||||
具体训练代码和预训练基本一样,代码可见
|
具体训练代码和预训练基本一样,代码可见
|
||||||
|
@ -195,7 +204,12 @@ accelerate launch --config_file configs/default_config.yaml instruction_tuning.p
|
||||||
过程中Loss如下,基本在波动不怎么下降
|
过程中Loss如下,基本在波动不怎么下降
|
||||||

|

|
||||||
### RLHF
|
### RLHF
|
||||||
|
暂无
|
||||||
|
### Server
|
||||||
|
|
||||||
|
单轮对话使用server.py,对于多轮对话使用chat_server.py
|
||||||
|
|
||||||
|
基于Gradio开发。
|
||||||
## 性能对比
|
## 性能对比
|
||||||
|
|
||||||
### 训练框架
|
### 训练框架
|
||||||
|
|
BIN
assets/multiturn_chat.jpeg
Normal file
BIN
assets/multiturn_chat.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 810 KiB |
115
chat_server.py
Normal file
115
chat_server.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
"""
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-04-06 22:30:10
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-04-07 23:03:31
|
||||||
|
FilePath: /Open-Llama/chat_server.py
|
||||||
|
Description:
|
||||||
|
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import gradio as gr
|
||||||
|
import sentencepiece as spm
|
||||||
|
from dataset.tokenizer import Tokenizer
|
||||||
|
from transformers import LlamaForCausalLM, LlamaConfig
|
||||||
|
|
||||||
|
|
||||||
|
sp_model = spm.SentencePieceProcessor(
|
||||||
|
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||||
|
)
|
||||||
|
tokenizer = Tokenizer(sp_model)
|
||||||
|
raw_model = LlamaForCausalLM(
|
||||||
|
LlamaConfig(
|
||||||
|
vocab_size=tokenizer.vocab_size,
|
||||||
|
initializer_range=0.01,
|
||||||
|
pad_token_id=tokenizer.pad_id,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_dropout_prob=0.1,
|
||||||
|
use_stable_embedding=True,
|
||||||
|
shared_input_output_embedding=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ckpt = torch.load(
|
||||||
|
"data/saved_ckpt/instruction_tuning_math_code_multiturn/36001.pt",
|
||||||
|
map_location="cpu",
|
||||||
|
)
|
||||||
|
raw_model.load_state_dict(ckpt)
|
||||||
|
raw_model.eval()
|
||||||
|
model = raw_model.cuda()
|
||||||
|
print("ready")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_codeblock(text):
|
||||||
|
lines = text.split("\n")
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if "```" in line:
|
||||||
|
if line != "```":
|
||||||
|
lines[i] = f'<pre><code class="{lines[i][3:]}">'
|
||||||
|
else:
|
||||||
|
lines[i] = "</code></pre>"
|
||||||
|
else:
|
||||||
|
if i > 0:
|
||||||
|
lines[i] = "<br/>" + line.replace("<", "<").replace(">", ">")
|
||||||
|
return "".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
gr.Markdown(
|
||||||
|
"""
|
||||||
|
# [Open-Llama](https://github.com/Bayes-Song/Open-Llama)
|
||||||
|
完全使用Open-Llama项目从0开始训练的Instruct-GPT模型,当长时间无响应(如20s以上)可刷新重试。
|
||||||
|
|
||||||
|
Instruct-GPT model is trained from scratch using the Open-Llama project without relying on any other pre-trained models. If there is no response for a long time (such as more than 20 seconds), please refresh and try again.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
chatbot = gr.Chatbot()
|
||||||
|
msg = gr.Textbox()
|
||||||
|
clear = gr.Button("Clear")
|
||||||
|
|
||||||
|
def user(user_message, history):
|
||||||
|
print(user_message)
|
||||||
|
return "", history + [[user_message, None]]
|
||||||
|
|
||||||
|
def bot(history):
|
||||||
|
context = []
|
||||||
|
round = 0
|
||||||
|
for prompt, completion in history:
|
||||||
|
round += 1
|
||||||
|
if completion is None:
|
||||||
|
inputs = "user:{}\nsystem:".format(prompt)
|
||||||
|
inputs = tokenizer(
|
||||||
|
inputs, return_tensors=True, add_special_tokens=False
|
||||||
|
)
|
||||||
|
context.append(inputs["input_ids"])
|
||||||
|
else:
|
||||||
|
inputs = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
|
inputs = tokenizer(inputs, return_tensors=True, add_special_tokens=True)
|
||||||
|
context.append(inputs["input_ids"])
|
||||||
|
context = torch.cat(context, dim=-1)
|
||||||
|
context = context[:, -1024:]
|
||||||
|
inputs_len = context.shape[1]
|
||||||
|
context = context.cuda()
|
||||||
|
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
|
||||||
|
pred = pred[:, inputs_len:]
|
||||||
|
pred = tokenizer.decode(pred.cpu())[0]
|
||||||
|
print(pred)
|
||||||
|
bot_message = parse_codeblock(pred)
|
||||||
|
history[-1][1] = bot_message
|
||||||
|
return history
|
||||||
|
|
||||||
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
||||||
|
bot, chatbot, chatbot
|
||||||
|
)
|
||||||
|
clear.click(lambda: None, None, chatbot, queue=False)
|
||||||
|
gr.Markdown(
|
||||||
|
"""
|
||||||
|
当前体验服务生成的所有内容都是由人工智能模型生成,我们对其生成内容的准确性、完整性和功能性不做任何保证,并且其生成的内容不代表我们的态度或观点。
|
||||||
|
|
||||||
|
联系方式: sl12160010@gmail.com 对于该项目有任何意见和建议都欢迎联系我.
|
||||||
|
Contact information: sl12160010@gmail.com. Any opinions or suggestions regarding the project are welcome to be addressed to me through this email.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
demo.launch(share=True)
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-03-30 21:38:07
|
Date: 2023-03-30 21:38:07
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-03-30 21:39:40
|
LastEditTime: 2023-04-06 03:37:23
|
||||||
FilePath: /Open-Llama/configs/instruction_tuning_config.py
|
FilePath: /Open-Llama/configs/instruction_tuning_config.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
max_length = 1024
|
max_length = 1024
|
||||||
train_batch_size = 2
|
train_batch_size = 2
|
||||||
num_training_steps = 37500
|
num_training_steps = 40000
|
||||||
num_warmup_steps = 100
|
num_warmup_steps = 100
|
||||||
initializer_range = 1e-2
|
initializer_range = 1e-2
|
||||||
lr = 2e-4
|
lr = 2e-4
|
||||||
|
@ -22,4 +22,4 @@ log_interval = 50
|
||||||
eval_interval = 500
|
eval_interval = 500
|
||||||
save_interval = 1000
|
save_interval = 1000
|
||||||
work_dir = "data/saved_ckpt/"
|
work_dir = "data/saved_ckpt/"
|
||||||
ckpt_path = "data/saved_ckpt/40000.pt"
|
ckpt_path = "data/saved_ckpt/83200.pt"
|
||||||
|
|
15
data/download_instruct.sh
Normal file
15
data/download_instruct.sh
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
#!/bin/bash
|
||||||
|
###
|
||||||
|
# @Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
# @Date: 2023-04-05 23:18:10
|
||||||
|
# @LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
# @LastEditTime: 2023-04-05 23:34:30
|
||||||
|
# @FilePath: /Open-Llama/data/download_instruct.sh
|
||||||
|
# @Description:
|
||||||
|
#
|
||||||
|
# Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
###
|
||||||
|
mkdir data/instruction_data
|
||||||
|
curl -C - --retry 3 'https://huggingface.co/datasets/RyokoAI/ShareGPT52K/resolve/main/sg_90k_part1.json' -o data/sg_90k_part1.json
|
||||||
|
curl -C - --retry 3 'https://huggingface.co/datasets/RyokoAI/ShareGPT52K/resolve/main/sg_90k_part2.json' -o data/sg_90k_part2.json
|
||||||
|
python3 data/preprocess_instruction.py
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-03-30 20:52:10
|
Date: 2023-03-30 20:52:10
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-03-30 20:52:12
|
LastEditTime: 2023-04-05 23:51:16
|
||||||
FilePath: /Open-Llama/data/preprocess_instruction.py
|
FilePath: /Open-Llama/data/preprocess_instruction.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -12,8 +12,11 @@ import json
|
||||||
import zstandard as zstd
|
import zstandard as zstd
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
root_dir = "data"
|
||||||
|
|
||||||
dataset = load_dataset("yizhongw/self_instruct")
|
dataset = load_dataset("yizhongw/self_instruct")
|
||||||
write_path = "data/instruction_data/part-self_instruct-{}.jsonl.zst"
|
write_path = root_dir + "/instruction_data/part-self_instruct-{}.jsonl.zst"
|
||||||
total_num = 0
|
total_num = 0
|
||||||
file_num = 0
|
file_num = 0
|
||||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
@ -27,9 +30,14 @@ for line in dataset["train"]:
|
||||||
wfp.write(b"\n")
|
wfp.write(b"\n")
|
||||||
total_num += 1
|
total_num += 1
|
||||||
wfp.close()
|
wfp.close()
|
||||||
|
print(
|
||||||
|
"yizhongw/self_instruct preprocess done. Total line: {}, Total file: {}".format(
|
||||||
|
total_num, file_num
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
dataset = load_dataset("BelleGroup/generated_train_0.5M_CN")
|
dataset = load_dataset("BelleGroup/train_0.5M_CN")
|
||||||
write_path = "data/instruction_data/part-belle_0.5M-{}.jsonl.zst"
|
write_path = root_dir + "/instruction_data/part-belle_0.5M-{}.jsonl.zst"
|
||||||
total_num = 0
|
total_num = 0
|
||||||
file_num = 0
|
file_num = 0
|
||||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
@ -43,9 +51,14 @@ for line in dataset["train"]:
|
||||||
wfp.write(b"\n")
|
wfp.write(b"\n")
|
||||||
total_num += 1
|
total_num += 1
|
||||||
wfp.close()
|
wfp.close()
|
||||||
|
print(
|
||||||
|
"BelleGroup/train_0.5M_CN preprocess done. Total line: {}, Total file: {}".format(
|
||||||
|
total_num, file_num
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
dataset = load_dataset("BelleGroup/generated_train_1M_CN")
|
dataset = load_dataset("BelleGroup/train_1M_CN")
|
||||||
write_path = "data/instruction_data/part-belle_1M-{}.jsonl.zst"
|
write_path = root_dir + "/instruction_data/part-belle_1M-{}.jsonl.zst"
|
||||||
total_num = 0
|
total_num = 0
|
||||||
file_num = 0
|
file_num = 0
|
||||||
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
@ -59,3 +72,96 @@ for line in dataset["train"]:
|
||||||
wfp.write(b"\n")
|
wfp.write(b"\n")
|
||||||
total_num += 1
|
total_num += 1
|
||||||
wfp.close()
|
wfp.close()
|
||||||
|
print(
|
||||||
|
"BelleGroup/train_1M_CN preprocess done. Total line: {}, Total file: {}".format(
|
||||||
|
total_num, file_num
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = load_dataset("BelleGroup/school_math_0.25M")
|
||||||
|
write_path = root_dir + "/instruction_data/part-belle_school_math_0.25M-{}.jsonl.zst"
|
||||||
|
total_num = 0
|
||||||
|
file_num = 0
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
for line in dataset["train"]:
|
||||||
|
line = json.dumps(line)
|
||||||
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
|
file_num += 1
|
||||||
|
wfp.close()
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
wfp.write(line.encode("utf-8"))
|
||||||
|
wfp.write(b"\n")
|
||||||
|
total_num += 1
|
||||||
|
wfp.close()
|
||||||
|
print(
|
||||||
|
"BelleGroup/school_math_0.25M preprocess done. Total line: {}, Total file: {}".format(
|
||||||
|
total_num, file_num
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = load_dataset("BelleGroup/multiturn_chat_0.8M")
|
||||||
|
write_path = root_dir + "/instruction_data/part-belle_multiturn_chat_0.8M-{}.jsonl.zst"
|
||||||
|
total_num = 0
|
||||||
|
file_num = 0
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
for line in dataset["train"]:
|
||||||
|
line = json.dumps(line)
|
||||||
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
|
file_num += 1
|
||||||
|
wfp.close()
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
wfp.write(line.encode("utf-8"))
|
||||||
|
wfp.write(b"\n")
|
||||||
|
total_num += 1
|
||||||
|
wfp.close()
|
||||||
|
print(
|
||||||
|
"BelleGroup/multiturn_chat_0.8M preprocess done. Total line: {}, Total file: {}".format(
|
||||||
|
total_num, file_num
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = load_dataset("Graverman/Instruct-to-Code")
|
||||||
|
write_path = root_dir + "/instruction_data/part-instruct_to_code-{}.jsonl.zst"
|
||||||
|
total_num = 0
|
||||||
|
file_num = 0
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
for line in dataset["train"]:
|
||||||
|
line = json.dumps(line)
|
||||||
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
|
file_num += 1
|
||||||
|
wfp.close()
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
wfp.write(line.encode("utf-8"))
|
||||||
|
wfp.write(b"\n")
|
||||||
|
total_num += 1
|
||||||
|
wfp.close()
|
||||||
|
print(
|
||||||
|
"Graverman/Instruct-to-Code preprocess done. Total line: {}, Total file: {}".format(
|
||||||
|
total_num, file_num
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
write_path = root_dir + "/instruction_data/part-sharegpt_90K-{}.jsonl.zst"
|
||||||
|
total_num = 0
|
||||||
|
file_num = 0
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
with open("data/sg_90k_part1.json", "r") as fp:
|
||||||
|
data1 = json.load(fp)
|
||||||
|
with open("data/sg_90k_part2.json", "r") as fp:
|
||||||
|
data2 = json.load(fp)
|
||||||
|
data = data1 + data2
|
||||||
|
for line in data:
|
||||||
|
line = json.dumps(line)
|
||||||
|
if total_num % 1024 == 0 and total_num > 0:
|
||||||
|
file_num += 1
|
||||||
|
wfp.close()
|
||||||
|
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
|
||||||
|
wfp.write(line.encode("utf-8"))
|
||||||
|
wfp.write(b"\n")
|
||||||
|
total_num += 1
|
||||||
|
wfp.close()
|
||||||
|
print(
|
||||||
|
"RyokoAI/ShareGPT52K preprocess done. Total line: {}, Total file: {}".format(
|
||||||
|
total_num, file_num
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
69
dataset/collate_fn.py
Normal file
69
dataset/collate_fn.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
"""
|
||||||
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
|
Date: 2023-03-30 20:58:16
|
||||||
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
|
LastEditTime: 2023-04-05 22:11:03
|
||||||
|
FilePath: /Open-Llama/dataset/collate_fn.py
|
||||||
|
Description:
|
||||||
|
|
||||||
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn_gen(tokenizer, segment_max_length=1024, padding="longest"):
|
||||||
|
"""
|
||||||
|
Organize data into tensors by padding based on the preset maximum length.
|
||||||
|
"""
|
||||||
|
pad_id = tokenizer.pad_id
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
if padding == "longest":
|
||||||
|
max_length = max([len(i) for i in batch])
|
||||||
|
elif padding == "max_length":
|
||||||
|
max_length = segment_max_length
|
||||||
|
else:
|
||||||
|
raise Exception("Invalid argumet for padding: {}".format(padding))
|
||||||
|
input_ids = []
|
||||||
|
for i in batch:
|
||||||
|
input_len = len(i)
|
||||||
|
input_ids.append(i + [pad_id] * (max_length - input_len))
|
||||||
|
inputs = {
|
||||||
|
"input_ids": torch.tensor(input_ids, dtype=torch.int64),
|
||||||
|
}
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
return collate_fn
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sentencepiece as spm
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from dataset.pretrain_dataset import preprocess_wudao_gen, preprocess_the_pile_gen
|
||||||
|
|
||||||
|
from dataset.tokenizer import Tokenizer
|
||||||
|
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||||
|
|
||||||
|
sp_model = spm.SentencePieceProcessor(
|
||||||
|
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||||
|
)
|
||||||
|
tokenizer = Tokenizer(sp_model)
|
||||||
|
patterns = ["data/pretrain_data/part-*.jsonl.zst"]
|
||||||
|
paths = create_shard_kwargs(patterns)
|
||||||
|
transform_dict = {
|
||||||
|
"wudao": preprocess_wudao_gen(tokenizer),
|
||||||
|
"pile": preprocess_the_pile_gen(tokenizer),
|
||||||
|
}
|
||||||
|
data_set = DataIter(paths, transform_dict=transform_dict)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
data_set,
|
||||||
|
batch_size=8,
|
||||||
|
num_workers=4,
|
||||||
|
collate_fn=collate_fn_gen(tokenizer),
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
for batch in train_loader:
|
||||||
|
for k, v in batch.items():
|
||||||
|
print(k, v.shape)
|
||||||
|
break
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-03-17 19:32:20
|
Date: 2023-03-17 19:32:20
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-03-26 23:03:32
|
LastEditTime: 2023-04-06 03:37:55
|
||||||
FilePath: /Open-Llama/dataset/data_iter.py
|
FilePath: /Open-Llama/dataset/data_iter.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -11,67 +11,84 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
import json
|
import json
|
||||||
from glob import glob
|
from glob import glob
|
||||||
import zstandard as zstd
|
import zstandard as zstd
|
||||||
|
from torch.utils.data import IterableDataset
|
||||||
|
|
||||||
|
|
||||||
def create_data_iter(paths, transform_dict=None, process_index=0, num_processes=1):
|
class DataIter(IterableDataset):
|
||||||
"""
|
"""
|
||||||
Currently, the allowed storage formats are jsonl and jsonl.zst.
|
Currently, the allowed storage formats are jsonl.zst.
|
||||||
Each line of the data is a dictionary, which can be parsed as JSON for subsequent processing after reading.
|
Each line of the data is a dictionary, which can be parsed as JSON for subsequent processing after reading.
|
||||||
|
Currently, only single worker is supported.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
paths_with_index,
|
||||||
|
transform_dict=None,
|
||||||
|
max_length=None,
|
||||||
|
concat_docs=False,
|
||||||
|
process_index=0,
|
||||||
|
num_processes=1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.paths_with_index = paths_with_index
|
||||||
|
self.max_length = max_length
|
||||||
|
self.transform_dict = transform_dict
|
||||||
|
self.concat_docs = concat_docs
|
||||||
|
self.process_index = process_index
|
||||||
|
self.num_processes = num_processes
|
||||||
|
if self.concat_docs:
|
||||||
|
self.cache = []
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
past = None
|
past = None
|
||||||
for i, path in paths:
|
for i, path in self.paths_with_index:
|
||||||
|
# part-dataset_name-01.jsonl.zst
|
||||||
dataset_name = path.split("-")[-2]
|
dataset_name = path.split("-")[-2]
|
||||||
if num_processes > 1 and i % num_processes != process_index:
|
# shard to multiple device
|
||||||
|
if self.num_processes > 1 and i % self.num_processes != self.process_index:
|
||||||
continue
|
continue
|
||||||
|
# Log the file name when encountering a new file.
|
||||||
if past != dataset_name:
|
if past != dataset_name:
|
||||||
print("Loading data from {}".format(path))
|
print("Loading data from {}".format(path))
|
||||||
past = path
|
past = path
|
||||||
if path.endswith("jsonl.zst"):
|
# Currently, the allowed storage formats are jsonl.zst.
|
||||||
|
assert path.endswith("jsonl.zst")
|
||||||
with zstd.open(path, "r", encoding="utf-8") as fp:
|
with zstd.open(path, "r", encoding="utf-8") as fp:
|
||||||
for line in fp:
|
for line in fp:
|
||||||
|
# If the length of the cache is greater than max_length.
|
||||||
|
if self.concat_docs and len(self.cache) >= self.max_length:
|
||||||
|
seq = self.cache[: self.max_length]
|
||||||
|
self.cache = self.cache[self.max_length :]
|
||||||
|
yield seq
|
||||||
if isinstance(line, bytes):
|
if isinstance(line, bytes):
|
||||||
line = line.decode("utf-8")
|
line = line.decode("utf-8")
|
||||||
line = json.loads(line)
|
line = json.loads(line)
|
||||||
line["dataset"] = dataset_name
|
line["dataset"] = dataset_name
|
||||||
if transform_dict:
|
# Transformation, including sample, tokenize, etc.
|
||||||
line = transform_dict[dataset_name](line)
|
if self.transform_dict:
|
||||||
if isinstance(line, str):
|
line = self.transform_dict[dataset_name](line)
|
||||||
|
# skip bad doc
|
||||||
|
if line is None:
|
||||||
|
continue
|
||||||
|
elif isinstance(line, str):
|
||||||
yield line
|
yield line
|
||||||
elif isinstance(line, list):
|
# must be list of list
|
||||||
for i in line:
|
elif isinstance(line, list) and isinstance(line[0], list):
|
||||||
yield i
|
for seq in line:
|
||||||
|
if self.concat_docs:
|
||||||
|
# concat seq from multiple docs
|
||||||
|
self.cache += seq
|
||||||
|
else:
|
||||||
|
yield seq
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Unsupported type in Transformation: {}".format(
|
"Unsupported type in Transformation: {}".format(
|
||||||
transform_dict[dataset_name]
|
self.transform_dict[dataset_name]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield line
|
yield line
|
||||||
elif path.endswith("jsonl"):
|
|
||||||
with open(path, "r") as fp:
|
|
||||||
for line in fp:
|
|
||||||
if isinstance(line, bytes):
|
|
||||||
line = line.decode("utf-8")
|
|
||||||
line = json.loads(line)
|
|
||||||
line["dataset"] = dataset_name
|
|
||||||
if transform_dict:
|
|
||||||
line = transform_dict[dataset_name](line)
|
|
||||||
if isinstance(line, str):
|
|
||||||
yield line
|
|
||||||
elif isinstance(line, list):
|
|
||||||
for i in line:
|
|
||||||
yield i
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
"Unsupported type in Transformation: {}".format(
|
|
||||||
transform_dict[dataset_name]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
yield line
|
|
||||||
else:
|
|
||||||
raise Exception("File format of {} is not supported yet.".format(path))
|
|
||||||
|
|
||||||
|
|
||||||
def create_shard_kwargs(patterns, repeat=1):
|
def create_shard_kwargs(patterns, repeat=1):
|
||||||
|
@ -90,7 +107,9 @@ if __name__ == "__main__":
|
||||||
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
|
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
|
||||||
paths = create_shard_kwargs(patterns)
|
paths = create_shard_kwargs(patterns)
|
||||||
transform_dict = {"wudao": lambda x: x["title"], "pile": lambda x: [x["text"]]}
|
transform_dict = {"wudao": lambda x: x["title"], "pile": lambda x: [x["text"]]}
|
||||||
data_iter = create_data_iter(paths, transform_dict=transform_dict)
|
data_iter = DataIter(
|
||||||
|
paths, transform_dict=transform_dict, max_length=16, concat_docs=True
|
||||||
|
)
|
||||||
for i, data in enumerate(data_iter):
|
for i, data in enumerate(data_iter):
|
||||||
print(i, data)
|
print(i, data)
|
||||||
if i == 20:
|
if i == 20:
|
||||||
|
|
|
@ -1,192 +0,0 @@
|
||||||
"""
|
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
|
||||||
Date: 2023-03-30 20:58:16
|
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
|
||||||
LastEditTime: 2023-03-30 21:00:49
|
|
||||||
FilePath: /Open-Llama/dataset/data_loader.py
|
|
||||||
Description:
|
|
||||||
|
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
|
||||||
"""
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024, padding="longest"):
|
|
||||||
"""
|
|
||||||
Organize data into tensors by padding based on the preset maximum length.
|
|
||||||
"""
|
|
||||||
pad_id = tokenizer.pad_id
|
|
||||||
|
|
||||||
def pretrain_collate_fn(batch):
|
|
||||||
if padding == "longest":
|
|
||||||
max_length = max([len(i) for i in batch])
|
|
||||||
elif padding == "max_length":
|
|
||||||
max_length = segment_max_length
|
|
||||||
else:
|
|
||||||
raise Exception("Invalid argumet for padding: {}".format(padding))
|
|
||||||
input_ids = []
|
|
||||||
for i in batch:
|
|
||||||
input_len = len(i)
|
|
||||||
input_ids.append(i + [pad_id] * (max_length - input_len))
|
|
||||||
inputs = {
|
|
||||||
"input_ids": torch.tensor(input_ids, dtype=torch.int64),
|
|
||||||
}
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
return pretrain_collate_fn
|
|
||||||
|
|
||||||
|
|
||||||
class BySequenceLengthDataset(torch.utils.data.IterableDataset):
|
|
||||||
"""
|
|
||||||
experimental
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, generator, batch_size, accelerator=None, bucket_size=16, max_length=1024
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.generator = generator
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.bucket_size = bucket_size
|
|
||||||
self.bucket_num = math.ceil(max_length / bucket_size)
|
|
||||||
self.buckets = [[] for _ in range(self.bucket_num)]
|
|
||||||
self.bucket_idx = None
|
|
||||||
self.accelerator = accelerator
|
|
||||||
if self.accelerator is not None:
|
|
||||||
self.buckets_ele_num = torch.tensor(
|
|
||||||
[0] * self.bucket_num, dtype=torch.int64, device=accelerator.device
|
|
||||||
)
|
|
||||||
self.buckets_indexes = torch.arange(
|
|
||||||
self.bucket_num, device=accelerator.device
|
|
||||||
)
|
|
||||||
self.finished = False
|
|
||||||
self.has_no_same_bucket = False
|
|
||||||
self.rest = None
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
if self.batch_size <= 1:
|
|
||||||
return self.generator
|
|
||||||
|
|
||||||
def bucket_iter():
|
|
||||||
while True:
|
|
||||||
if self.bucket_idx is not None:
|
|
||||||
sample = self.buckets[self.bucket_idx].pop()
|
|
||||||
if len(self.buckets[self.bucket_idx]) == 0:
|
|
||||||
self.bucket_idx = None
|
|
||||||
yield sample
|
|
||||||
try:
|
|
||||||
sample = next(self.generator)
|
|
||||||
except StopIteration:
|
|
||||||
break
|
|
||||||
sample_len = len(sample) - 1
|
|
||||||
bucket_idx = sample_len // self.bucket_size
|
|
||||||
if len(self.buckets[bucket_idx]) == self.batch_size - 1:
|
|
||||||
self.bucket_idx = bucket_idx
|
|
||||||
yield sample
|
|
||||||
else:
|
|
||||||
self.buckets[bucket_idx].append(sample)
|
|
||||||
|
|
||||||
def parallel_bucket_iter():
|
|
||||||
while True:
|
|
||||||
if self.bucket_idx is not None:
|
|
||||||
sample = self.buckets[self.bucket_idx].pop()
|
|
||||||
self.buckets_ele_num[self.bucket_idx] -= 1
|
|
||||||
buckets_ele_num = self.accelerator.gather(self.buckets_ele_num)
|
|
||||||
buckets_ele_num = buckets_ele_num.reshape(
|
|
||||||
self.accelerator.num_processes, self.bucket_num
|
|
||||||
)
|
|
||||||
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
|
|
||||||
if min_buckets_ele_num[self.bucket_idx] <= 0:
|
|
||||||
self.bucket_idx = None
|
|
||||||
yield sample
|
|
||||||
else:
|
|
||||||
if self.finished:
|
|
||||||
if self.has_no_same_bucket:
|
|
||||||
if self.rest is None:
|
|
||||||
self.rest = []
|
|
||||||
for bucket in self.buckets:
|
|
||||||
for i in bucket:
|
|
||||||
self.rest.append(i)
|
|
||||||
elif len(self.rest) > 0:
|
|
||||||
yield self.rest.pop()
|
|
||||||
else:
|
|
||||||
raise StopIteration()
|
|
||||||
else:
|
|
||||||
buckets_ele_num = self.accelerator.gather(
|
|
||||||
self.buckets_ele_num
|
|
||||||
)
|
|
||||||
buckets_ele_num = buckets_ele_num.view(
|
|
||||||
self.accelerator.num_processes, self.bucket_num
|
|
||||||
)
|
|
||||||
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
|
|
||||||
valid_bucket_idx = self.buckets_indexes[
|
|
||||||
min_buckets_ele_num >= self.batch_size
|
|
||||||
]
|
|
||||||
if len(valid_bucket_idx) > 0:
|
|
||||||
self.bucket_idx = valid_bucket_idx[0].cpu().item()
|
|
||||||
else:
|
|
||||||
self.has_no_same_bucket = True
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
sample = next(self.generator)
|
|
||||||
except StopIteration:
|
|
||||||
self.finished = True
|
|
||||||
continue
|
|
||||||
sample_len = len(sample) - 1
|
|
||||||
bucket_idx = sample_len // self.bucket_size
|
|
||||||
self.buckets[bucket_idx].append(sample)
|
|
||||||
self.buckets_ele_num[bucket_idx] += 1
|
|
||||||
buckets_ele_num = self.accelerator.gather(
|
|
||||||
self.buckets_ele_num
|
|
||||||
).cpu()
|
|
||||||
buckets_ele_num = buckets_ele_num.view(
|
|
||||||
self.accelerator.num_processes, self.bucket_num
|
|
||||||
)
|
|
||||||
min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
|
|
||||||
valid_bucket_idx = self.buckets_indexes[
|
|
||||||
min_buckets_ele_num >= self.batch_size
|
|
||||||
]
|
|
||||||
if len(valid_bucket_idx) > 0:
|
|
||||||
self.bucket_idx = valid_bucket_idx[0].cpu().item()
|
|
||||||
|
|
||||||
if self.accelerator:
|
|
||||||
return parallel_bucket_iter()
|
|
||||||
else:
|
|
||||||
return bucket_iter()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import sentencepiece as spm
|
|
||||||
from datasets import IterableDataset
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from dataset.pretrain_dataset import preprocess_wudao_gen, preprocess_the_pile_gen
|
|
||||||
|
|
||||||
from dataset.tokenizer import Tokenizer
|
|
||||||
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
|
||||||
|
|
||||||
sp_model = spm.SentencePieceProcessor(
|
|
||||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
|
||||||
)
|
|
||||||
tokenizer = Tokenizer(sp_model)
|
|
||||||
patterns = ["data/pretrain_data/part-*.jsonl.zst"]
|
|
||||||
paths = create_shard_kwargs(patterns)
|
|
||||||
transform_dict = {
|
|
||||||
"wudao": preprocess_wudao_gen(tokenizer),
|
|
||||||
"pile": preprocess_the_pile_gen(tokenizer),
|
|
||||||
}
|
|
||||||
data_set = IterableDataset.from_generator(
|
|
||||||
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
|
|
||||||
)
|
|
||||||
train_loader = DataLoader(
|
|
||||||
data_set,
|
|
||||||
batch_size=8,
|
|
||||||
num_workers=4,
|
|
||||||
collate_fn=pretrain_collate_fn_gen(tokenizer),
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
for batch in train_loader:
|
|
||||||
for k, v in batch.items():
|
|
||||||
print(k, v.shape)
|
|
||||||
break
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-03-30 21:02:00
|
Date: 2023-03-30 21:02:00
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-03-30 21:02:06
|
LastEditTime: 2023-04-06 03:33:27
|
||||||
FilePath: /Open-Llama/dataset/instruction_dataset.py
|
FilePath: /Open-Llama/dataset/instruction_dataset.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
|
||||||
prompt = line["prompt"]
|
prompt = line["prompt"]
|
||||||
if prompt.endswith("Output:"):
|
if prompt.endswith("Output:"):
|
||||||
prompt = prompt[:-7]
|
prompt = prompt[:-7]
|
||||||
total = "user:{}<s>system:{}".format(prompt.strip(), line["completion"].strip())
|
total = "user:{}\nsystem:{}".format(prompt.strip(), line["completion"].strip())
|
||||||
out = tokenizer(total)
|
out = tokenizer(total)
|
||||||
input_ids = out["input_ids"]
|
input_ids = out["input_ids"]
|
||||||
return [
|
return [
|
||||||
|
@ -39,12 +39,12 @@ def preprocess_belle_gen(tokenizer, segment_max_length=1024):
|
||||||
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||||
Split the data based on the tokenized length according to the maximum length.
|
Split the data based on the tokenized length according to the maximum length.
|
||||||
"""
|
"""
|
||||||
prompt = line["input"].replace("\\n", "")
|
prompt = line["instruction"].replace("\\n", "")
|
||||||
prompt = prompt.strip("")
|
prompt = prompt.strip("")
|
||||||
|
|
||||||
completion = line["target"].replace("\\n", "")
|
completion = line["output"].replace("\\n", "")
|
||||||
completion = completion.strip("")
|
completion = completion.strip("")
|
||||||
total = "user:{}<s>system:{}".format(prompt, completion)
|
total = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
out = tokenizer(total)
|
out = tokenizer(total)
|
||||||
input_ids = out["input_ids"]
|
input_ids = out["input_ids"]
|
||||||
return [
|
return [
|
||||||
|
@ -55,28 +55,124 @@ def preprocess_belle_gen(tokenizer, segment_max_length=1024):
|
||||||
return preprocess_belle
|
return preprocess_belle
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_belle_multiturn_chat_gen(tokenizer, segment_max_length=1024):
|
||||||
|
def preprocess_belle_multiturn_chat(line):
|
||||||
|
"""
|
||||||
|
The format of the data is roughly as follows.
|
||||||
|
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||||
|
Split the data based on the tokenized length according to the maximum length.
|
||||||
|
"""
|
||||||
|
prompt = line["instruction"].replace("\\n", "")
|
||||||
|
prompt = prompt.strip("")
|
||||||
|
|
||||||
|
completion = line["output"].replace("\\n", "")
|
||||||
|
completion = completion.strip("")
|
||||||
|
chats = prompt + completion
|
||||||
|
chats = chats.split("Human:")
|
||||||
|
input_ids = []
|
||||||
|
for chat in chats:
|
||||||
|
if chat.strip() == "":
|
||||||
|
continue
|
||||||
|
res = chat.split("Assistant:")
|
||||||
|
if len(res) != 2:
|
||||||
|
continue
|
||||||
|
prompt, completion = res
|
||||||
|
prompt = prompt.strip()
|
||||||
|
completion = completion.strip()
|
||||||
|
chat = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
|
out = tokenizer(chat)
|
||||||
|
input_ids.extend(out["input_ids"])
|
||||||
|
if len(input_ids) == 0:
|
||||||
|
return None
|
||||||
|
return [
|
||||||
|
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||||
|
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||||
|
]
|
||||||
|
|
||||||
|
return preprocess_belle_multiturn_chat
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_sharegpt_gen(tokenizer, segment_max_length=1024):
|
||||||
|
def preprocess_sharegpt(line):
|
||||||
|
"""
|
||||||
|
The format of the data is roughly as follows.
|
||||||
|
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||||
|
Split the data based on the tokenized length according to the maximum length.
|
||||||
|
"""
|
||||||
|
chats = line["conversations"]
|
||||||
|
if chats[0]["from"] != "human":
|
||||||
|
chats = chats[1:]
|
||||||
|
input_ids = []
|
||||||
|
for i in range(len(chats) // 2):
|
||||||
|
prompt = chats[2 * i]
|
||||||
|
completion = chats[2 * i + 1]
|
||||||
|
if not (prompt["from"] == "human" and completion["from"] == "gpt"):
|
||||||
|
continue
|
||||||
|
prompt = prompt["value"]
|
||||||
|
prompt = prompt.strip()
|
||||||
|
completion = completion["value"]
|
||||||
|
completion = completion.strip()
|
||||||
|
chat = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
|
out = tokenizer(chat)
|
||||||
|
input_ids.extend(out["input_ids"])
|
||||||
|
if input_ids == []:
|
||||||
|
return None
|
||||||
|
return [
|
||||||
|
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||||
|
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||||
|
]
|
||||||
|
|
||||||
|
return preprocess_sharegpt
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_instruct_code_gen(tokenizer, segment_max_length=1024):
|
||||||
|
def preprocess_instruct_code(line):
|
||||||
|
"""
|
||||||
|
The format of the data is roughly as follows.
|
||||||
|
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
|
||||||
|
Split the data based on the tokenized length according to the maximum length.
|
||||||
|
"""
|
||||||
|
prompt = line["instruction"].replace("\\n", "")
|
||||||
|
prompt = prompt.strip("")
|
||||||
|
|
||||||
|
completion = line["answer"].replace("\\n", "")
|
||||||
|
completion = completion.strip("")
|
||||||
|
total = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
|
out = tokenizer(total)
|
||||||
|
input_ids = out["input_ids"]
|
||||||
|
return [
|
||||||
|
input_ids[i * segment_max_length : (i + 1) * segment_max_length]
|
||||||
|
for i in range(math.ceil(len(input_ids) / segment_max_length))
|
||||||
|
]
|
||||||
|
|
||||||
|
return preprocess_instruct_code
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
from datasets import IterableDataset
|
|
||||||
|
|
||||||
from dataset.tokenizer import Tokenizer
|
from dataset.tokenizer import Tokenizer
|
||||||
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||||
|
|
||||||
sp_model = spm.SentencePieceProcessor(
|
sp_model = spm.SentencePieceProcessor(
|
||||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||||
)
|
)
|
||||||
tokenizer = Tokenizer(sp_model)
|
tokenizer = Tokenizer(sp_model)
|
||||||
patterns = ["data/instruction_data/part-belle_1M*.jsonl.zst"]
|
patterns = ["data/instruction_data/part-belle_multiturn_chat_0.8M-*.jsonl.zst"]
|
||||||
paths = create_shard_kwargs(patterns)
|
paths = create_shard_kwargs(patterns)
|
||||||
transform_dict = {
|
transform_dict = {
|
||||||
|
"self_instruct": preprocess_self_instruction_gen(tokenizer),
|
||||||
"belle_1M": preprocess_belle_gen(tokenizer),
|
"belle_1M": preprocess_belle_gen(tokenizer),
|
||||||
"belle_0.5M": preprocess_belle_gen(tokenizer),
|
"belle_0.5M": preprocess_belle_gen(tokenizer),
|
||||||
"self_instruct": preprocess_self_instruction_gen(tokenizer),
|
"belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
|
||||||
|
"belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
|
||||||
|
"instruct_to_code": preprocess_instruct_code_gen(tokenizer),
|
||||||
|
"sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
|
||||||
}
|
}
|
||||||
data_set = IterableDataset.from_generator(
|
data_set = DataIter(
|
||||||
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
|
paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
|
||||||
)
|
)
|
||||||
for i, sample in enumerate(data_set):
|
for i, sample in enumerate(data_set):
|
||||||
print(sample, sp_model.Decode(sample))
|
print(sp_model.decode(sample))
|
||||||
if i == 20:
|
if i == 1:
|
||||||
break
|
break
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-03-17 20:41:25
|
Date: 2023-03-17 20:41:25
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-03-26 23:07:56
|
LastEditTime: 2023-04-05 22:32:39
|
||||||
FilePath: /Open-Llama/dataset/pretrain_dataset.py
|
FilePath: /Open-Llama/dataset/pretrain_dataset.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -49,10 +49,9 @@ def preprocess_the_pile_gen(tokenizer, segment_max_length=1024):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
from datasets import IterableDataset
|
|
||||||
|
|
||||||
from dataset.tokenizer import Tokenizer
|
from dataset.tokenizer import Tokenizer
|
||||||
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||||
|
|
||||||
sp_model = spm.SentencePieceProcessor(
|
sp_model = spm.SentencePieceProcessor(
|
||||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||||
|
@ -64,8 +63,8 @@ if __name__ == "__main__":
|
||||||
"wudao": preprocess_wudao_gen(tokenizer),
|
"wudao": preprocess_wudao_gen(tokenizer),
|
||||||
"pile": preprocess_the_pile_gen(tokenizer),
|
"pile": preprocess_the_pile_gen(tokenizer),
|
||||||
}
|
}
|
||||||
data_set = IterableDataset.from_generator(
|
data_set = DataIter(
|
||||||
create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
|
paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
|
||||||
)
|
)
|
||||||
for sample in data_set:
|
for sample in data_set:
|
||||||
print(sample)
|
print(sample)
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-03-20 21:39:47
|
Date: 2023-03-20 21:39:47
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-03-26 23:09:39
|
LastEditTime: 2023-04-06 23:01:50
|
||||||
FilePath: /Open-Llama/dataset/tokenizer.py
|
FilePath: /Open-Llama/dataset/tokenizer.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -145,14 +145,34 @@ class Tokenizer:
|
||||||
out["attention_mask"] = attention_mask
|
out["attention_mask"] = attention_mask
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def decode(self, inputs):
|
def decode(self, inputs, max_rounds=None):
|
||||||
inputs = inputs.tolist()
|
inputs = inputs.tolist()
|
||||||
out = []
|
out = []
|
||||||
for i in inputs:
|
for i, ids in enumerate(inputs):
|
||||||
if self.eos_id in i:
|
count = 0
|
||||||
eos_idx = i.index(self.eos_id)
|
flag = False
|
||||||
i = i[:eos_idx]
|
for j, token in enumerate(ids):
|
||||||
out.append(i)
|
if token == self.eos_id:
|
||||||
|
if max_rounds is None:
|
||||||
|
flag = True
|
||||||
|
break
|
||||||
|
elif isinstance(max_rounds, int):
|
||||||
|
if count < max_rounds:
|
||||||
|
count += 1
|
||||||
|
else:
|
||||||
|
flag = True
|
||||||
|
break
|
||||||
|
elif isinstance(max_rounds, list):
|
||||||
|
if count < max_rounds[i]:
|
||||||
|
count += 1
|
||||||
|
else:
|
||||||
|
flag = True
|
||||||
|
break
|
||||||
|
if flag:
|
||||||
|
ids = ids[:j]
|
||||||
|
else:
|
||||||
|
ids = ids
|
||||||
|
out.append(ids)
|
||||||
out = self.sp_model.Decode(out)
|
out = self.sp_model.Decode(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -183,11 +203,11 @@ if __name__ == "__main__":
|
||||||
for i, j in zip(tmp, out):
|
for i, j in zip(tmp, out):
|
||||||
assert normalize("NFKC", i) == j
|
assert normalize("NFKC", i) == j
|
||||||
|
|
||||||
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||||
|
|
||||||
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
|
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
|
||||||
paths = create_shard_kwargs(patterns)
|
paths = create_shard_kwargs(patterns)
|
||||||
data_iter = create_data_iter(paths)
|
data_iter = DataIter(paths)
|
||||||
for i, data in enumerate(data_iter):
|
for i, data in enumerate(data_iter):
|
||||||
assert (
|
assert (
|
||||||
normalize("NFKC", data["content"])
|
normalize("NFKC", data["content"])
|
||||||
|
|
|
@ -2,14 +2,14 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-03-24 20:49:03
|
Date: 2023-03-24 20:49:03
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-03-26 23:43:59
|
LastEditTime: 2023-04-05 22:40:29
|
||||||
FilePath: /Open-Llama/dataset/train_tokenizer.py
|
FilePath: /Open-Llama/dataset/train_tokenizer.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
import random
|
import random
|
||||||
from dataset.data_iter import create_data_iter, create_shard_kwargs
|
from dataset.data_iter import DataIter, create_shard_kwargs
|
||||||
|
|
||||||
wudao_patterns = [
|
wudao_patterns = [
|
||||||
"data/pretrain_data/part-wudao-*.jsonl.zst",
|
"data/pretrain_data/part-wudao-*.jsonl.zst",
|
||||||
|
@ -24,10 +24,10 @@ pile_paths = create_shard_kwargs(pile_patterns)
|
||||||
random.shuffle(pile_paths)
|
random.shuffle(pile_paths)
|
||||||
paths = wudao_paths[:5] + pile_paths[:10]
|
paths = wudao_paths[:5] + pile_paths[:10]
|
||||||
transform_dict = {
|
transform_dict = {
|
||||||
"wudao": lambda line: [(line["title"] + "\n" + line["content"])],
|
"wudao": lambda line: line["title"] + "\n" + line["content"],
|
||||||
"pile": lambda line: [line["text"]],
|
"pile": lambda line: line["text"],
|
||||||
}
|
}
|
||||||
data_iter = create_data_iter(paths, transform_dict)
|
data_iter = iter(DataIter(paths, transform_dict))
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-03-30 21:35:01
|
Date: 2023-03-30 21:35:01
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-03-30 21:40:03
|
LastEditTime: 2023-04-06 03:35:31
|
||||||
FilePath: /Open-Llama/inctruction_tuning.py
|
FilePath: /Open-Llama/inctruction_tuning.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -16,18 +16,20 @@ import random
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
from torchinfo import summary
|
from torchinfo import summary
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from datasets import IterableDataset
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from deepspeed.ops.adam import FusedAdam
|
from deepspeed.ops.adam import FusedAdam
|
||||||
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
|
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
|
||||||
|
|
||||||
from dataset.validation import val_set
|
from dataset.validation import val_set
|
||||||
from dataset.tokenizer import Tokenizer
|
from dataset.tokenizer import Tokenizer
|
||||||
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||||
from dataset.data_loader import pretrain_collate_fn_gen
|
from dataset.collate_fn import collate_fn_gen
|
||||||
from dataset.instruction_dataset import (
|
from dataset.instruction_dataset import (
|
||||||
preprocess_belle_gen,
|
preprocess_belle_gen,
|
||||||
preprocess_self_instruction_gen,
|
preprocess_self_instruction_gen,
|
||||||
|
preprocess_belle_multiturn_chat_gen,
|
||||||
|
preprocess_instruct_code_gen,
|
||||||
|
preprocess_sharegpt_gen,
|
||||||
)
|
)
|
||||||
from configs.instruction_tuning_config import *
|
from configs.instruction_tuning_config import *
|
||||||
|
|
||||||
|
@ -46,25 +48,28 @@ tokenizer = Tokenizer(sp_model)
|
||||||
paths = create_shard_kwargs(patterns, repeat=3)
|
paths = create_shard_kwargs(patterns, repeat=3)
|
||||||
random.shuffle(paths)
|
random.shuffle(paths)
|
||||||
transform_dict = {
|
transform_dict = {
|
||||||
"belle_1M": preprocess_belle_gen(tokenizer, max_length),
|
"self_instruct": preprocess_self_instruction_gen(tokenizer),
|
||||||
"belle_0.5M": preprocess_belle_gen(tokenizer, max_length),
|
"belle_1M": preprocess_belle_gen(tokenizer),
|
||||||
"self_instruct": preprocess_self_instruction_gen(tokenizer, max_length),
|
"belle_0.5M": preprocess_belle_gen(tokenizer),
|
||||||
|
"belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
|
||||||
|
"belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
|
||||||
|
"instruct_to_code": preprocess_instruct_code_gen(tokenizer),
|
||||||
|
"sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
|
||||||
}
|
}
|
||||||
data_set = IterableDataset.from_generator(
|
data_set = DataIter(
|
||||||
create_data_iter,
|
paths,
|
||||||
gen_kwargs={
|
transform_dict=transform_dict,
|
||||||
"paths": paths,
|
concat_docs=True,
|
||||||
"transform_dict": transform_dict,
|
max_length=max_length,
|
||||||
"process_index": accelerator.process_index,
|
process_index=accelerator.process_index,
|
||||||
"num_processes": accelerator.num_processes,
|
num_processes=accelerator.num_processes,
|
||||||
},
|
|
||||||
)
|
)
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
data_set,
|
data_set,
|
||||||
batch_size=train_batch_size,
|
batch_size=train_batch_size,
|
||||||
# If num_workers is greater than 1, duplicate data may occur.
|
# If num_workers is greater than 1, duplicate data may occur.
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
|
collate_fn=collate_fn_gen(tokenizer, max_length),
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
# smaller initializer_range make training more stable
|
# smaller initializer_range make training more stable
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-03-17 14:27:28
|
Date: 2023-03-17 14:27:28
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-03-27 01:07:25
|
LastEditTime: 2023-04-05 22:46:31
|
||||||
FilePath: /Open-Llama/pretrain_llama.py
|
FilePath: /Open-Llama/pretrain_llama.py
|
||||||
Description:
|
Description:
|
||||||
pretrain GPT
|
pretrain GPT
|
||||||
|
@ -16,15 +16,14 @@ import random
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
from torchinfo import summary
|
from torchinfo import summary
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from datasets import IterableDataset
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from deepspeed.ops.adam import FusedAdam
|
from deepspeed.ops.adam import FusedAdam
|
||||||
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
|
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
|
||||||
|
|
||||||
from dataset.validation import val_set
|
from dataset.validation import val_set
|
||||||
from dataset.tokenizer import Tokenizer
|
from dataset.tokenizer import Tokenizer
|
||||||
from dataset.data_iter import create_shard_kwargs, create_data_iter
|
from dataset.data_iter import create_shard_kwargs, DataIter
|
||||||
from dataset.data_loader import pretrain_collate_fn_gen
|
from dataset.collate_fn import collate_fn_gen
|
||||||
from dataset.pretrain_dataset import (
|
from dataset.pretrain_dataset import (
|
||||||
preprocess_the_pile_gen,
|
preprocess_the_pile_gen,
|
||||||
preprocess_wudao_gen,
|
preprocess_wudao_gen,
|
||||||
|
@ -49,21 +48,20 @@ transform_dict = {
|
||||||
"wudao": preprocess_wudao_gen(tokenizer, max_length),
|
"wudao": preprocess_wudao_gen(tokenizer, max_length),
|
||||||
"pile": preprocess_the_pile_gen(tokenizer, max_length),
|
"pile": preprocess_the_pile_gen(tokenizer, max_length),
|
||||||
}
|
}
|
||||||
data_set = IterableDataset.from_generator(
|
data_set = DataIter(
|
||||||
create_data_iter,
|
paths,
|
||||||
gen_kwargs={
|
transform_dict=transform_dict,
|
||||||
"paths": paths,
|
concat_docs=True,
|
||||||
"transform_dict": transform_dict,
|
max_length=max_length,
|
||||||
"process_index": accelerator.process_index,
|
process_index=accelerator.process_index,
|
||||||
"num_processes": accelerator.num_processes,
|
num_processes=accelerator.num_processes,
|
||||||
},
|
|
||||||
)
|
)
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
data_set,
|
data_set,
|
||||||
batch_size=train_batch_size,
|
batch_size=train_batch_size,
|
||||||
# If num_workers is greater than 1, duplicate data may occur.
|
# If num_workers is greater than 1, duplicate data may occur.
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
|
collate_fn=collate_fn_gen(tokenizer, max_length),
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
# smaller initializer_range make training more stable
|
# smaller initializer_range make training more stable
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-03-31 13:26:15
|
Date: 2023-03-31 13:26:15
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-03-31 14:05:35
|
LastEditTime: 2023-04-06 03:45:44
|
||||||
FilePath: /Open-Llama/server.py
|
FilePath: /Open-Llama/server.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -32,7 +32,9 @@ raw_model = LlamaForCausalLM(
|
||||||
shared_input_output_embedding=True,
|
shared_input_output_embedding=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
ckpt = torch.load("data/saved_ckpt/instruction_tuning_3_epochs/23001.pt", map_location="cpu")
|
ckpt = torch.load(
|
||||||
|
"data/saved_ckpt/instruction_tuning_3_epochs/37001.pt", map_location="cpu"
|
||||||
|
)
|
||||||
raw_model.load_state_dict(ckpt)
|
raw_model.load_state_dict(ckpt)
|
||||||
raw_model.eval()
|
raw_model.eval()
|
||||||
model = raw_model.cuda()
|
model = raw_model.cuda()
|
||||||
|
@ -41,7 +43,7 @@ print("ready")
|
||||||
|
|
||||||
def question_answer(prompt):
|
def question_answer(prompt):
|
||||||
print(prompt)
|
print(prompt)
|
||||||
raw_inputs = "user:{}<s>system:".format(prompt)
|
raw_inputs = "user:{}\nsystem:".format(prompt)
|
||||||
inputs_len = len(raw_inputs)
|
inputs_len = len(raw_inputs)
|
||||||
inputs = tokenizer(raw_inputs, return_tensors=True, add_special_tokens=False)
|
inputs = tokenizer(raw_inputs, return_tensors=True, add_special_tokens=False)
|
||||||
for k, v in inputs.items():
|
for k, v in inputs.items():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user