diff --git a/solver/trainer.py b/solver/trainer.py index 6974a58..bd57cb9 100644 --- a/solver/trainer.py +++ b/solver/trainer.py @@ -2,7 +2,7 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-24 20:05:21 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-05-06 23:04:14 +LastEditTime: 2023-05-08 22:51:42 FilePath: /Open-Llama/solver/trainer.py Description: @@ -56,24 +56,27 @@ class Trainer: def get_optimizer(self): no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [ - p - for n, p in self.raw_model.named_parameters() - if not any(nd in n for nd in no_decay) - ], - "weight_decay": self.config["train"]["weight_decay"], - }, - { - "params": [ - p - for n, p in self.raw_model.named_parameters() - if any(nd in n for nd in no_decay) - ], - "weight_decay": 0.0, - }, - ] + if self.config["train"].get("use_lora", False): + optimizer_grouped_parameters = self.raw_model.parameters() + else: + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in self.raw_model.named_parameters() + if not any(nd in n for nd in no_decay) + ], + "weight_decay": self.config["train"]["weight_decay"], + }, + { + "params": [ + p + for n, p in self.raw_model.named_parameters() + if any(nd in n for nd in no_decay) + ], + "weight_decay": 0.0, + }, + ] self.optim = FusedAdam( optimizer_grouped_parameters, lr=self.config["train"]["lr"],