diff --git a/server.py b/server.py new file mode 100644 index 0000000..0a08bbe --- /dev/null +++ b/server.py @@ -0,0 +1,59 @@ +""" +Author: LiangSong(sl12160010@gmail.com) +Date: 2023-03-31 13:26:15 +LastEditors: LiangSong(sl12160010@gmail.com) +LastEditTime: 2023-03-31 14:05:35 +FilePath: /Open-Llama/server.py +Description: + +Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. +""" +import torch +import gradio as gr +import sentencepiece as spm +from dataset.tokenizer import Tokenizer +from transformers import LlamaForCausalLM, LlamaConfig + + +sp_model = spm.SentencePieceProcessor( + model_file="configs/10w_vocab_wudao5_pile10.model" +) +tokenizer = Tokenizer(sp_model) + +raw_model = LlamaForCausalLM( + LlamaConfig( + vocab_size=tokenizer.vocab_size, + initializer_range=0.01, + pad_token_id=tokenizer.pad_id, + rms_norm_eps=1e-5, + hidden_dropout_prob=0.1, + attention_dropout_prob=0.1, + use_stable_embedding=True, + shared_input_output_embedding=True, + ) +) +ckpt = torch.load("data/saved_ckpt/instruction_tuning/12001.pt", map_location="cpu") +raw_model.load_state_dict(ckpt) +raw_model.eval() +model = raw_model.cuda() +print("ready") + + +def question_answer(prompt): + print(prompt) + raw_inputs = "user:{}system:".format(prompt) + inputs_len = len(raw_inputs) + inputs = tokenizer(raw_inputs, return_tensors=True, add_special_tokens=False) + for k, v in inputs.items(): + inputs[k] = v.cuda() + pred = model.generate(**inputs, max_new_tokens=512, do_sample=True) + pred = tokenizer.decode(pred.cpu())[0] + pred = pred[inputs_len:] + print(pred) + return pred + + +demo = gr.Interface(fn=question_answer, inputs="text", outputs="text").queue( + concurrency_count=1 +) +demo.launch(share=True)