support peft
This commit is contained in:
parent
7da40f1c83
commit
92caa94490
|
@ -2,7 +2,7 @@
|
||||||
* @Author: LiangSong(sl12160010@gmail.com)
|
* @Author: LiangSong(sl12160010@gmail.com)
|
||||||
* @Date: 2023-03-10 21:18:35
|
* @Date: 2023-03-10 21:18:35
|
||||||
* @LastEditors: LiangSong(sl12160010@gmail.com)
|
* @LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
* @LastEditTime: 2023-05-06 23:33:11
|
* @LastEditTime: 2023-05-08 22:25:57
|
||||||
* @FilePath: /Open-Llama/README.md
|
* @FilePath: /Open-Llama/README.md
|
||||||
* @Description:
|
* @Description:
|
||||||
*
|
*
|
||||||
|
@ -19,7 +19,7 @@
|
||||||
<img alt="GitHub last commit" src="https://img.shields.io/github/last-commit/s-JoL/Open-Llama">
|
<img alt="GitHub last commit" src="https://img.shields.io/github/last-commit/s-JoL/Open-Llama">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
Open-Llama is an open-source project that offers a complete training pipeline for building large language models, ranging from dataset preparation to tokenization, pre-training, prompt tuning, and the reinforcement learning technique RLHF.
|
Open-Llama is an open-source project that offers a complete training pipeline for building large language models, ranging from dataset preparation to tokenization, pre-training, prompt tuning, lora, and the reinforcement learning technique RLHF.
|
||||||
|
|
||||||
**You can try this model directly from the [Demo](http://home.ustc.edu.cn/~sl9292/).**
|
**You can try this model directly from the [Demo](http://home.ustc.edu.cn/~sl9292/).**
|
||||||
|
|
||||||
|
@ -62,6 +62,7 @@ Below is a display of the model's multi-turn dialogue ability regarding code:
|
||||||
**[2023.5.8] Release v2.1**
|
**[2023.5.8] Release v2.1**
|
||||||
|
|
||||||
This update adds support for larger model training. Using DeepSpeed stage3 + offload + activation checkpoint, you can **train a 65B model on a single machine with 8 A100-80G**.
|
This update adds support for larger model training. Using DeepSpeed stage3 + offload + activation checkpoint, you can **train a 65B model on a single machine with 8 A100-80G**.
|
||||||
|
At the same time, the peft library is introduced to **support training such as lora**.
|
||||||
The following table compares the training speed of Open-Llama and the original Llama, and the performance data of Llama is quoted from the original Llama paper.
|
The following table compares the training speed of Open-Llama and the original Llama, and the performance data of Llama is quoted from the original Llama paper.
|
||||||
| | DeepSpeed Stage | Offload | Activation Checkpoint | Total Token | GPU hours | Speed token/s/gpu | Batch Size | CPU Memory |
|
| | DeepSpeed Stage | Offload | Activation Checkpoint | Total Token | GPU hours | Speed token/s/gpu | Batch Size | CPU Memory |
|
||||||
|----------------|-----------------|---------|-----------------------|-------------|-----------|-------------------|------------|------------|
|
|----------------|-----------------|---------|-----------------------|-------------|-----------|-------------------|------------|------------|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
* @Author: LiangSong(sl12160010@gmail.com)
|
* @Author: LiangSong(sl12160010@gmail.com)
|
||||||
* @Date: 2023-03-10 21:18:35
|
* @Date: 2023-03-10 21:18:35
|
||||||
* @LastEditors: LiangSong(sl12160010@gmail.com)
|
* @LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
* @LastEditTime: 2023-05-06 23:32:31
|
* @LastEditTime: 2023-05-08 22:25:28
|
||||||
* @FilePath: /Open-Llama/README_zh.md
|
* @FilePath: /Open-Llama/README_zh.md
|
||||||
* @Description:
|
* @Description:
|
||||||
*
|
*
|
||||||
|
@ -19,7 +19,7 @@
|
||||||
<img alt="GitHub last commit" src="https://img.shields.io/github/last-commit/s-JoL/Open-Llama">
|
<img alt="GitHub last commit" src="https://img.shields.io/github/last-commit/s-JoL/Open-Llama">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
Open-Llama是一个开源项目,提供了一整套用于构建大型语言模型的训练流程,从数据集准备到分词、预训练、指令调优,以及强化学习技术 RLHF。
|
Open-Llama是一个开源项目,提供了一整套用于构建大型语言模型的训练流程,从数据集准备到分词、预训练、指令调优,lora, 以及强化学习技术 RLHF。
|
||||||
|
|
||||||
**可从[Demo](http://home.ustc.edu.cn/~sl9292/)直接试用本模型。**
|
**可从[Demo](http://home.ustc.edu.cn/~sl9292/)直接试用本模型。**
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
|
||||||
|
|
||||||
**[2023.5.8] Release v2.1**
|
**[2023.5.8] Release v2.1**
|
||||||
|
|
||||||
本次更新加入对更大模型训练的支持,使用DeepSpeed stage3 + offload + activation checkpoint可以在**单机8卡A100-80G训练65B模型**。
|
本次更新加入对更大模型训练的支持,使用DeepSpeed stage3 + offload + activation checkpoint可以在**单机8卡A100-80G训练65B模型**。同时引入peft库**支持lora**等训练。
|
||||||
下表对比了Open-Llama和Llama原文的训练速度,Llama性能数据引自Llama原文。
|
下表对比了Open-Llama和Llama原文的训练速度,Llama性能数据引自Llama原文。
|
||||||
| | DeepSpeed Stage | Offload | Activation Checkpoint | Total Token | GPU hours | Speed token/s/gpu | Batch Size | CPU Memory |
|
| | DeepSpeed Stage | Offload | Activation Checkpoint | Total Token | GPU hours | Speed token/s/gpu | Batch Size | CPU Memory |
|
||||||
|----------------|-----------------|---------|-----------------------|-------------|-----------|-------------------|------------|------------|
|
|----------------|-----------------|---------|-----------------------|-------------|-----------|-------------------|------------|------------|
|
||||||
|
|
|
@ -17,4 +17,5 @@ triton
|
||||||
functorch==1.13.1
|
functorch==1.13.1
|
||||||
xformers
|
xformers
|
||||||
gradio
|
gradio
|
||||||
|
peft
|
||||||
git+https://github.com/huggingface/transformers.git
|
git+https://github.com/huggingface/transformers.git
|
16
train_lm.py
16
train_lm.py
|
@ -2,19 +2,19 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-12 19:12:42
|
Date: 2023-04-12 19:12:42
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-05-06 23:08:42
|
LastEditTime: 2023-05-08 22:23:52
|
||||||
FilePath: /Open-Llama/train_lm.py
|
FilePath: /Open-Llama/train_lm.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
import yaml
|
import yaml
|
||||||
import torch
|
|
||||||
import logging
|
import logging
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from peft import LoraConfig, TaskType, get_peft_model
|
||||||
from datasets.distributed import split_dataset_by_node
|
from datasets.distributed import split_dataset_by_node
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, LlamaTokenizer
|
from transformers import AutoConfig, AutoModelForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
|
@ -74,6 +74,18 @@ def main(argv):
|
||||||
logging.warning("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
|
logging.warning("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
|
||||||
else:
|
else:
|
||||||
raw_model = AutoModelForCausalLM.from_config(model_config)
|
raw_model = AutoModelForCausalLM.from_config(model_config)
|
||||||
|
# lora
|
||||||
|
if config["train"].get("use_lora", False):
|
||||||
|
peft_config = LoraConfig(
|
||||||
|
task_type=TaskType.CAUSAL_LM,
|
||||||
|
target_modules=["q_proj", "v_proj"],
|
||||||
|
inference_mode=False,
|
||||||
|
r=1,
|
||||||
|
lora_alpha=32,
|
||||||
|
lora_dropout=0.1,
|
||||||
|
)
|
||||||
|
raw_model = get_peft_model(raw_model, peft_config)
|
||||||
|
raw_model.print_trainable_parameters()
|
||||||
if config["train"].get("gradient_checkpointing_enable", False):
|
if config["train"].get("gradient_checkpointing_enable", False):
|
||||||
raw_model.gradient_checkpointing_enable()
|
raw_model.gradient_checkpointing_enable()
|
||||||
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user