support gradient ckpt for peft
This commit is contained in:
parent
3ba0c77053
commit
6814fdb59e
|
@ -2,7 +2,7 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-12 19:12:42
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-05-08 22:23:52
|
||||
LastEditTime: 2023-05-08 23:39:35
|
||||
FilePath: /Open-Llama/train_lm.py
|
||||
Description:
|
||||
|
||||
|
@ -76,6 +76,13 @@ def main(argv):
|
|||
raw_model = AutoModelForCausalLM.from_config(model_config)
|
||||
# lora
|
||||
if config["train"].get("use_lora", False):
|
||||
# gradient ckpt bug, https://github.com/huggingface/transformers/issues/23170
|
||||
if hasattr(raw_model, "enable_input_require_grads"):
|
||||
raw_model.enable_input_require_grads()
|
||||
else:
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
raw_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
|
|
Loading…
Reference in New Issue
Block a user