support load model from accelerate ckpt
This commit is contained in:
parent
52cd09f664
commit
0466673f76
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-06 22:30:10
|
Date: 2023-04-06 22:30:10
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-29 19:38:54
|
LastEditTime: 2023-04-29 20:40:13
|
||||||
FilePath: /Open-Llama/chat_server.py
|
FilePath: /Open-Llama/chat_server.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -37,6 +37,8 @@ ckpt = torch.load(
|
||||||
"data/saved_ckpt/instruction_tuning_math_code_multiturn/36001.pt",
|
"data/saved_ckpt/instruction_tuning_math_code_multiturn/36001.pt",
|
||||||
map_location="cpu",
|
map_location="cpu",
|
||||||
)
|
)
|
||||||
|
if 'module' in ckpt:
|
||||||
|
ckpt = ckpt['module']
|
||||||
raw_model.load_state_dict(ckpt)
|
raw_model.load_state_dict(ckpt)
|
||||||
raw_model.eval()
|
raw_model.eval()
|
||||||
model = raw_model.cuda()
|
model = raw_model.cuda()
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-12 19:12:42
|
Date: 2023-04-12 19:12:42
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-29 19:38:47
|
LastEditTime: 2023-04-29 20:37:52
|
||||||
FilePath: /Open-Llama/train_lm.py
|
FilePath: /Open-Llama/train_lm.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -72,6 +72,8 @@ def main(argv):
|
||||||
)
|
)
|
||||||
if config["train"]["ckpt"] is not None:
|
if config["train"]["ckpt"] is not None:
|
||||||
ckpt = torch.load(config["train"]["ckpt"])
|
ckpt = torch.load(config["train"]["ckpt"])
|
||||||
|
if 'module' in ckpt:
|
||||||
|
ckpt = ckpt['module']
|
||||||
raw_model.load_state_dict(ckpt)
|
raw_model.load_state_dict(ckpt)
|
||||||
logging.warn("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
|
logging.warn("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
|
||||||
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user