diff --git a/chat_server.py b/chat_server.py index 96f551a..5efbcf5 100644 --- a/chat_server.py +++ b/chat_server.py @@ -2,7 +2,7 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-06 22:30:10 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-05-04 22:28:07 +LastEditTime: 2023-05-04 22:32:05 FilePath: /Open-Llama/chat_server.py Description: @@ -88,10 +88,8 @@ with gr.Blocks() as demo: context = torch.cat(context, dim=-1) context = context[:, -1024:] inputs_len = context.shape[1] - context = context.half().cuda() - pred = model.generate( - input_ids=context, max_new_tokens=1024, do_sample=True - ) + context = context.cuda() + pred = model.generate(input_ids=context, max_new_tokens=1024, do_sample=True) pred = pred[:, inputs_len:] pred = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True) logging.warn(pred)