support gradient ckpt for peft
This commit is contained in:
parent
3ba0c77053
commit
6814fdb59e
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-12 19:12:42
|
Date: 2023-04-12 19:12:42
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-05-08 22:23:52
|
LastEditTime: 2023-05-08 23:39:35
|
||||||
FilePath: /Open-Llama/train_lm.py
|
FilePath: /Open-Llama/train_lm.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -76,6 +76,13 @@ def main(argv):
|
||||||
raw_model = AutoModelForCausalLM.from_config(model_config)
|
raw_model = AutoModelForCausalLM.from_config(model_config)
|
||||||
# lora
|
# lora
|
||||||
if config["train"].get("use_lora", False):
|
if config["train"].get("use_lora", False):
|
||||||
|
# gradient ckpt bug, https://github.com/huggingface/transformers/issues/23170
|
||||||
|
if hasattr(raw_model, "enable_input_require_grads"):
|
||||||
|
raw_model.enable_input_require_grads()
|
||||||
|
else:
|
||||||
|
def make_inputs_require_grad(module, input, output):
|
||||||
|
output.requires_grad_(True)
|
||||||
|
raw_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||||
peft_config = LoraConfig(
|
peft_config = LoraConfig(
|
||||||
task_type=TaskType.CAUSAL_LM,
|
task_type=TaskType.CAUSAL_LM,
|
||||||
target_modules=["q_proj", "v_proj"],
|
target_modules=["q_proj", "v_proj"],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user