This commit is contained in:
LiangSong 2023-05-04 22:32:15 +08:00
parent 98ffab3a97
commit fbb7997607

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com) Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-06 22:30:10 Date: 2023-04-06 22:30:10
LastEditors: LiangSong(sl12160010@gmail.com) LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-05-04 22:28:07 LastEditTime: 2023-05-04 22:32:05
FilePath: /Open-Llama/chat_server.py FilePath: /Open-Llama/chat_server.py
Description: Description:
@ -88,10 +88,8 @@ with gr.Blocks() as demo:
context = torch.cat(context, dim=-1) context = torch.cat(context, dim=-1)
context = context[:, -1024:] context = context[:, -1024:]
inputs_len = context.shape[1] inputs_len = context.shape[1]
context = context.half().cuda() context = context.cuda()
pred = model.generate( pred = model.generate(input_ids=context, max_new_tokens=1024, do_sample=True)
input_ids=context, max_new_tokens=1024, do_sample=True
)
pred = pred[:, inputs_len:] pred = pred[:, inputs_len:]
pred = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True) pred = tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
logging.warn(pred) logging.warn(pred)