Open-Llama/chat_server.py

126 lines
4.2 KiB
Python
Raw Normal View History

2023-04-07 02:04:05 +00:00
"""
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-06 22:30:10
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-26 23:58:23
2023-04-07 02:04:05 +00:00
FilePath: /Open-Llama/chat_server.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import torch
import gradio as gr
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
2023-04-07 02:04:05 +00:00
tokenizer = LlamaTokenizer(
"configs/10w_vocab_wudao5_pile10.model",
pad_token="<pad>",
add_bos_token=False,
add_eos_token=True,
2023-04-07 02:04:05 +00:00
)
raw_model = OpenLlamaForCausalLM(
OpenLlamaConfig(
2023-04-07 02:04:05 +00:00
vocab_size=tokenizer.vocab_size,
initializer_range=0.01,
pad_token_id=tokenizer.pad_token_id,
2023-04-07 02:04:05 +00:00
rms_norm_eps=1e-5,
hidden_dropout_prob=0.1,
attention_dropout_prob=0.1,
use_stable_embedding=True,
shared_input_output_embedding=True,
)
)
ckpt = torch.load(
2023-04-07 15:20:20 +00:00
"data/saved_ckpt/instruction_tuning_math_code_multiturn/36001.pt",
map_location="cpu",
2023-04-07 02:04:05 +00:00
)
raw_model.load_state_dict(ckpt)
raw_model.eval()
model = raw_model.cuda()
print("ready")
2023-04-07 15:20:20 +00:00
2023-04-07 15:19:42 +00:00
def parse_codeblock(text):
lines = text.split("\n")
for i, line in enumerate(lines):
if "```" in line:
if line != "```":
lines[i] = f'<pre><code class="{lines[i][3:]}">'
else:
2023-04-07 15:20:20 +00:00
lines[i] = "</code></pre>"
2023-04-07 15:19:42 +00:00
else:
if i > 0:
lines[i] = "<br/>" + line.replace("<", "&lt;").replace(">", "&gt;")
return "".join(lines)
2023-04-07 15:20:20 +00:00
2023-04-07 02:04:05 +00:00
with gr.Blocks() as demo:
2023-04-07 15:19:42 +00:00
gr.Markdown(
"""
# [Open-Llama](https://github.com/Bayes-Song/Open-Llama)
完全使用Open-Llama项目从0开始训练的Instruct-GPT模型当长时间无响应如20s以上可刷新重试
Instruct-GPT model is trained from scratch using the Open-Llama project without relying on any other pre-trained models. If there is no response for a long time (such as more than 20 seconds), please refresh and try again.
"""
)
2023-04-07 02:04:05 +00:00
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
def user(user_message, history):
2023-04-07 15:19:42 +00:00
print(user_message)
2023-04-07 02:04:05 +00:00
return "", history + [[user_message, None]]
def bot(history):
context = []
round = 0
for prompt, completion in history:
round += 1
if completion is None:
2023-04-07 15:20:20 +00:00
inputs = "user:{}\nsystem:".format(prompt)
inputs = tokenizer(
inputs,
return_tensors="pt",
add_special_tokens=False,
return_attention_mask=False,
2023-04-07 15:20:20 +00:00
)
context.append(inputs["input_ids"])
2023-04-07 02:04:05 +00:00
else:
2023-04-07 15:20:20 +00:00
inputs = "user:{}\nsystem:{}".format(prompt, completion)
inputs = tokenizer(
inputs,
return_tensors="pt",
add_special_tokens=True,
return_attention_mask=False,
)
2023-04-07 15:20:20 +00:00
context.append(inputs["input_ids"])
2023-04-07 02:04:05 +00:00
context = torch.cat(context, dim=-1)
2023-04-07 15:20:20 +00:00
context = context[:, -1024:]
2023-04-07 02:04:05 +00:00
inputs_len = context.shape[1]
context = context.cuda()
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
pred = pred[:, inputs_len:]
pred = tokenizer.decode(pred.cpu()[0])
pred = pred.strip()
2023-04-07 15:19:42 +00:00
print(pred)
bot_message = parse_codeblock(pred)
2023-04-07 02:04:05 +00:00
history[-1][1] = bot_message
return history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
2023-04-07 15:19:42 +00:00
gr.Markdown(
2023-04-07 15:20:20 +00:00
"""
2023-04-07 15:19:42 +00:00
当前体验服务生成的所有内容都是由人工智能模型生成我们对其生成内容的准确性完整性和功能性不做任何保证并且其生成的内容不代表我们的态度或观点
联系方式: sl12160010@gmail.com 对于该项目有任何意见和建议都欢迎联系我.
Contact information: sl12160010@gmail.com. Any opinions or suggestions regarding the project are welcome to be addressed to me through this email.
"""
)
2023-04-07 02:04:05 +00:00
2023-04-07 15:19:42 +00:00
demo.launch(share=True)