update server

This commit is contained in:
LiangSong 2023-05-12 15:07:46 +08:00
parent 7231d53ca4
commit e18ead00cc

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com) Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-06 22:30:10 Date: 2023-04-06 22:30:10
LastEditors: LiangSong(sl12160010@gmail.com) LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-05-06 23:30:57 LastEditTime: 2023-05-12 15:07:36
FilePath: /Open-Llama/chat_server.py FilePath: /Open-Llama/chat_server.py
Description: Description:
@ -11,37 +11,11 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
import torch import torch
import logging import logging
import gradio as gr import gradio as gr
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = LlamaTokenizer( tokenizer = AutoTokenizer.from_pretrained("s-JoL/Open-Llama-V2", use_fast=False)
"configs/tokenizer_models/10w_vocab_wudao5_pile10.model", model = AutoModelForCausalLM.from_pretrained("s-JoL/Open-Llama-V2", torch_dtype=torch.bfloat16, device_map="auto")
pad_token="<pad>",
add_bos_token=False,
add_eos_token=True,
)
raw_model = OpenLlamaForCausalLM(
OpenLlamaConfig(
vocab_size=tokenizer.vocab_size,
initializer_range=0.01,
pad_token_id=tokenizer.pad_token_id,
rms_norm_eps=1e-5,
hidden_dropout_prob=0.1,
attention_dropout_prob=0.1,
use_stable_embedding=True,
shared_input_output_embedding=True,
)
)
ckpt = torch.load(
"data/saved_ckpt/instruction_tuning_math_code_multiturn/36001.pt",
map_location="cpu",
)
if "module" in ckpt:
ckpt = ckpt["module"]
raw_model.load_state_dict(ckpt)
raw_model.eval()
model = raw_model.half().cuda()
logging.warning("ready") logging.warning("ready")
@ -111,4 +85,4 @@ with gr.Blocks() as demo:
""" """
) )
demo.launch(share=True) demo.launch()