From 0466673f76c2dae67b5f391578b691e0f2c49454 Mon Sep 17 00:00:00 2001 From: LiangSong Date: Sat, 29 Apr 2023 20:40:42 +0800 Subject: [PATCH] support load model from accelerate ckpt --- chat_server.py | 4 +++- train_lm.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/chat_server.py b/chat_server.py index 411760e..853dd58 100644 --- a/chat_server.py +++ b/chat_server.py @@ -2,7 +2,7 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-06 22:30:10 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-04-29 19:38:54 +LastEditTime: 2023-04-29 20:40:13 FilePath: /Open-Llama/chat_server.py Description: @@ -37,6 +37,8 @@ ckpt = torch.load( "data/saved_ckpt/instruction_tuning_math_code_multiturn/36001.pt", map_location="cpu", ) +if 'module' in ckpt: + ckpt = ckpt['module'] raw_model.load_state_dict(ckpt) raw_model.eval() model = raw_model.cuda() diff --git a/train_lm.py b/train_lm.py index 31ba9bf..6772841 100644 --- a/train_lm.py +++ b/train_lm.py @@ -2,7 +2,7 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-12 19:12:42 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-04-29 19:38:47 +LastEditTime: 2023-04-29 20:37:52 FilePath: /Open-Llama/train_lm.py Description: @@ -72,6 +72,8 @@ def main(argv): ) if config["train"]["ckpt"] is not None: ckpt = torch.load(config["train"]["ckpt"]) + if 'module' in ckpt: + ckpt = ckpt['module'] raw_model.load_state_dict(ckpt) logging.warn("Loaded ckpt from: {}".format(config["train"]["ckpt"])) trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)