From 97aff0e05100ca9f9b7101594d0fd7cd052c256c Mon Sep 17 00:00:00 2001 From: LiangSong Date: Thu, 27 Apr 2023 00:04:11 +0800 Subject: [PATCH] use split_dataset_by_node instead accelerate.prepare to accelerate data loading by 50% --- chat_server.py | 32 +++++++++++++++++++++----------- pretrain.py | 9 ++++++++- solver/trainer.py | 6 ++++-- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/chat_server.py b/chat_server.py index 2500d15..d6db705 100644 --- a/chat_server.py +++ b/chat_server.py @@ -2,7 +2,7 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-06 22:30:10 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-04-07 23:03:31 +LastEditTime: 2023-04-26 23:58:23 FilePath: /Open-Llama/chat_server.py Description: @@ -10,20 +10,21 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. """ import torch import gradio as gr -import sentencepiece as spm -from dataset.tokenizer import Tokenizer -from transformers import OpenLlamaForCausalLM, OpenLlamaConfig +from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer -sp_model = spm.SentencePieceProcessor( - model_file="configs/10w_vocab_wudao5_pile10.model" +tokenizer = LlamaTokenizer( + "configs/10w_vocab_wudao5_pile10.model", + pad_token="", + add_bos_token=False, + add_eos_token=True, ) -tokenizer = Tokenizer(sp_model) + raw_model = OpenLlamaForCausalLM( OpenLlamaConfig( vocab_size=tokenizer.vocab_size, initializer_range=0.01, - pad_token_id=tokenizer.pad_id, + pad_token_id=tokenizer.pad_token_id, rms_norm_eps=1e-5, hidden_dropout_prob=0.1, attention_dropout_prob=0.1, @@ -80,12 +81,20 @@ with gr.Blocks() as demo: if completion is None: inputs = "user:{}\nsystem:".format(prompt) inputs = tokenizer( - inputs, return_tensors=True, add_special_tokens=False + inputs, + return_tensors="pt", + add_special_tokens=False, + return_attention_mask=False, ) context.append(inputs["input_ids"]) else: inputs = "user:{}\nsystem:{}".format(prompt, completion) - inputs = tokenizer(inputs, return_tensors=True, add_special_tokens=True) + inputs = tokenizer( + inputs, + return_tensors="pt", + add_special_tokens=True, + return_attention_mask=False, + ) context.append(inputs["input_ids"]) context = torch.cat(context, dim=-1) context = context[:, -1024:] @@ -93,7 +102,8 @@ with gr.Blocks() as demo: context = context.cuda() pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True) pred = pred[:, inputs_len:] - pred = tokenizer.decode(pred.cpu())[0] + pred = tokenizer.decode(pred.cpu()[0]) + pred = pred.strip() print(pred) bot_message = parse_codeblock(pred) history[-1][1] = bot_message diff --git a/pretrain.py b/pretrain.py index 5c63bc3..f253a2c 100644 --- a/pretrain.py +++ b/pretrain.py @@ -2,18 +2,20 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-12 19:12:42 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-04-24 20:06:19 +LastEditTime: 2023-04-26 23:05:47 FilePath: /Open-Llama/pretrain.py Description: Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved. """ +import os import yaml import torch from absl import app from absl import flags from accelerate import Accelerator from torch.utils.data import DataLoader +from datasets.distributed import split_dataset_by_node from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer from dataset.dataset import construct_dataset @@ -36,6 +38,11 @@ def main(argv): ) data_config = config["data"] pretrain_dataset = construct_dataset(data_config, tokenizer) + pretrain_dataset = split_dataset_by_node( + pretrain_dataset, + rank=int(os.environ["RANK"]), + world_size=int(os.environ["WORLD_SIZE"]), + ) train_loader = DataLoader( pretrain_dataset, batch_size=config["train"]["train_batch_size"], diff --git a/solver/trainer.py b/solver/trainer.py index 30d78bc..f56db0c 100644 --- a/solver/trainer.py +++ b/solver/trainer.py @@ -2,7 +2,7 @@ Author: LiangSong(sl12160010@gmail.com) Date: 2023-04-24 20:05:21 LastEditors: LiangSong(sl12160010@gmail.com) -LastEditTime: 2023-04-24 20:06:07 +LastEditTime: 2023-04-26 23:06:55 FilePath: /Open-Llama/solver/trainer.py Description: @@ -87,7 +87,7 @@ class Trainer: def prepare(self): ( - self.train_loader, + _, self.model, self.optim, self.scheduler, @@ -113,6 +113,8 @@ class Trainer: self.optim.zero_grad() self.start_time = time.time() for self.data_step, batch in enumerate(self.train_loader): + for k, v in batch.items(): + batch[k] = v.to(self.accelerator.device, non_blocking=True) if self.data_step >= self.config["train"]["num_training_steps"]: break self.model.train()