update config
This commit is contained in:
parent
0466673f76
commit
f05e929aad
|
@ -37,8 +37,8 @@ ckpt = torch.load(
|
|||
"data/saved_ckpt/instruction_tuning_math_code_multiturn/36001.pt",
|
||||
map_location="cpu",
|
||||
)
|
||||
if 'module' in ckpt:
|
||||
ckpt = ckpt['module']
|
||||
if "module" in ckpt:
|
||||
ckpt = ckpt["module"]
|
||||
raw_model.load_state_dict(ckpt)
|
||||
raw_model.eval()
|
||||
model = raw_model.cuda()
|
||||
|
|
|
@ -16,8 +16,8 @@ model:
|
|||
shared_input_output_embedding: False
|
||||
train:
|
||||
train_batch_size: 2
|
||||
num_training_steps: 1000000
|
||||
num_warmup_steps: 2000
|
||||
num_training_steps: 40000
|
||||
num_warmup_steps: 500
|
||||
initializer_range: 1.0e-2
|
||||
lr: 2.0e-4
|
||||
weight_decay: 1.0e-1
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-24 20:05:21
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-29 20:26:34
|
||||
LastEditTime: 2023-04-29 21:59:51
|
||||
FilePath: /Open-Llama/solver/trainer.py
|
||||
Description:
|
||||
|
||||
|
@ -112,6 +112,7 @@ class Trainer:
|
|||
self.train_loader = self.accelerator.skip_first_batches(
|
||||
self.train_loader, num_batches=skip_steps
|
||||
)
|
||||
self.accelerator.wait_for_everyone()
|
||||
|
||||
def train_step(self, batch):
|
||||
out = self.model(**batch)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-12 19:12:42
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-04-29 20:37:52
|
||||
LastEditTime: 2023-05-02 18:26:50
|
||||
FilePath: /Open-Llama/train_lm.py
|
||||
Description:
|
||||
|
||||
|
@ -71,9 +71,9 @@ def main(argv):
|
|||
)
|
||||
)
|
||||
if config["train"]["ckpt"] is not None:
|
||||
ckpt = torch.load(config["train"]["ckpt"])
|
||||
if 'module' in ckpt:
|
||||
ckpt = ckpt['module']
|
||||
ckpt = torch.load(config["train"]["ckpt"], map_location="cpu")
|
||||
if "module" in ckpt:
|
||||
ckpt = ckpt["module"]
|
||||
raw_model.load_state_dict(ckpt)
|
||||
logging.warn("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
|
||||
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
||||
|
|
Loading…
Reference in New Issue
Block a user