add server
This commit is contained in:
parent
b9bc7eaf35
commit
5dc1e77c66
59
server.py
Normal file
59
server.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
"""
|
||||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-03-31 13:26:15
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-03-31 14:05:35
|
||||
FilePath: /Open-Llama/server.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import torch
|
||||
import gradio as gr
|
||||
import sentencepiece as spm
|
||||
from dataset.tokenizer import Tokenizer
|
||||
from transformers import LlamaForCausalLM, LlamaConfig
|
||||
|
||||
|
||||
sp_model = spm.SentencePieceProcessor(
|
||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||||
)
|
||||
tokenizer = Tokenizer(sp_model)
|
||||
|
||||
raw_model = LlamaForCausalLM(
|
||||
LlamaConfig(
|
||||
vocab_size=tokenizer.vocab_size,
|
||||
initializer_range=0.01,
|
||||
pad_token_id=tokenizer.pad_id,
|
||||
rms_norm_eps=1e-5,
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_dropout_prob=0.1,
|
||||
use_stable_embedding=True,
|
||||
shared_input_output_embedding=True,
|
||||
)
|
||||
)
|
||||
ckpt = torch.load("data/saved_ckpt/instruction_tuning/12001.pt", map_location="cpu")
|
||||
raw_model.load_state_dict(ckpt)
|
||||
raw_model.eval()
|
||||
model = raw_model.cuda()
|
||||
print("ready")
|
||||
|
||||
|
||||
def question_answer(prompt):
|
||||
print(prompt)
|
||||
raw_inputs = "user:{}<s>system:".format(prompt)
|
||||
inputs_len = len(raw_inputs)
|
||||
inputs = tokenizer(raw_inputs, return_tensors=True, add_special_tokens=False)
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.cuda()
|
||||
pred = model.generate(**inputs, max_new_tokens=512, do_sample=True)
|
||||
pred = tokenizer.decode(pred.cpu())[0]
|
||||
pred = pred[inputs_len:]
|
||||
print(pred)
|
||||
return pred
|
||||
|
||||
|
||||
demo = gr.Interface(fn=question_answer, inputs="text", outputs="text").queue(
|
||||
concurrency_count=1
|
||||
)
|
||||
demo.launch(share=True)
|
Loading…
Reference in New Issue
Block a user