update config
This commit is contained in:
parent
0466673f76
commit
f05e929aad
|
@ -37,8 +37,8 @@ ckpt = torch.load(
|
||||||
"data/saved_ckpt/instruction_tuning_math_code_multiturn/36001.pt",
|
"data/saved_ckpt/instruction_tuning_math_code_multiturn/36001.pt",
|
||||||
map_location="cpu",
|
map_location="cpu",
|
||||||
)
|
)
|
||||||
if 'module' in ckpt:
|
if "module" in ckpt:
|
||||||
ckpt = ckpt['module']
|
ckpt = ckpt["module"]
|
||||||
raw_model.load_state_dict(ckpt)
|
raw_model.load_state_dict(ckpt)
|
||||||
raw_model.eval()
|
raw_model.eval()
|
||||||
model = raw_model.cuda()
|
model = raw_model.cuda()
|
||||||
|
|
|
@ -16,8 +16,8 @@ model:
|
||||||
shared_input_output_embedding: False
|
shared_input_output_embedding: False
|
||||||
train:
|
train:
|
||||||
train_batch_size: 2
|
train_batch_size: 2
|
||||||
num_training_steps: 1000000
|
num_training_steps: 40000
|
||||||
num_warmup_steps: 2000
|
num_warmup_steps: 500
|
||||||
initializer_range: 1.0e-2
|
initializer_range: 1.0e-2
|
||||||
lr: 2.0e-4
|
lr: 2.0e-4
|
||||||
weight_decay: 1.0e-1
|
weight_decay: 1.0e-1
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-24 20:05:21
|
Date: 2023-04-24 20:05:21
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-29 20:26:34
|
LastEditTime: 2023-04-29 21:59:51
|
||||||
FilePath: /Open-Llama/solver/trainer.py
|
FilePath: /Open-Llama/solver/trainer.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -112,6 +112,7 @@ class Trainer:
|
||||||
self.train_loader = self.accelerator.skip_first_batches(
|
self.train_loader = self.accelerator.skip_first_batches(
|
||||||
self.train_loader, num_batches=skip_steps
|
self.train_loader, num_batches=skip_steps
|
||||||
)
|
)
|
||||||
|
self.accelerator.wait_for_everyone()
|
||||||
|
|
||||||
def train_step(self, batch):
|
def train_step(self, batch):
|
||||||
out = self.model(**batch)
|
out = self.model(**batch)
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-12 19:12:42
|
Date: 2023-04-12 19:12:42
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-29 20:37:52
|
LastEditTime: 2023-05-02 18:26:50
|
||||||
FilePath: /Open-Llama/train_lm.py
|
FilePath: /Open-Llama/train_lm.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -71,9 +71,9 @@ def main(argv):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if config["train"]["ckpt"] is not None:
|
if config["train"]["ckpt"] is not None:
|
||||||
ckpt = torch.load(config["train"]["ckpt"])
|
ckpt = torch.load(config["train"]["ckpt"], map_location="cpu")
|
||||||
if 'module' in ckpt:
|
if "module" in ckpt:
|
||||||
ckpt = ckpt['module']
|
ckpt = ckpt["module"]
|
||||||
raw_model.load_state_dict(ckpt)
|
raw_model.load_state_dict(ckpt)
|
||||||
logging.warn("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
|
logging.warn("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
|
||||||
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user