support peft
This commit is contained in:
parent
7da40f1c83
commit
92caa94490
|
@ -2,7 +2,7 @@
|
|||
* @Author: LiangSong(sl12160010@gmail.com)
|
||||
* @Date: 2023-03-10 21:18:35
|
||||
* @LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
* @LastEditTime: 2023-05-06 23:33:11
|
||||
* @LastEditTime: 2023-05-08 22:25:57
|
||||
* @FilePath: /Open-Llama/README.md
|
||||
* @Description:
|
||||
*
|
||||
|
@ -19,7 +19,7 @@
|
|||
<img alt="GitHub last commit" src="https://img.shields.io/github/last-commit/s-JoL/Open-Llama">
|
||||
</p>
|
||||
|
||||
Open-Llama is an open-source project that offers a complete training pipeline for building large language models, ranging from dataset preparation to tokenization, pre-training, prompt tuning, and the reinforcement learning technique RLHF.
|
||||
Open-Llama is an open-source project that offers a complete training pipeline for building large language models, ranging from dataset preparation to tokenization, pre-training, prompt tuning, lora, and the reinforcement learning technique RLHF.
|
||||
|
||||
**You can try this model directly from the [Demo](http://home.ustc.edu.cn/~sl9292/).**
|
||||
|
||||
|
@ -61,7 +61,8 @@ Below is a display of the model's multi-turn dialogue ability regarding code:
|
|||
|
||||
**[2023.5.8] Release v2.1**
|
||||
|
||||
This update adds support for larger model training. Using DeepSpeed stage3 + offload + activation checkpoint, you can **train a 65B model on a single machine with 8 A100-80G**.
|
||||
This update adds support for larger model training. Using DeepSpeed stage3 + offload + activation checkpoint, you can **train a 65B model on a single machine with 8 A100-80G**.
|
||||
At the same time, the peft library is introduced to **support training such as lora**.
|
||||
The following table compares the training speed of Open-Llama and the original Llama, and the performance data of Llama is quoted from the original Llama paper.
|
||||
| | DeepSpeed Stage | Offload | Activation Checkpoint | Total Token | GPU hours | Speed token/s/gpu | Batch Size | CPU Memory |
|
||||
|----------------|-----------------|---------|-----------------------|-------------|-----------|-------------------|------------|------------|
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
* @Author: LiangSong(sl12160010@gmail.com)
|
||||
* @Date: 2023-03-10 21:18:35
|
||||
* @LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
* @LastEditTime: 2023-05-06 23:32:31
|
||||
* @LastEditTime: 2023-05-08 22:25:28
|
||||
* @FilePath: /Open-Llama/README_zh.md
|
||||
* @Description:
|
||||
*
|
||||
|
@ -19,7 +19,7 @@
|
|||
<img alt="GitHub last commit" src="https://img.shields.io/github/last-commit/s-JoL/Open-Llama">
|
||||
</p>
|
||||
|
||||
Open-Llama是一个开源项目,提供了一整套用于构建大型语言模型的训练流程,从数据集准备到分词、预训练、指令调优,以及强化学习技术 RLHF。
|
||||
Open-Llama是一个开源项目,提供了一整套用于构建大型语言模型的训练流程,从数据集准备到分词、预训练、指令调优,lora, 以及强化学习技术 RLHF。
|
||||
|
||||
**可从[Demo](http://home.ustc.edu.cn/~sl9292/)直接试用本模型。**
|
||||
|
||||
|
@ -62,7 +62,7 @@ print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
|
|||
|
||||
**[2023.5.8] Release v2.1**
|
||||
|
||||
本次更新加入对更大模型训练的支持,使用DeepSpeed stage3 + offload + activation checkpoint可以在**单机8卡A100-80G训练65B模型**。
|
||||
本次更新加入对更大模型训练的支持,使用DeepSpeed stage3 + offload + activation checkpoint可以在**单机8卡A100-80G训练65B模型**。同时引入peft库**支持lora**等训练。
|
||||
下表对比了Open-Llama和Llama原文的训练速度,Llama性能数据引自Llama原文。
|
||||
| | DeepSpeed Stage | Offload | Activation Checkpoint | Total Token | GPU hours | Speed token/s/gpu | Batch Size | CPU Memory |
|
||||
|----------------|-----------------|---------|-----------------------|-------------|-----------|-------------------|------------|------------|
|
||||
|
|
|
@ -184,7 +184,7 @@ def construct_dataset(
|
|||
random.shuffle(all_data_files)
|
||||
if world_size is not None:
|
||||
num_shards = len(all_data_files)
|
||||
all_data_files = all_data_files[:num_shards // world_size * world_size]
|
||||
all_data_files = all_data_files[: num_shards // world_size * world_size]
|
||||
dataset = load_dataset(
|
||||
"json", data_files=all_data_files, split="train", streaming=True
|
||||
)
|
||||
|
|
|
@ -17,4 +17,5 @@ triton
|
|||
functorch==1.13.1
|
||||
xformers
|
||||
gradio
|
||||
peft
|
||||
git+https://github.com/huggingface/transformers.git
|
16
train_lm.py
16
train_lm.py
|
@ -2,19 +2,19 @@
|
|||
Author: LiangSong(sl12160010@gmail.com)
|
||||
Date: 2023-04-12 19:12:42
|
||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||
LastEditTime: 2023-05-06 23:08:42
|
||||
LastEditTime: 2023-05-08 22:23:52
|
||||
FilePath: /Open-Llama/train_lm.py
|
||||
Description:
|
||||
|
||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||
"""
|
||||
import yaml
|
||||
import torch
|
||||
import logging
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from accelerate import Accelerator
|
||||
from torch.utils.data import DataLoader
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
from datasets.distributed import split_dataset_by_node
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, LlamaTokenizer
|
||||
|
||||
|
@ -74,6 +74,18 @@ def main(argv):
|
|||
logging.warning("Loaded ckpt from: {}".format(config["train"]["ckpt"]))
|
||||
else:
|
||||
raw_model = AutoModelForCausalLM.from_config(model_config)
|
||||
# lora
|
||||
if config["train"].get("use_lora", False):
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
inference_mode=False,
|
||||
r=1,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
raw_model = get_peft_model(raw_model, peft_config)
|
||||
raw_model.print_trainable_parameters()
|
||||
if config["train"].get("gradient_checkpointing_enable", False):
|
||||
raw_model.gradient_checkpointing_enable()
|
||||
trainer = Trainer(config, raw_model, train_loader, tokenizer, accelerator)
|
||||
|
|
Loading…
Reference in New Issue
Block a user