From 6814fdb59e78f6f455e78eda732396b051dcdbf2 Mon Sep 17 00:00:00 2001 From: LiangSong Date: Mon, 8 May 2023 23:40:03 +0800 Subject: [PATCH] support gradient ckpt for peft --- train_lm.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/train_lm.py b/train_lm.py index 1917709..6daa76d 100644 --- a/train_lm.py +++ b/train_lm.py @@ -2,7 +2,7 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-12 19:12:42 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-05-08 22:23:52 +LastEditTime: 2023-05-08 23:39:35 FilePath: /Open-Llama/train_lm.py Description: @@ -76,6 +76,13 @@ def main(argv): raw_model = AutoModelForCausalLM.from_config(model_config) # lora if config["train"].get("use_lora", False): + # gradient ckpt bug, https://github.com/huggingface/transformers/issues/23170 + if hasattr(raw_model, "enable_input_require_grads"): + raw_model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + raw_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "v_proj"],