use split_dataset_by_node instead accelerate.prepare to accelerate data loading by 50%

This commit is contained in:
LiangSong 2023-04-27 00:04:11 +08:00
parent 0377b43628
commit 97aff0e051
3 changed files with 33 additions and 14 deletions

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com) Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-06 22:30:10 Date: 2023-04-06 22:30:10
LastEditors: LiangSong(sl12160010@gmail.com) LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-07 23:03:31 LastEditTime: 2023-04-26 23:58:23
FilePath: /Open-Llama/chat_server.py FilePath: /Open-Llama/chat_server.py
Description: Description:
@ -10,20 +10,21 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
""" """
import torch import torch
import gradio as gr import gradio as gr
import sentencepiece as spm from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
from dataset.tokenizer import Tokenizer
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig
sp_model = spm.SentencePieceProcessor( tokenizer = LlamaTokenizer(
model_file="configs/10w_vocab_wudao5_pile10.model" "configs/10w_vocab_wudao5_pile10.model",
pad_token="<pad>",
add_bos_token=False,
add_eos_token=True,
) )
tokenizer = Tokenizer(sp_model)
raw_model = OpenLlamaForCausalLM( raw_model = OpenLlamaForCausalLM(
OpenLlamaConfig( OpenLlamaConfig(
vocab_size=tokenizer.vocab_size, vocab_size=tokenizer.vocab_size,
initializer_range=0.01, initializer_range=0.01,
pad_token_id=tokenizer.pad_id, pad_token_id=tokenizer.pad_token_id,
rms_norm_eps=1e-5, rms_norm_eps=1e-5,
hidden_dropout_prob=0.1, hidden_dropout_prob=0.1,
attention_dropout_prob=0.1, attention_dropout_prob=0.1,
@ -80,12 +81,20 @@ with gr.Blocks() as demo:
if completion is None: if completion is None:
inputs = "user:{}\nsystem:".format(prompt) inputs = "user:{}\nsystem:".format(prompt)
inputs = tokenizer( inputs = tokenizer(
inputs, return_tensors=True, add_special_tokens=False inputs,
return_tensors="pt",
add_special_tokens=False,
return_attention_mask=False,
) )
context.append(inputs["input_ids"]) context.append(inputs["input_ids"])
else: else:
inputs = "user:{}\nsystem:{}".format(prompt, completion) inputs = "user:{}\nsystem:{}".format(prompt, completion)
inputs = tokenizer(inputs, return_tensors=True, add_special_tokens=True) inputs = tokenizer(
inputs,
return_tensors="pt",
add_special_tokens=True,
return_attention_mask=False,
)
context.append(inputs["input_ids"]) context.append(inputs["input_ids"])
context = torch.cat(context, dim=-1) context = torch.cat(context, dim=-1)
context = context[:, -1024:] context = context[:, -1024:]
@ -93,7 +102,8 @@ with gr.Blocks() as demo:
context = context.cuda() context = context.cuda()
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True) pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
pred = pred[:, inputs_len:] pred = pred[:, inputs_len:]
pred = tokenizer.decode(pred.cpu())[0] pred = tokenizer.decode(pred.cpu()[0])
pred = pred.strip()
print(pred) print(pred)
bot_message = parse_codeblock(pred) bot_message = parse_codeblock(pred)
history[-1][1] = bot_message history[-1][1] = bot_message

View File

@ -2,18 +2,20 @@
Author: LiangSong(sl12160010@gmail.com) Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-12 19:12:42 Date: 2023-04-12 19:12:42
LastEditors: LiangSong(sl12160010@gmail.com) LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-24 20:06:19 LastEditTime: 2023-04-26 23:05:47
FilePath: /Open-Llama/pretrain.py FilePath: /Open-Llama/pretrain.py
Description: Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
""" """
import os
import yaml import yaml
import torch import torch
from absl import app from absl import app
from absl import flags from absl import flags
from accelerate import Accelerator from accelerate import Accelerator
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from datasets.distributed import split_dataset_by_node
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
from dataset.dataset import construct_dataset from dataset.dataset import construct_dataset
@ -36,6 +38,11 @@ def main(argv):
) )
data_config = config["data"] data_config = config["data"]
pretrain_dataset = construct_dataset(data_config, tokenizer) pretrain_dataset = construct_dataset(data_config, tokenizer)
pretrain_dataset = split_dataset_by_node(
pretrain_dataset,
rank=int(os.environ["RANK"]),
world_size=int(os.environ["WORLD_SIZE"]),
)
train_loader = DataLoader( train_loader = DataLoader(
pretrain_dataset, pretrain_dataset,
batch_size=config["train"]["train_batch_size"], batch_size=config["train"]["train_batch_size"],

View File

@ -2,7 +2,7 @@
Author: LiangSong(sl12160010@gmail.com) Author: LiangSong(sl12160010@gmail.com)
Date: 2023-04-24 20:05:21 Date: 2023-04-24 20:05:21
LastEditors: LiangSong(sl12160010@gmail.com) LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-04-24 20:06:07 LastEditTime: 2023-04-26 23:06:55
FilePath: /Open-Llama/solver/trainer.py FilePath: /Open-Llama/solver/trainer.py
Description: Description:
@ -87,7 +87,7 @@ class Trainer:
def prepare(self): def prepare(self):
( (
self.train_loader, _,
self.model, self.model,
self.optim, self.optim,
self.scheduler, self.scheduler,
@ -113,6 +113,8 @@ class Trainer:
self.optim.zero_grad() self.optim.zero_grad()
self.start_time = time.time() self.start_time = time.time()
for self.data_step, batch in enumerate(self.train_loader): for self.data_step, batch in enumerate(self.train_loader):
for k, v in batch.items():
batch[k] = v.to(self.accelerator.device, non_blocking=True)
if self.data_step >= self.config["train"]["num_training_steps"]: if self.data_step >= self.config["train"]["num_training_steps"]:
break break
self.model.train() self.model.train()