support peft

This commit is contained in:
LiangSong 2023-05-08 22:26:39 +08:00
parent 7da40f1c83
commit 92caa94490
5 changed files with 23 additions and 9 deletions

View File

@ -2,7 +2,7 @@
* @Author: LiangSong(sl12160010@gmail.com)
* @Date: 2023-03-10 21:18:35
* @LastEditors: LiangSong(sl12160010@gmail.com)
* @LastEditTime: 2023-05-06 23:33:11
* @LastEditTime: 2023-05-08 22:25:57
* @FilePath: /Open-Llama/README.md
* @Description:
*
@ -19,7 +19,7 @@
<img alt="GitHub last commit" src="https://img.shields.io/github/last-commit/s-JoL/Open-Llama">
</p>
Open-Llama is an open-source project that offers a complete training pipeline for building large language models, ranging from dataset preparation to tokenization, pre-training, prompt tuning, and the reinforcement learning technique RLHF.
Open-Llama is an open-source project that offers a complete training pipeline for building large language models, ranging from dataset preparation to tokenization, pre-training, prompt tuning, lora, and the reinforcement learning technique RLHF.
**You can try this model directly from the [Demo](http://home.ustc.edu.cn/~sl9292/).**
@ -61,7 +61,8 @@ Below is a display of the model's multi-turn dialogue ability regarding code:
**[2023.5.8] Release v2.1**
This update adds support for larger model training. Using DeepSpeed stage3 + offload + activation checkpoint, you can **train a 65B model on a single machine with 8 A100-80G**.
This update adds support for larger model training. Using DeepSpeed stage3 + offload + activation checkpoint, you can **train a 65B model on a single machine with 8 A100-80G**.
At the same time, the peft library is introduced to **support training such as lora**.
The following table compares the training speed of Open-Llama and the original Llama, and the performance data of Llama is quoted from the original Llama paper.
| | DeepSpeed Stage | Offload | Activation Checkpoint | Total Token | GPU hours | Speed token/s/gpu | Batch Size | CPU Memory |
|----------------|-----------------|---------|-----------------------|-------------|-----------|-------------------|------------|------------|

View File

@ -2,7 +2,7 @@
* @Author: LiangSong(sl12160010@gmail.com)
* @Date: 2023-03-10 21:18:35
* @LastEditors: LiangSong(sl12160010@gmail.com)
* @LastEditTime: 2023-05-06 23:32:31
* @LastEditTime: 2023-05-08 22:25:28
* @FilePath: /Open-Llama/README_zh.md
* @Description:
*
@ -19,7 +19,7 @@
<img alt="GitHub last commit" src="https://img.shields.io/github/last-commit/s-JoL/Open-Llama">
</p>
Open-Llama是一个开源项目提供了一整套用于构建大型语言模型的训练流程从数据集准备到分词、预训练、指令调优以及强化学习技术 RLHF。
Open-Llama是一个开源项目提供了一整套用于构建大型语言模型的训练流程从数据集准备到分词、预训练、指令调优lora, 以及强化学习技术 RLHF。
**可从[Demo](http://home.ustc.edu.cn/~sl9292/)直接试用本模型。**
@ -62,7 +62,7 @@ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
**[2023.5.8] Release v2.1**
本次更新加入对更大模型训练的支持使用DeepSpeed stage3 + offload + activation checkpoint可以在**单机8卡A100-80G训练65B模型**。
本次更新加入对更大模型训练的支持使用DeepSpeed stage3 + offload + activation checkpoint可以在**单机8卡A100-80G训练65B模型**。同时引入peft库**支持lora**等训练。
下表对比了Open-Llama和Llama原文的训练速度Llama性能数据引自Llama原文。
| | DeepSpeed Stage | Offload | Activation Checkpoint | Total Token | GPU hours | Speed token/s/gpu | Batch Size | CPU Memory |
|----------------|-----------------|---------|-----------------------|-------------|-----------|-------------------|------------|------------|

View File

@ -184,7 +184,7 @@ def construct_dataset(
random.shuffle(all_data_files)
if world_size is not None:
num_shards = len(all_data_files)
all_data_files = all_data_files[:num_shards // world_size * world_size]
all_data_files = all_data_files[: num_shards // world_size * world_size]
dataset = load_dataset(
"json", data_files=all_data_files, split="train", streaming=True
)

View File

@ -17,4 +17,5 @@ triton
functorch==1.13.1
xformers
gradio
peft
git+https://github.com/huggingface/transformers.git

View File

@ -2,19 +2,19 @@
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-12 19:12:42
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-05-06 23:08:42
LastEditTime: 2023-05-08 22:23:52
FilePath: /Open-Llama/train_lm.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
"""
import yaml
import torch
import logging
from absl import app
from absl import flags
from accelerate import Accelerator
from torch.utils.data import DataLoader
from peft import LoraConfig, TaskType, get_peft_model
from datasets.distributed import split_dataset_by_node
from transformers import AutoConfig, AutoModelForCausalLM, LlamaTokenizer
@ -74,6 +74,18 @@ def main(argv):
logging.warning("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
else:
raw_model = AutoModelForCausalLM.from_config(model_config)
# lora
if config["train"].get("use_lora", False):
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "v_proj"],
inference_mode=False,
r=1,
lora_alpha=32,
lora_dropout=0.1,
)
raw_model = get_peft_model(raw_model, peft_config)
raw_model.print_trainable_parameters()
if config["train"].get("gradient_checkpointing_enable", False):
raw_model.gradient_checkpointing_enable()
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)