2023-04-07 02:04:05 +00:00
|
|
|
|
"""
|
2023-05-17 15:21:46 +00:00
|
|
|
|
Author: s-JoL(sl12160010@gmail.com)
|
2023-04-07 02:04:05 +00:00
|
|
|
|
Date: 2023-04-06 22:30:10
|
2023-05-17 15:21:46 +00:00
|
|
|
|
LastEditors: s-JoL(sl12160010@gmail.com)
|
2023-05-12 07:07:46 +00:00
|
|
|
|
LastEditTime: 2023-05-12 15:07:36
|
2023-04-07 02:04:05 +00:00
|
|
|
|
FilePath: /Open-Llama/chat_server.py
|
|
|
|
|
Description:
|
|
|
|
|
|
2023-05-17 15:21:46 +00:00
|
|
|
|
Copyright (c) 2023 by s-JoL(sl12160010@gmail.com), All Rights Reserved.
|
2023-04-07 02:04:05 +00:00
|
|
|
|
"""
|
|
|
|
|
import torch
|
2023-04-29 12:28:39 +00:00
|
|
|
|
import logging
|
2023-04-07 02:04:05 +00:00
|
|
|
|
import gradio as gr
|
2023-05-12 07:07:46 +00:00
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2023-04-07 02:04:05 +00:00
|
|
|
|
|
|
|
|
|
|
2023-05-12 07:07:46 +00:00
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("s-JoL/Open-Llama-V2", use_fast=False)
|
2023-05-17 15:21:46 +00:00
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
"s-JoL/Open-Llama-V2", torch_dtype=torch.bfloat16, device_map="auto"
|
|
|
|
|
)
|
2023-05-06 15:37:17 +00:00
|
|
|
|
logging.warning("ready")
|
2023-04-07 02:04:05 +00:00
|
|
|
|
|
2023-04-07 15:20:20 +00:00
|
|
|
|
|
2023-04-07 02:04:05 +00:00
|
|
|
|
with gr.Blocks() as demo:
|
2023-04-07 15:19:42 +00:00
|
|
|
|
gr.Markdown(
|
|
|
|
|
"""
|
2023-05-17 15:21:46 +00:00
|
|
|
|
# [Open-Llama](https://github.com/s-JoL/Open-Llama)
|
2023-04-07 15:19:42 +00:00
|
|
|
|
完全使用Open-Llama项目从0开始训练的Instruct-GPT模型,当长时间无响应(如20s以上)可刷新重试。
|
|
|
|
|
|
|
|
|
|
Instruct-GPT model is trained from scratch using the Open-Llama project without relying on any other pre-trained models. If there is no response for a long time (such as more than 20 seconds), please refresh and try again.
|
|
|
|
|
"""
|
|
|
|
|
)
|
2023-04-07 02:04:05 +00:00
|
|
|
|
chatbot = gr.Chatbot()
|
|
|
|
|
msg = gr.Textbox()
|
|
|
|
|
clear = gr.Button("Clear")
|
|
|
|
|
|
|
|
|
|
def user(user_message, history):
|
2023-05-06 15:37:17 +00:00
|
|
|
|
logging.warning(user_message)
|
2023-04-07 02:04:05 +00:00
|
|
|
|
return "", history + [[user_message, None]]
|
|
|
|
|
|
|
|
|
|
def bot(history):
|
|
|
|
|
context = []
|
|
|
|
|
round = 0
|
|
|
|
|
for prompt, completion in history:
|
|
|
|
|
round += 1
|
|
|
|
|
if completion is None:
|
2023-04-07 15:20:20 +00:00
|
|
|
|
inputs = "user:{}\nsystem:".format(prompt)
|
|
|
|
|
inputs = tokenizer(
|
2023-04-26 16:04:11 +00:00
|
|
|
|
inputs,
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
add_special_tokens=False,
|
|
|
|
|
return_attention_mask=False,
|
2023-04-07 15:20:20 +00:00
|
|
|
|
)
|
|
|
|
|
context.append(inputs["input_ids"])
|
2023-04-07 02:04:05 +00:00
|
|
|
|
else:
|
2023-04-07 15:20:20 +00:00
|
|
|
|
inputs = "user:{}\nsystem:{}".format(prompt, completion)
|
2023-04-26 16:04:11 +00:00
|
|
|
|
inputs = tokenizer(
|
|
|
|
|
inputs,
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
add_special_tokens=True,
|
|
|
|
|
return_attention_mask=False,
|
|
|
|
|
)
|
2023-04-07 15:20:20 +00:00
|
|
|
|
context.append(inputs["input_ids"])
|
2023-04-07 02:04:05 +00:00
|
|
|
|
context = torch.cat(context, dim=-1)
|
2023-04-07 15:20:20 +00:00
|
|
|
|
context = context[:, -1024:]
|
2023-04-07 02:04:05 +00:00
|
|
|
|
inputs_len = context.shape[1]
|
2023-05-04 14:32:15 +00:00
|
|
|
|
context = context.cuda()
|
|
|
|
|
pred = model.generate(input_ids=context, max_new_tokens=1024, do_sample=True)
|
2023-04-07 02:04:05 +00:00
|
|
|
|
pred = pred[:, inputs_len:]
|
2023-04-27 11:42:06 +00:00
|
|
|
|
pred = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
|
2023-05-06 15:37:17 +00:00
|
|
|
|
logging.warning(pred)
|
2023-05-04 10:18:52 +00:00
|
|
|
|
bot_message = pred
|
2023-04-07 02:04:05 +00:00
|
|
|
|
history[-1][1] = bot_message
|
|
|
|
|
return history
|
|
|
|
|
|
|
|
|
|
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
|
|
|
|
bot, chatbot, chatbot
|
|
|
|
|
)
|
|
|
|
|
clear.click(lambda: None, None, chatbot, queue=False)
|
2023-04-07 15:19:42 +00:00
|
|
|
|
gr.Markdown(
|
2023-04-07 15:20:20 +00:00
|
|
|
|
"""
|
2023-04-07 15:19:42 +00:00
|
|
|
|
当前体验服务生成的所有内容都是由人工智能模型生成,我们对其生成内容的准确性、完整性和功能性不做任何保证,并且其生成的内容不代表我们的态度或观点。
|
|
|
|
|
|
|
|
|
|
联系方式: sl12160010@gmail.com 对于该项目有任何意见和建议都欢迎联系我.
|
2023-05-04 14:54:10 +00:00
|
|
|
|
|
2023-04-07 15:19:42 +00:00
|
|
|
|
Contact information: sl12160010@gmail.com. Any opinions or suggestions regarding the project are welcome to be addressed to me through this email.
|
|
|
|
|
"""
|
|
|
|
|
)
|
2023-04-07 02:04:05 +00:00
|
|
|
|
|
2023-05-12 07:07:46 +00:00
|
|
|
|
demo.launch()
|