update optimizer for lora
This commit is contained in:
parent
58586112c1
commit
3ba0c77053
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-24 20:05:21
|
Date: 2023-04-24 20:05:21
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-05-06 23:04:14
|
LastEditTime: 2023-05-08 22:51:42
|
||||||
FilePath: /Open-Llama/solver/trainer.py
|
FilePath: /Open-Llama/solver/trainer.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -56,24 +56,27 @@ class Trainer:
|
||||||
|
|
||||||
def get_optimizer(self):
|
def get_optimizer(self):
|
||||||
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
|
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight"]
|
||||||
optimizer_grouped_parameters = [
|
if self.config["train"].get("use_lora", False):
|
||||||
{
|
optimizer_grouped_parameters = self.raw_model.parameters()
|
||||||
"params": [
|
else:
|
||||||
p
|
optimizer_grouped_parameters = [
|
||||||
for n, p in self.raw_model.named_parameters()
|
{
|
||||||
if not any(nd in n for nd in no_decay)
|
"params": [
|
||||||
],
|
p
|
||||||
"weight_decay": self.config["train"]["weight_decay"],
|
for n, p in self.raw_model.named_parameters()
|
||||||
},
|
if not any(nd in n for nd in no_decay)
|
||||||
{
|
],
|
||||||
"params": [
|
"weight_decay": self.config["train"]["weight_decay"],
|
||||||
p
|
},
|
||||||
for n, p in self.raw_model.named_parameters()
|
{
|
||||||
if any(nd in n for nd in no_decay)
|
"params": [
|
||||||
],
|
p
|
||||||
"weight_decay": 0.0,
|
for n, p in self.raw_model.named_parameters()
|
||||||
},
|
if any(nd in n for nd in no_decay)
|
||||||
]
|
],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
self.optim = FusedAdam(
|
self.optim = FusedAdam(
|
||||||
optimizer_grouped_parameters,
|
optimizer_grouped_parameters,
|
||||||
lr=self.config["train"]["lr"],
|
lr=self.config["train"]["lr"],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user