update optimizer for lora

This commit is contained in:
LiangSong 2023-05-08 22:56:37 +08:00
parent 58586112c1
commit 3ba0c77053

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com) Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-24 20:05:21 Date: 2023-04-24 20:05:21
LastEditors: LiangSong(sl12160010@gmail.com) LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-05-06 23:04:14 LastEditTime: 2023-05-08 22:51:42
FilePath: /Open-Llama/solver/trainer.py FilePath: /Open-Llama/solver/trainer.py
Description: Description:
@ -56,24 +56,27 @@ class Trainer:
def get_optimizer(self): def get_optimizer(self):
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"] no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
optimizer_grouped_parameters = [ if self.config["train"].get("use_lora", False):
{ optimizer_grouped_parameters = self.raw_model.parameters()
"params": [ else:
p optimizer_grouped_parameters = [
for n, p in self.raw_model.named_parameters() {
if not any(nd in n for nd in no_decay) "params": [
], p
"weight_decay": self.config["train"]["weight_decay"], for n, p in self.raw_model.named_parameters()
}, if not any(nd in n for nd in no_decay)
{ ],
"params": [ "weight_decay": self.config["train"]["weight_decay"],
p },
for n, p in self.raw_model.named_parameters() {
if any(nd in n for nd in no_decay) "params": [
], p
"weight_decay": 0.0, for n, p in self.raw_model.named_parameters()
}, if any(nd in n for nd in no_decay)
] ],
"weight_decay": 0.0,
},
]
self.optim = FusedAdam( self.optim = FusedAdam(
optimizer_grouped_parameters, optimizer_grouped_parameters,
lr=self.config["train"]["lr"], lr=self.config["train"]["lr"],