support gradient ckpt for peft

This commit is contained in:
LiangSong 2023-05-08 23:40:03 +08:00
parent 3ba0c77053
commit 6814fdb59e

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-12 19:12:42
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-05-08 22:23:52
LastEditTime: 2023-05-08 23:39:35
FilePath: /Open-Llama/train_lm.py
Description:
@ -76,6 +76,13 @@ def main(argv):
raw_model = AutoModelForCausalLM.from_config(model_config)
# lora
if config["train"].get("use_lora", False):
# gradient ckpt bug, https://github.com/huggingface/transformers/issues/23170
if hasattr(raw_model, "enable_input_require_grads"):
raw_model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
raw_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "v_proj"],