support load model from accelerate ckpt

This commit is contained in:
LiangSong 2023-04-29 20:40:42 +08:00
parent 52cd09f664
commit 0466673f76
2 changed files with 6 additions and 2 deletions

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-06 22:30:10
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-29 19:38:54
LastEditTime: 2023-04-29 20:40:13
FilePath: /Open-Llama/chat_server.py
Description:
@ -37,6 +37,8 @@ ckpt = torch.load(
"data/saved_ckpt/instruction_tuning_math_code_multiturn/36001.pt",
map_location="cpu",
)
if 'module' in ckpt:
ckpt = ckpt['module']
raw_model.load_state_dict(ckpt)
raw_model.eval()
model = raw_model.cuda()

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-12 19:12:42
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-29 19:38:47
LastEditTime: 2023-04-29 20:37:52
FilePath: /Open-Llama/train_lm.py
Description:
@ -72,6 +72,8 @@ def main(argv):
)
if config["train"]["ckpt"] is not None:
ckpt = torch.load(config["train"]["ckpt"])
if 'module' in ckpt:
ckpt = ckpt['module']
raw_model.load_state_dict(ckpt)
logging.warn("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)