update server

This commit is contained in:
LiangSong 2023-05-12 15:07:46 +08:00
parent 7231d53ca4
commit e18ead00cc

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-06 22:30:10
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-05-06 23:30:57
LastEditTime: 2023-05-12 15:07:36
FilePath: /Open-Llama/chat_server.py
Description:
@ -11,37 +11,11 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
import torch
import logging
import gradio as gr
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = LlamaTokenizer(
"configs/tokenizer_models/10w_vocab_wudao5_pile10.model",
pad_token="<pad>",
add_bos_token=False,
add_eos_token=True,
)
raw_model = OpenLlamaForCausalLM(
OpenLlamaConfig(
vocab_size=tokenizer.vocab_size,
initializer_range=0.01,
pad_token_id=tokenizer.pad_token_id,
rms_norm_eps=1e-5,
hidden_dropout_prob=0.1,
attention_dropout_prob=0.1,
use_stable_embedding=True,
shared_input_output_embedding=True,
)
)
ckpt = torch.load(
"data/saved_ckpt/instruction_tuning_math_code_multiturn/36001.pt",
map_location="cpu",
)
if "module" in ckpt:
ckpt = ckpt["module"]
raw_model.load_state_dict(ckpt)
raw_model.eval()
model = raw_model.half().cuda()
tokenizer = AutoTokenizer.from_pretrained("s-JoL/Open-Llama-V2", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("s-JoL/Open-Llama-V2", torch_dtype=torch.bfloat16, device_map="auto")
logging.warning("ready")
@ -111,4 +85,4 @@ with gr.Blocks() as demo:
"""
)
demo.launch(share=True)
demo.launch()