diff --git a/README.md b/README.md
index f0a1fa8..1a67fe9 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
* @Author: LiangSong(sl12160010@gmail.com)
* @Date: 2023-03-10 21:18:35
* @LastEditors: LiangSong(sl12160010@gmail.com)
- * @LastEditTime: 2023-04-02 21:32:26
+ * @LastEditTime: 2023-04-07 23:19:21
* @FilePath: /Open-Llama/README.md
* @Description:
*
@@ -16,7 +16,8 @@ Open-Llama是一个开源项目,提供了一整套用于构建大型语言模
## 进展
-虽然还没有完整的预训练完,但是我们先使用40K step预训练的模型进行了Instruction-tuning,模型可以服从简单的命令。目前没有多轮对话能力
+我们完成了300B token的预训练,总共训练80 K step,Global Batch Size和Llama中一致为4M。
+使用总共7部分数据构成Instruction-tuning数据,模型具有一定的编程能力、数学能力和多轮对话能力,具体数据见Instruction-Tuning部分。
[Demo](http://home.ustc.edu.cn/~sl9292/)
@@ -25,6 +26,9 @@ Open-Llama是一个开源项目,提供了一整套用于构建大型语言模
本模型的效果如下图,更多结果还待进一步测试。由于国内网络问题,使用上面的Demo可能出现请求丢失的情况,如长时间无响应可刷新重试

+下面是一个关于代码的多轮对话能力的展示
+
+
我们简单预估一下达到上面效果的一个花费,训练40K step使用了1.5亿条预训练数据,大约为110B token,总共训练时间76h,按Google Cloud的A100报价花费大约为19152美元。后续的Instruction-tuning训练了12k Step,使用1.6M条数据,总共训练时间3.4h,大约花费342美元。因此从0开始训练一个这样的模型总花费不到20000美元。
目前模型在数学方面和代码方面表现明显较差,这一方面和训练数据有关,另一方面我认为也是模型大小所造成的,然而这方面的逻辑推理能力是一个可用的模型所必备,因此后续更新会关注提升相关能力。
@@ -166,12 +170,17 @@ Total mult-adds (G): 6.89
我们使用目前开源的三个数据集进行Instruction-tuning,后续会加入更多的任务以及自己构建的数据集。
- [yizhongw/self_instruct](https://huggingface.co/datasets/yizhongw/self_instruct)
-- [BelleGroup/generated_train_0.5M_CN](https://huggingface.co/datasets/BelleGroup/generated_train_0.5M_CN)
-- [BelleGroup/generated_train_1M_CN](https://huggingface.co/datasets/BelleGroup/generated_train_1M_CN)
+- [BelleGroup/train_0.5M_CN](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
+- [BelleGroup/train_1M_CN](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
+- [BelleGroup/multiturn_chat_0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
+- [BelleGroup/school_math_0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
+- [RyokoAI/ShareGPT52K](https://huggingface.co/datasets/RyokoAI/ShareGPT52K)
+- [Graverman/Instruct-to-Code](https://huggingface.co/datasets/Graverman/Instruct-to-Code)
+其中ShareGPT52K数据在datastes的处理有些问题,我们直接下载原数据重新进行了处理。
我们对原始数据进行了一些预处理,格式如下
```
-user: {prompt}system: {completion}
+user: {prompt}\nsystem: {completion}
```
具体训练代码和预训练基本一样,代码可见
@@ -195,7 +204,12 @@ accelerate launch --config_file configs/default_config.yaml instruction_tuning.p
过程中Loss如下,基本在波动不怎么下降

### RLHF
+暂无
+### Server
+单轮对话使用server.py,对于多轮对话使用chat_server.py
+
+基于Gradio开发。
## 性能对比
### 训练框架
diff --git a/assets/multiturn_chat.jpeg b/assets/multiturn_chat.jpeg
new file mode 100644
index 0000000..7655050
Binary files /dev/null and b/assets/multiturn_chat.jpeg differ
diff --git a/chat_server.py b/chat_server.py
new file mode 100644
index 0000000..1b0f043
--- /dev/null
+++ b/chat_server.py
@@ -0,0 +1,115 @@
+"""
+Author: LiangSong(sl12160010@gmail.com)
+Date: 2023-04-06 22:30:10
+LastEditors: LiangSong(sl12160010@gmail.com)
+LastEditTime: 2023-04-07 23:03:31
+FilePath: /Open-Llama/chat_server.py
+Description:
+
+Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
+"""
+import torch
+import gradio as gr
+import sentencepiece as spm
+from dataset.tokenizer import Tokenizer
+from transformers import LlamaForCausalLM, LlamaConfig
+
+
+sp_model = spm.SentencePieceProcessor(
+ model_file="configs/10w_vocab_wudao5_pile10.model"
+)
+tokenizer = Tokenizer(sp_model)
+raw_model = LlamaForCausalLM(
+ LlamaConfig(
+ vocab_size=tokenizer.vocab_size,
+ initializer_range=0.01,
+ pad_token_id=tokenizer.pad_id,
+ rms_norm_eps=1e-5,
+ hidden_dropout_prob=0.1,
+ attention_dropout_prob=0.1,
+ use_stable_embedding=True,
+ shared_input_output_embedding=True,
+ )
+)
+ckpt = torch.load(
+ "data/saved_ckpt/instruction_tuning_math_code_multiturn/36001.pt",
+ map_location="cpu",
+)
+raw_model.load_state_dict(ckpt)
+raw_model.eval()
+model = raw_model.cuda()
+print("ready")
+
+
+def parse_codeblock(text):
+ lines = text.split("\n")
+ for i, line in enumerate(lines):
+ if "```" in line:
+ if line != "```":
+ lines[i] = f'
'
+ else:
+ lines[i] = "
"
+ else:
+ if i > 0:
+ lines[i] = "
" + line.replace("<", "<").replace(">", ">")
+ return "".join(lines)
+
+
+with gr.Blocks() as demo:
+ gr.Markdown(
+ """
+ # [Open-Llama](https://github.com/Bayes-Song/Open-Llama)
+ 完全使用Open-Llama项目从0开始训练的Instruct-GPT模型,当长时间无响应(如20s以上)可刷新重试。
+
+ Instruct-GPT model is trained from scratch using the Open-Llama project without relying on any other pre-trained models. If there is no response for a long time (such as more than 20 seconds), please refresh and try again.
+ """
+ )
+ chatbot = gr.Chatbot()
+ msg = gr.Textbox()
+ clear = gr.Button("Clear")
+
+ def user(user_message, history):
+ print(user_message)
+ return "", history + [[user_message, None]]
+
+ def bot(history):
+ context = []
+ round = 0
+ for prompt, completion in history:
+ round += 1
+ if completion is None:
+ inputs = "user:{}\nsystem:".format(prompt)
+ inputs = tokenizer(
+ inputs, return_tensors=True, add_special_tokens=False
+ )
+ context.append(inputs["input_ids"])
+ else:
+ inputs = "user:{}\nsystem:{}".format(prompt, completion)
+ inputs = tokenizer(inputs, return_tensors=True, add_special_tokens=True)
+ context.append(inputs["input_ids"])
+ context = torch.cat(context, dim=-1)
+ context = context[:, -1024:]
+ inputs_len = context.shape[1]
+ context = context.cuda()
+ pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
+ pred = pred[:, inputs_len:]
+ pred = tokenizer.decode(pred.cpu())[0]
+ print(pred)
+ bot_message = parse_codeblock(pred)
+ history[-1][1] = bot_message
+ return history
+
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
+ bot, chatbot, chatbot
+ )
+ clear.click(lambda: None, None, chatbot, queue=False)
+ gr.Markdown(
+ """
+ 当前体验服务生成的所有内容都是由人工智能模型生成,我们对其生成内容的准确性、完整性和功能性不做任何保证,并且其生成的内容不代表我们的态度或观点。
+
+ 联系方式: sl12160010@gmail.com 对于该项目有任何意见和建议都欢迎联系我.
+ Contact information: sl12160010@gmail.com. Any opinions or suggestions regarding the project are welcome to be addressed to me through this email.
+ """
+ )
+
+demo.launch(share=True)
diff --git a/configs/instruction_tuning_config.py b/configs/instruction_tuning_config.py
index 54a2eb2..d9684b2 100644
--- a/configs/instruction_tuning_config.py
+++ b/configs/instruction_tuning_config.py
@@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 21:38:07
LastEditors: LiangSong(sl12160010@gmail.com)
-LastEditTime: 2023-03-30 21:39:40
+LastEditTime: 2023-04-06 03:37:23
FilePath: /Open-Llama/configs/instruction_tuning_config.py
Description:
@@ -10,7 +10,7 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
max_length = 1024
train_batch_size = 2
-num_training_steps = 37500
+num_training_steps = 40000
num_warmup_steps = 100
initializer_range = 1e-2
lr = 2e-4
@@ -22,4 +22,4 @@ log_interval = 50
eval_interval = 500
save_interval = 1000
work_dir = "data/saved_ckpt/"
-ckpt_path = "data/saved_ckpt/40000.pt"
+ckpt_path = "data/saved_ckpt/83200.pt"
diff --git a/data/download_instruct.sh b/data/download_instruct.sh
new file mode 100644
index 0000000..b916bd6
--- /dev/null
+++ b/data/download_instruct.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+###
+ # @Author: LiangSong(sl12160010@gmail.com)
+ # @Date: 2023-04-05 23:18:10
+ # @LastEditors: LiangSong(sl12160010@gmail.com)
+ # @LastEditTime: 2023-04-05 23:34:30
+ # @FilePath: /Open-Llama/data/download_instruct.sh
+ # @Description:
+ #
+ # Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
+###
+mkdir data/instruction_data
+curl -C - --retry 3 'https://huggingface.co/datasets/RyokoAI/ShareGPT52K/resolve/main/sg_90k_part1.json' -o data/sg_90k_part1.json
+curl -C - --retry 3 'https://huggingface.co/datasets/RyokoAI/ShareGPT52K/resolve/main/sg_90k_part2.json' -o data/sg_90k_part2.json
+python3 data/preprocess_instruction.py
\ No newline at end of file
diff --git a/data/preprocess_instruction.py b/data/preprocess_instruction.py
index 16cf8d1..6f6a2c4 100644
--- a/data/preprocess_instruction.py
+++ b/data/preprocess_instruction.py
@@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 20:52:10
LastEditors: LiangSong(sl12160010@gmail.com)
-LastEditTime: 2023-03-30 20:52:12
+LastEditTime: 2023-04-05 23:51:16
FilePath: /Open-Llama/data/preprocess_instruction.py
Description:
@@ -12,8 +12,11 @@ import json
import zstandard as zstd
from datasets import load_dataset
+
+root_dir = "data"
+
dataset = load_dataset("yizhongw/self_instruct")
-write_path = "data/instruction_data/part-self_instruct-{}.jsonl.zst"
+write_path = root_dir + "/instruction_data/part-self_instruct-{}.jsonl.zst"
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
@@ -27,9 +30,14 @@ for line in dataset["train"]:
wfp.write(b"\n")
total_num += 1
wfp.close()
+print(
+ "yizhongw/self_instruct preprocess done. Total line: {}, Total file: {}".format(
+ total_num, file_num
+ )
+)
-dataset = load_dataset("BelleGroup/generated_train_0.5M_CN")
-write_path = "data/instruction_data/part-belle_0.5M-{}.jsonl.zst"
+dataset = load_dataset("BelleGroup/train_0.5M_CN")
+write_path = root_dir + "/instruction_data/part-belle_0.5M-{}.jsonl.zst"
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
@@ -43,9 +51,14 @@ for line in dataset["train"]:
wfp.write(b"\n")
total_num += 1
wfp.close()
+print(
+ "BelleGroup/train_0.5M_CN preprocess done. Total line: {}, Total file: {}".format(
+ total_num, file_num
+ )
+)
-dataset = load_dataset("BelleGroup/generated_train_1M_CN")
-write_path = "data/instruction_data/part-belle_1M-{}.jsonl.zst"
+dataset = load_dataset("BelleGroup/train_1M_CN")
+write_path = root_dir + "/instruction_data/part-belle_1M-{}.jsonl.zst"
total_num = 0
file_num = 0
wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
@@ -59,3 +72,96 @@ for line in dataset["train"]:
wfp.write(b"\n")
total_num += 1
wfp.close()
+print(
+ "BelleGroup/train_1M_CN preprocess done. Total line: {}, Total file: {}".format(
+ total_num, file_num
+ )
+)
+
+dataset = load_dataset("BelleGroup/school_math_0.25M")
+write_path = root_dir + "/instruction_data/part-belle_school_math_0.25M-{}.jsonl.zst"
+total_num = 0
+file_num = 0
+wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
+for line in dataset["train"]:
+ line = json.dumps(line)
+ if total_num % 1024 == 0 and total_num > 0:
+ file_num += 1
+ wfp.close()
+ wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
+ wfp.write(line.encode("utf-8"))
+ wfp.write(b"\n")
+ total_num += 1
+wfp.close()
+print(
+ "BelleGroup/school_math_0.25M preprocess done. Total line: {}, Total file: {}".format(
+ total_num, file_num
+ )
+)
+
+dataset = load_dataset("BelleGroup/multiturn_chat_0.8M")
+write_path = root_dir + "/instruction_data/part-belle_multiturn_chat_0.8M-{}.jsonl.zst"
+total_num = 0
+file_num = 0
+wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
+for line in dataset["train"]:
+ line = json.dumps(line)
+ if total_num % 1024 == 0 and total_num > 0:
+ file_num += 1
+ wfp.close()
+ wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
+ wfp.write(line.encode("utf-8"))
+ wfp.write(b"\n")
+ total_num += 1
+wfp.close()
+print(
+ "BelleGroup/multiturn_chat_0.8M preprocess done. Total line: {}, Total file: {}".format(
+ total_num, file_num
+ )
+)
+
+dataset = load_dataset("Graverman/Instruct-to-Code")
+write_path = root_dir + "/instruction_data/part-instruct_to_code-{}.jsonl.zst"
+total_num = 0
+file_num = 0
+wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
+for line in dataset["train"]:
+ line = json.dumps(line)
+ if total_num % 1024 == 0 and total_num > 0:
+ file_num += 1
+ wfp.close()
+ wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
+ wfp.write(line.encode("utf-8"))
+ wfp.write(b"\n")
+ total_num += 1
+wfp.close()
+print(
+ "Graverman/Instruct-to-Code preprocess done. Total line: {}, Total file: {}".format(
+ total_num, file_num
+ )
+)
+
+write_path = root_dir + "/instruction_data/part-sharegpt_90K-{}.jsonl.zst"
+total_num = 0
+file_num = 0
+wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
+with open("data/sg_90k_part1.json", "r") as fp:
+ data1 = json.load(fp)
+with open("data/sg_90k_part2.json", "r") as fp:
+ data2 = json.load(fp)
+data = data1 + data2
+for line in data:
+ line = json.dumps(line)
+ if total_num % 1024 == 0 and total_num > 0:
+ file_num += 1
+ wfp.close()
+ wfp = zstd.open(write_path.format(file_num), "wb", encoding="utf-8")
+ wfp.write(line.encode("utf-8"))
+ wfp.write(b"\n")
+ total_num += 1
+wfp.close()
+print(
+ "RyokoAI/ShareGPT52K preprocess done. Total line: {}, Total file: {}".format(
+ total_num, file_num
+ )
+)
diff --git a/dataset/collate_fn.py b/dataset/collate_fn.py
new file mode 100644
index 0000000..8e9648b
--- /dev/null
+++ b/dataset/collate_fn.py
@@ -0,0 +1,69 @@
+"""
+Author: LiangSong(sl12160010@gmail.com)
+Date: 2023-03-30 20:58:16
+LastEditors: LiangSong(sl12160010@gmail.com)
+LastEditTime: 2023-04-05 22:11:03
+FilePath: /Open-Llama/dataset/collate_fn.py
+Description:
+
+Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
+"""
+import torch
+
+
+def collate_fn_gen(tokenizer, segment_max_length=1024, padding="longest"):
+ """
+ Organize data into tensors by padding based on the preset maximum length.
+ """
+ pad_id = tokenizer.pad_id
+
+ def collate_fn(batch):
+ if padding == "longest":
+ max_length = max([len(i) for i in batch])
+ elif padding == "max_length":
+ max_length = segment_max_length
+ else:
+ raise Exception("Invalid argumet for padding: {}".format(padding))
+ input_ids = []
+ for i in batch:
+ input_len = len(i)
+ input_ids.append(i + [pad_id] * (max_length - input_len))
+ inputs = {
+ "input_ids": torch.tensor(input_ids, dtype=torch.int64),
+ }
+ return inputs
+
+ return collate_fn
+
+
+if __name__ == "__main__":
+ import sentencepiece as spm
+ from torch.utils.data import DataLoader
+
+ from dataset.pretrain_dataset import preprocess_wudao_gen, preprocess_the_pile_gen
+
+ from dataset.tokenizer import Tokenizer
+ from dataset.data_iter import create_shard_kwargs, DataIter
+
+ sp_model = spm.SentencePieceProcessor(
+ model_file="configs/10w_vocab_wudao5_pile10.model"
+ )
+ tokenizer = Tokenizer(sp_model)
+ patterns = ["data/pretrain_data/part-*.jsonl.zst"]
+ paths = create_shard_kwargs(patterns)
+ transform_dict = {
+ "wudao": preprocess_wudao_gen(tokenizer),
+ "pile": preprocess_the_pile_gen(tokenizer),
+ }
+ data_set = DataIter(paths, transform_dict=transform_dict)
+ train_loader = DataLoader(
+ data_set,
+ batch_size=8,
+ num_workers=4,
+ collate_fn=collate_fn_gen(tokenizer),
+ drop_last=True,
+ )
+ for batch in train_loader:
+ for k, v in batch.items():
+ print(k, v.shape)
+ break
diff --git a/dataset/data_iter.py b/dataset/data_iter.py
index de83cb5..48e21bc 100644
--- a/dataset/data_iter.py
+++ b/dataset/data_iter.py
@@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-17 19:32:20
LastEditors: LiangSong(sl12160010@gmail.com)
-LastEditTime: 2023-03-26 23:03:32
+LastEditTime: 2023-04-06 03:37:55
FilePath: /Open-Llama/dataset/data_iter.py
Description:
@@ -11,67 +11,84 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
import json
from glob import glob
import zstandard as zstd
+from torch.utils.data import IterableDataset
-def create_data_iter(paths, transform_dict=None, process_index=0, num_processes=1):
+class DataIter(IterableDataset):
"""
- Currently, the allowed storage formats are jsonl and jsonl.zst.
+ Currently, the allowed storage formats are jsonl.zst.
Each line of the data is a dictionary, which can be parsed as JSON for subsequent processing after reading.
+ Currently, only single worker is supported.
"""
- past = None
- for i, path in paths:
- dataset_name = path.split("-")[-2]
- if num_processes > 1 and i % num_processes != process_index:
- continue
- if past != dataset_name:
- print("Loading data from {}".format(path))
- past = path
- if path.endswith("jsonl.zst"):
+
+ def __init__(
+ self,
+ paths_with_index,
+ transform_dict=None,
+ max_length=None,
+ concat_docs=False,
+ process_index=0,
+ num_processes=1,
+ ):
+ super().__init__()
+ self.paths_with_index = paths_with_index
+ self.max_length = max_length
+ self.transform_dict = transform_dict
+ self.concat_docs = concat_docs
+ self.process_index = process_index
+ self.num_processes = num_processes
+ if self.concat_docs:
+ self.cache = []
+
+ def __iter__(self):
+ past = None
+ for i, path in self.paths_with_index:
+ # part-dataset_name-01.jsonl.zst
+ dataset_name = path.split("-")[-2]
+ # shard to multiple device
+ if self.num_processes > 1 and i % self.num_processes != self.process_index:
+ continue
+ # Log the file name when encountering a new file.
+ if past != dataset_name:
+ print("Loading data from {}".format(path))
+ past = path
+ # Currently, the allowed storage formats are jsonl.zst.
+ assert path.endswith("jsonl.zst")
with zstd.open(path, "r", encoding="utf-8") as fp:
for line in fp:
+ # If the length of the cache is greater than max_length.
+ if self.concat_docs and len(self.cache) >= self.max_length:
+ seq = self.cache[: self.max_length]
+ self.cache = self.cache[self.max_length :]
+ yield seq
if isinstance(line, bytes):
line = line.decode("utf-8")
line = json.loads(line)
line["dataset"] = dataset_name
- if transform_dict:
- line = transform_dict[dataset_name](line)
- if isinstance(line, str):
+ # Transformation, including sample, tokenize, etc.
+ if self.transform_dict:
+ line = self.transform_dict[dataset_name](line)
+ # skip bad doc
+ if line is None:
+ continue
+ elif isinstance(line, str):
yield line
- elif isinstance(line, list):
- for i in line:
- yield i
+ # must be list of list
+ elif isinstance(line, list) and isinstance(line[0], list):
+ for seq in line:
+ if self.concat_docs:
+ # concat seq from multiple docs
+ self.cache += seq
+ else:
+ yield seq
else:
raise Exception(
"Unsupported type in Transformation: {}".format(
- transform_dict[dataset_name]
+ self.transform_dict[dataset_name]
)
)
else:
yield line
- elif path.endswith("jsonl"):
- with open(path, "r") as fp:
- for line in fp:
- if isinstance(line, bytes):
- line = line.decode("utf-8")
- line = json.loads(line)
- line["dataset"] = dataset_name
- if transform_dict:
- line = transform_dict[dataset_name](line)
- if isinstance(line, str):
- yield line
- elif isinstance(line, list):
- for i in line:
- yield i
- else:
- raise Exception(
- "Unsupported type in Transformation: {}".format(
- transform_dict[dataset_name]
- )
- )
- else:
- yield line
- else:
- raise Exception("File format of {} is not supported yet.".format(path))
def create_shard_kwargs(patterns, repeat=1):
@@ -90,7 +107,9 @@ if __name__ == "__main__":
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
transform_dict = {"wudao": lambda x: x["title"], "pile": lambda x: [x["text"]]}
- data_iter = create_data_iter(paths, transform_dict=transform_dict)
+ data_iter = DataIter(
+ paths, transform_dict=transform_dict, max_length=16, concat_docs=True
+ )
for i, data in enumerate(data_iter):
print(i, data)
if i == 20:
diff --git a/dataset/data_loader.py b/dataset/data_loader.py
deleted file mode 100644
index 69f0fba..0000000
--- a/dataset/data_loader.py
+++ /dev/null
@@ -1,192 +0,0 @@
-"""
-Author: LiangSong(sl12160010@gmail.com)
-Date: 2023-03-30 20:58:16
-LastEditors: LiangSong(sl12160010@gmail.com)
-LastEditTime: 2023-03-30 21:00:49
-FilePath: /Open-Llama/dataset/data_loader.py
-Description:
-
-Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
-"""
-import math
-import torch
-
-
-def pretrain_collate_fn_gen(tokenizer, segment_max_length=1024, padding="longest"):
- """
- Organize data into tensors by padding based on the preset maximum length.
- """
- pad_id = tokenizer.pad_id
-
- def pretrain_collate_fn(batch):
- if padding == "longest":
- max_length = max([len(i) for i in batch])
- elif padding == "max_length":
- max_length = segment_max_length
- else:
- raise Exception("Invalid argumet for padding: {}".format(padding))
- input_ids = []
- for i in batch:
- input_len = len(i)
- input_ids.append(i + [pad_id] * (max_length - input_len))
- inputs = {
- "input_ids": torch.tensor(input_ids, dtype=torch.int64),
- }
- return inputs
-
- return pretrain_collate_fn
-
-
-class BySequenceLengthDataset(torch.utils.data.IterableDataset):
- """
- experimental
- """
-
- def __init__(
- self, generator, batch_size, accelerator=None, bucket_size=16, max_length=1024
- ):
- super().__init__()
- self.generator = generator
- self.batch_size = batch_size
- self.bucket_size = bucket_size
- self.bucket_num = math.ceil(max_length / bucket_size)
- self.buckets = [[] for _ in range(self.bucket_num)]
- self.bucket_idx = None
- self.accelerator = accelerator
- if self.accelerator is not None:
- self.buckets_ele_num = torch.tensor(
- [0] * self.bucket_num, dtype=torch.int64, device=accelerator.device
- )
- self.buckets_indexes = torch.arange(
- self.bucket_num, device=accelerator.device
- )
- self.finished = False
- self.has_no_same_bucket = False
- self.rest = None
-
- def __iter__(self):
- if self.batch_size <= 1:
- return self.generator
-
- def bucket_iter():
- while True:
- if self.bucket_idx is not None:
- sample = self.buckets[self.bucket_idx].pop()
- if len(self.buckets[self.bucket_idx]) == 0:
- self.bucket_idx = None
- yield sample
- try:
- sample = next(self.generator)
- except StopIteration:
- break
- sample_len = len(sample) - 1
- bucket_idx = sample_len // self.bucket_size
- if len(self.buckets[bucket_idx]) == self.batch_size - 1:
- self.bucket_idx = bucket_idx
- yield sample
- else:
- self.buckets[bucket_idx].append(sample)
-
- def parallel_bucket_iter():
- while True:
- if self.bucket_idx is not None:
- sample = self.buckets[self.bucket_idx].pop()
- self.buckets_ele_num[self.bucket_idx] -= 1
- buckets_ele_num = self.accelerator.gather(self.buckets_ele_num)
- buckets_ele_num = buckets_ele_num.reshape(
- self.accelerator.num_processes, self.bucket_num
- )
- min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
- if min_buckets_ele_num[self.bucket_idx] <= 0:
- self.bucket_idx = None
- yield sample
- else:
- if self.finished:
- if self.has_no_same_bucket:
- if self.rest is None:
- self.rest = []
- for bucket in self.buckets:
- for i in bucket:
- self.rest.append(i)
- elif len(self.rest) > 0:
- yield self.rest.pop()
- else:
- raise StopIteration()
- else:
- buckets_ele_num = self.accelerator.gather(
- self.buckets_ele_num
- )
- buckets_ele_num = buckets_ele_num.view(
- self.accelerator.num_processes, self.bucket_num
- )
- min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
- valid_bucket_idx = self.buckets_indexes[
- min_buckets_ele_num >= self.batch_size
- ]
- if len(valid_bucket_idx) > 0:
- self.bucket_idx = valid_bucket_idx[0].cpu().item()
- else:
- self.has_no_same_bucket = True
- else:
- try:
- sample = next(self.generator)
- except StopIteration:
- self.finished = True
- continue
- sample_len = len(sample) - 1
- bucket_idx = sample_len // self.bucket_size
- self.buckets[bucket_idx].append(sample)
- self.buckets_ele_num[bucket_idx] += 1
- buckets_ele_num = self.accelerator.gather(
- self.buckets_ele_num
- ).cpu()
- buckets_ele_num = buckets_ele_num.view(
- self.accelerator.num_processes, self.bucket_num
- )
- min_buckets_ele_num = buckets_ele_num.min(dim=0)[0]
- valid_bucket_idx = self.buckets_indexes[
- min_buckets_ele_num >= self.batch_size
- ]
- if len(valid_bucket_idx) > 0:
- self.bucket_idx = valid_bucket_idx[0].cpu().item()
-
- if self.accelerator:
- return parallel_bucket_iter()
- else:
- return bucket_iter()
-
-
-if __name__ == "__main__":
- import sentencepiece as spm
- from datasets import IterableDataset
- from torch.utils.data import DataLoader
-
- from dataset.pretrain_dataset import preprocess_wudao_gen, preprocess_the_pile_gen
-
- from dataset.tokenizer import Tokenizer
- from dataset.data_iter import create_shard_kwargs, create_data_iter
-
- sp_model = spm.SentencePieceProcessor(
- model_file="configs/10w_vocab_wudao5_pile10.model"
- )
- tokenizer = Tokenizer(sp_model)
- patterns = ["data/pretrain_data/part-*.jsonl.zst"]
- paths = create_shard_kwargs(patterns)
- transform_dict = {
- "wudao": preprocess_wudao_gen(tokenizer),
- "pile": preprocess_the_pile_gen(tokenizer),
- }
- data_set = IterableDataset.from_generator(
- create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
- )
- train_loader = DataLoader(
- data_set,
- batch_size=8,
- num_workers=4,
- collate_fn=pretrain_collate_fn_gen(tokenizer),
- drop_last=True,
- )
- for batch in train_loader:
- for k, v in batch.items():
- print(k, v.shape)
- break
diff --git a/dataset/instruction_dataset.py b/dataset/instruction_dataset.py
index 3b37aaa..9262023 100644
--- a/dataset/instruction_dataset.py
+++ b/dataset/instruction_dataset.py
@@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 21:02:00
LastEditors: LiangSong(sl12160010@gmail.com)
-LastEditTime: 2023-03-30 21:02:06
+LastEditTime: 2023-04-06 03:33:27
FilePath: /Open-Llama/dataset/instruction_dataset.py
Description:
@@ -21,7 +21,7 @@ def preprocess_self_instruction_gen(tokenizer, segment_max_length=1024):
prompt = line["prompt"]
if prompt.endswith("Output:"):
prompt = prompt[:-7]
- total = "user:{}system:{}".format(prompt.strip(), line["completion"].strip())
+ total = "user:{}\nsystem:{}".format(prompt.strip(), line["completion"].strip())
out = tokenizer(total)
input_ids = out["input_ids"]
return [
@@ -39,12 +39,12 @@ def preprocess_belle_gen(tokenizer, segment_max_length=1024):
{'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
Split the data based on the tokenized length according to the maximum length.
"""
- prompt = line["input"].replace("\\n", "")
+ prompt = line["instruction"].replace("\\n", "")
prompt = prompt.strip("")
- completion = line["target"].replace("\\n", "")
+ completion = line["output"].replace("\\n", "")
completion = completion.strip("")
- total = "user:{}system:{}".format(prompt, completion)
+ total = "user:{}\nsystem:{}".format(prompt, completion)
out = tokenizer(total)
input_ids = out["input_ids"]
return [
@@ -55,28 +55,124 @@ def preprocess_belle_gen(tokenizer, segment_max_length=1024):
return preprocess_belle
+def preprocess_belle_multiturn_chat_gen(tokenizer, segment_max_length=1024):
+ def preprocess_belle_multiturn_chat(line):
+ """
+ The format of the data is roughly as follows.
+ {'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
+ Split the data based on the tokenized length according to the maximum length.
+ """
+ prompt = line["instruction"].replace("\\n", "")
+ prompt = prompt.strip("")
+
+ completion = line["output"].replace("\\n", "")
+ completion = completion.strip("")
+ chats = prompt + completion
+ chats = chats.split("Human:")
+ input_ids = []
+ for chat in chats:
+ if chat.strip() == "":
+ continue
+ res = chat.split("Assistant:")
+ if len(res) != 2:
+ continue
+ prompt, completion = res
+ prompt = prompt.strip()
+ completion = completion.strip()
+ chat = "user:{}\nsystem:{}".format(prompt, completion)
+ out = tokenizer(chat)
+ input_ids.extend(out["input_ids"])
+ if len(input_ids) == 0:
+ return None
+ return [
+ input_ids[i * segment_max_length : (i + 1) * segment_max_length]
+ for i in range(math.ceil(len(input_ids) / segment_max_length))
+ ]
+
+ return preprocess_belle_multiturn_chat
+
+
+def preprocess_sharegpt_gen(tokenizer, segment_max_length=1024):
+ def preprocess_sharegpt(line):
+ """
+ The format of the data is roughly as follows.
+ {'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
+ Split the data based on the tokenized length according to the maximum length.
+ """
+ chats = line["conversations"]
+ if chats[0]["from"] != "human":
+ chats = chats[1:]
+ input_ids = []
+ for i in range(len(chats) // 2):
+ prompt = chats[2 * i]
+ completion = chats[2 * i + 1]
+ if not (prompt["from"] == "human" and completion["from"] == "gpt"):
+ continue
+ prompt = prompt["value"]
+ prompt = prompt.strip()
+ completion = completion["value"]
+ completion = completion.strip()
+ chat = "user:{}\nsystem:{}".format(prompt, completion)
+ out = tokenizer(chat)
+ input_ids.extend(out["input_ids"])
+ if input_ids == []:
+ return None
+ return [
+ input_ids[i * segment_max_length : (i + 1) * segment_max_length]
+ for i in range(math.ceil(len(input_ids) / segment_max_length))
+ ]
+
+ return preprocess_sharegpt
+
+
+def preprocess_instruct_code_gen(tokenizer, segment_max_length=1024):
+ def preprocess_instruct_code(line):
+ """
+ The format of the data is roughly as follows.
+ {'text': 'some text', 'meta': {'pile_set_name': 'Github'}}
+ Split the data based on the tokenized length according to the maximum length.
+ """
+ prompt = line["instruction"].replace("\\n", "")
+ prompt = prompt.strip("")
+
+ completion = line["answer"].replace("\\n", "")
+ completion = completion.strip("")
+ total = "user:{}\nsystem:{}".format(prompt, completion)
+ out = tokenizer(total)
+ input_ids = out["input_ids"]
+ return [
+ input_ids[i * segment_max_length : (i + 1) * segment_max_length]
+ for i in range(math.ceil(len(input_ids) / segment_max_length))
+ ]
+
+ return preprocess_instruct_code
+
+
if __name__ == "__main__":
import sentencepiece as spm
- from datasets import IterableDataset
from dataset.tokenizer import Tokenizer
- from dataset.data_iter import create_shard_kwargs, create_data_iter
+ from dataset.data_iter import create_shard_kwargs, DataIter
sp_model = spm.SentencePieceProcessor(
model_file="configs/10w_vocab_wudao5_pile10.model"
)
tokenizer = Tokenizer(sp_model)
- patterns = ["data/instruction_data/part-belle_1M*.jsonl.zst"]
+ patterns = ["data/instruction_data/part-belle_multiturn_chat_0.8M-*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
transform_dict = {
+ "self_instruct": preprocess_self_instruction_gen(tokenizer),
"belle_1M": preprocess_belle_gen(tokenizer),
"belle_0.5M": preprocess_belle_gen(tokenizer),
- "self_instruct": preprocess_self_instruction_gen(tokenizer),
+ "belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
+ "belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
+ "instruct_to_code": preprocess_instruct_code_gen(tokenizer),
+ "sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
}
- data_set = IterableDataset.from_generator(
- create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
+ data_set = DataIter(
+ paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
)
for i, sample in enumerate(data_set):
- print(sample, sp_model.Decode(sample))
- if i == 20:
+ print(sp_model.decode(sample))
+ if i == 1:
break
diff --git a/dataset/pretrain_dataset.py b/dataset/pretrain_dataset.py
index 1b1885c..812d9ba 100644
--- a/dataset/pretrain_dataset.py
+++ b/dataset/pretrain_dataset.py
@@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-17 20:41:25
LastEditors: LiangSong(sl12160010@gmail.com)
-LastEditTime: 2023-03-26 23:07:56
+LastEditTime: 2023-04-05 22:32:39
FilePath: /Open-Llama/dataset/pretrain_dataset.py
Description:
@@ -49,10 +49,9 @@ def preprocess_the_pile_gen(tokenizer, segment_max_length=1024):
if __name__ == "__main__":
import sentencepiece as spm
- from datasets import IterableDataset
from dataset.tokenizer import Tokenizer
- from dataset.data_iter import create_shard_kwargs, create_data_iter
+ from dataset.data_iter import create_shard_kwargs, DataIter
sp_model = spm.SentencePieceProcessor(
model_file="configs/10w_vocab_wudao5_pile10.model"
@@ -64,8 +63,8 @@ if __name__ == "__main__":
"wudao": preprocess_wudao_gen(tokenizer),
"pile": preprocess_the_pile_gen(tokenizer),
}
- data_set = IterableDataset.from_generator(
- create_data_iter, gen_kwargs={"paths": paths, "transform_dict": transform_dict}
+ data_set = DataIter(
+ paths, transform_dict=transform_dict, concat_docs=True, max_length=1024
)
for sample in data_set:
print(sample)
diff --git a/dataset/tokenizer.py b/dataset/tokenizer.py
index 044a973..4a11aab 100644
--- a/dataset/tokenizer.py
+++ b/dataset/tokenizer.py
@@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-20 21:39:47
LastEditors: LiangSong(sl12160010@gmail.com)
-LastEditTime: 2023-03-26 23:09:39
+LastEditTime: 2023-04-06 23:01:50
FilePath: /Open-Llama/dataset/tokenizer.py
Description:
@@ -145,14 +145,34 @@ class Tokenizer:
out["attention_mask"] = attention_mask
return out
- def decode(self, inputs):
+ def decode(self, inputs, max_rounds=None):
inputs = inputs.tolist()
out = []
- for i in inputs:
- if self.eos_id in i:
- eos_idx = i.index(self.eos_id)
- i = i[:eos_idx]
- out.append(i)
+ for i, ids in enumerate(inputs):
+ count = 0
+ flag = False
+ for j, token in enumerate(ids):
+ if token == self.eos_id:
+ if max_rounds is None:
+ flag = True
+ break
+ elif isinstance(max_rounds, int):
+ if count < max_rounds:
+ count += 1
+ else:
+ flag = True
+ break
+ elif isinstance(max_rounds, list):
+ if count < max_rounds[i]:
+ count += 1
+ else:
+ flag = True
+ break
+ if flag:
+ ids = ids[:j]
+ else:
+ ids = ids
+ out.append(ids)
out = self.sp_model.Decode(out)
return out
@@ -183,11 +203,11 @@ if __name__ == "__main__":
for i, j in zip(tmp, out):
assert normalize("NFKC", i) == j
- from dataset.data_iter import create_shard_kwargs, create_data_iter
+ from dataset.data_iter import create_shard_kwargs, DataIter
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
paths = create_shard_kwargs(patterns)
- data_iter = create_data_iter(paths)
+ data_iter = DataIter(paths)
for i, data in enumerate(data_iter):
assert (
normalize("NFKC", data["content"])
diff --git a/dataset/train_tokenizer.py b/dataset/train_tokenizer.py
index 609c8d9..7d4b6c8 100644
--- a/dataset/train_tokenizer.py
+++ b/dataset/train_tokenizer.py
@@ -2,14 +2,14 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-24 20:49:03
LastEditors: LiangSong(sl12160010@gmail.com)
-LastEditTime: 2023-03-26 23:43:59
+LastEditTime: 2023-04-05 22:40:29
FilePath: /Open-Llama/dataset/train_tokenizer.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import random
-from dataset.data_iter import create_data_iter, create_shard_kwargs
+from dataset.data_iter import DataIter, create_shard_kwargs
wudao_patterns = [
"data/pretrain_data/part-wudao-*.jsonl.zst",
@@ -24,10 +24,10 @@ pile_paths = create_shard_kwargs(pile_patterns)
random.shuffle(pile_paths)
paths = wudao_paths[:5] + pile_paths[:10]
transform_dict = {
- "wudao": lambda line: [(line["title"] + "\n" + line["content"])],
- "pile": lambda line: [line["text"]],
+ "wudao": lambda line: line["title"] + "\n" + line["content"],
+ "pile": lambda line: line["text"],
}
-data_iter = create_data_iter(paths, transform_dict)
+data_iter = iter(DataIter(paths, transform_dict))
import io
import sentencepiece as spm
diff --git a/inctruction_tuning.py b/inctruction_tuning.py
index 4e7ff6c..9ef01aa 100644
--- a/inctruction_tuning.py
+++ b/inctruction_tuning.py
@@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-30 21:35:01
LastEditors: LiangSong(sl12160010@gmail.com)
-LastEditTime: 2023-03-30 21:40:03
+LastEditTime: 2023-04-06 03:35:31
FilePath: /Open-Llama/inctruction_tuning.py
Description:
@@ -16,18 +16,20 @@ import random
import sentencepiece as spm
from torchinfo import summary
from accelerate import Accelerator
-from datasets import IterableDataset
from torch.utils.data import DataLoader
from deepspeed.ops.adam import FusedAdam
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
from dataset.validation import val_set
from dataset.tokenizer import Tokenizer
-from dataset.data_iter import create_shard_kwargs, create_data_iter
-from dataset.data_loader import pretrain_collate_fn_gen
+from dataset.data_iter import create_shard_kwargs, DataIter
+from dataset.collate_fn import collate_fn_gen
from dataset.instruction_dataset import (
preprocess_belle_gen,
preprocess_self_instruction_gen,
+ preprocess_belle_multiturn_chat_gen,
+ preprocess_instruct_code_gen,
+ preprocess_sharegpt_gen,
)
from configs.instruction_tuning_config import *
@@ -46,25 +48,28 @@ tokenizer = Tokenizer(sp_model)
paths = create_shard_kwargs(patterns, repeat=3)
random.shuffle(paths)
transform_dict = {
- "belle_1M": preprocess_belle_gen(tokenizer, max_length),
- "belle_0.5M": preprocess_belle_gen(tokenizer, max_length),
- "self_instruct": preprocess_self_instruction_gen(tokenizer, max_length),
+ "self_instruct": preprocess_self_instruction_gen(tokenizer),
+ "belle_1M": preprocess_belle_gen(tokenizer),
+ "belle_0.5M": preprocess_belle_gen(tokenizer),
+ "belle_school_math_0.25M": preprocess_belle_gen(tokenizer),
+ "belle_multiturn_chat_0.8M": preprocess_belle_multiturn_chat_gen(tokenizer),
+ "instruct_to_code": preprocess_instruct_code_gen(tokenizer),
+ "sharegpt_90K": preprocess_sharegpt_gen(tokenizer),
}
-data_set = IterableDataset.from_generator(
- create_data_iter,
- gen_kwargs={
- "paths": paths,
- "transform_dict": transform_dict,
- "process_index": accelerator.process_index,
- "num_processes": accelerator.num_processes,
- },
+data_set = DataIter(
+ paths,
+ transform_dict=transform_dict,
+ concat_docs=True,
+ max_length=max_length,
+ process_index=accelerator.process_index,
+ num_processes=accelerator.num_processes,
)
train_loader = DataLoader(
data_set,
batch_size=train_batch_size,
# If num_workers is greater than 1, duplicate data may occur.
num_workers=0,
- collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
+ collate_fn=collate_fn_gen(tokenizer, max_length),
drop_last=True,
)
# smaller initializer_range make training more stable
diff --git a/pretrain_llama.py b/pretrain_llama.py
index 17a48f8..18b92fb 100644
--- a/pretrain_llama.py
+++ b/pretrain_llama.py
@@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-17 14:27:28
LastEditors: LiangSong(sl12160010@gmail.com)
-LastEditTime: 2023-03-27 01:07:25
+LastEditTime: 2023-04-05 22:46:31
FilePath: /Open-Llama/pretrain_llama.py
Description:
pretrain GPT
@@ -16,15 +16,14 @@ import random
import sentencepiece as spm
from torchinfo import summary
from accelerate import Accelerator
-from datasets import IterableDataset
from torch.utils.data import DataLoader
from deepspeed.ops.adam import FusedAdam
from transformers import LlamaForCausalLM, LlamaConfig, get_cosine_schedule_with_warmup
from dataset.validation import val_set
from dataset.tokenizer import Tokenizer
-from dataset.data_iter import create_shard_kwargs, create_data_iter
-from dataset.data_loader import pretrain_collate_fn_gen
+from dataset.data_iter import create_shard_kwargs, DataIter
+from dataset.collate_fn import collate_fn_gen
from dataset.pretrain_dataset import (
preprocess_the_pile_gen,
preprocess_wudao_gen,
@@ -49,21 +48,20 @@ transform_dict = {
"wudao": preprocess_wudao_gen(tokenizer, max_length),
"pile": preprocess_the_pile_gen(tokenizer, max_length),
}
-data_set = IterableDataset.from_generator(
- create_data_iter,
- gen_kwargs={
- "paths": paths,
- "transform_dict": transform_dict,
- "process_index": accelerator.process_index,
- "num_processes": accelerator.num_processes,
- },
+data_set = DataIter(
+ paths,
+ transform_dict=transform_dict,
+ concat_docs=True,
+ max_length=max_length,
+ process_index=accelerator.process_index,
+ num_processes=accelerator.num_processes,
)
train_loader = DataLoader(
data_set,
batch_size=train_batch_size,
# If num_workers is greater than 1, duplicate data may occur.
num_workers=0,
- collate_fn=pretrain_collate_fn_gen(tokenizer, max_length),
+ collate_fn=collate_fn_gen(tokenizer, max_length),
drop_last=True,
)
# smaller initializer_range make training more stable
diff --git a/server.py b/server.py
index 77510ef..5d4ea82 100644
--- a/server.py
+++ b/server.py
@@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-31 13:26:15
LastEditors: LiangSong(sl12160010@gmail.com)
-LastEditTime: 2023-03-31 14:05:35
+LastEditTime: 2023-04-06 03:45:44
FilePath: /Open-Llama/server.py
Description:
@@ -32,7 +32,9 @@ raw_model = LlamaForCausalLM(
shared_input_output_embedding=True,
)
)
-ckpt = torch.load("data/saved_ckpt/instruction_tuning_3_epochs/23001.pt", map_location="cpu")
+ckpt = torch.load(
+ "data/saved_ckpt/instruction_tuning_3_epochs/37001.pt", map_location="cpu"
+)
raw_model.load_state_dict(ckpt)
raw_model.eval()
model = raw_model.cuda()
@@ -41,7 +43,7 @@ print("ready")
def question_answer(prompt):
print(prompt)
- raw_inputs = "user:{}system:".format(prompt)
+ raw_inputs = "user:{}\nsystem:".format(prompt)
inputs_len = len(raw_inputs)
inputs = tokenizer(raw_inputs, return_tensors=True, add_special_tokens=False)
for k, v in inputs.items():