support load model from accelerate ckpt
This commit is contained in:
parent
52cd09f664
commit
0466673f76
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-06 22:30:10
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-29 19:38:54
|
||||
LastEditTime: 2023-04-29 20:40:13
|
||||
FilePath: /Open-Llama/chat_server.py
|
||||
Description:
|
||||
|
||||
|
@ -37,6 +37,8 @@ ckpt = torch.load(
|
|||
"data/saved_ckpt/instruction_tuning_math_code_multiturn/36001.pt",
|
||||
map_location="cpu",
|
||||
)
|
||||
if 'module' in ckpt:
|
||||
ckpt = ckpt['module']
|
||||
raw_model.load_state_dict(ckpt)
|
||||
raw_model.eval()
|
||||
model = raw_model.cuda()
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-12 19:12:42
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-29 19:38:47
|
||||
LastEditTime: 2023-04-29 20:37:52
|
||||
FilePath: /Open-Llama/train_lm.py
|
||||
Description:
|
||||
|
||||
|
@ -72,6 +72,8 @@ def main(argv):
|
|||
)
|
||||
if config["train"]["ckpt"] is not None:
|
||||
ckpt = torch.load(config["train"]["ckpt"])
|
||||
if 'module' in ckpt:
|
||||
ckpt = ckpt['module']
|
||||
raw_model.load_state_dict(ckpt)
|
||||
logging.warn("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
|
||||
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
||||
|
|
Loading…
Reference in New Issue
Block a user