use split_dataset_by_node instead accelerate.prepare to accelerate data loading by 50%
This commit is contained in:
parent
0377b43628
commit
97aff0e051
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-06 22:30:10
|
Date: 2023-04-06 22:30:10
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-07 23:03:31
|
LastEditTime: 2023-04-26 23:58:23
|
||||||
FilePath: /Open-Llama/chat_server.py
|
FilePath: /Open-Llama/chat_server.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -10,20 +10,21 @@ Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import sentencepiece as spm
|
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
|
||||||
from dataset.tokenizer import Tokenizer
|
|
||||||
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig
|
|
||||||
|
|
||||||
|
|
||||||
sp_model = spm.SentencePieceProcessor(
|
tokenizer = LlamaTokenizer(
|
||||||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
"configs/10w_vocab_wudao5_pile10.model",
|
||||||
|
pad_token="<pad>",
|
||||||
|
add_bos_token=False,
|
||||||
|
add_eos_token=True,
|
||||||
)
|
)
|
||||||
tokenizer = Tokenizer(sp_model)
|
|
||||||
raw_model = OpenLlamaForCausalLM(
|
raw_model = OpenLlamaForCausalLM(
|
||||||
OpenLlamaConfig(
|
OpenLlamaConfig(
|
||||||
vocab_size=tokenizer.vocab_size,
|
vocab_size=tokenizer.vocab_size,
|
||||||
initializer_range=0.01,
|
initializer_range=0.01,
|
||||||
pad_token_id=tokenizer.pad_id,
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
rms_norm_eps=1e-5,
|
rms_norm_eps=1e-5,
|
||||||
hidden_dropout_prob=0.1,
|
hidden_dropout_prob=0.1,
|
||||||
attention_dropout_prob=0.1,
|
attention_dropout_prob=0.1,
|
||||||
|
@ -80,12 +81,20 @@ with gr.Blocks() as demo:
|
||||||
if completion is None:
|
if completion is None:
|
||||||
inputs = "user:{}\nsystem:".format(prompt)
|
inputs = "user:{}\nsystem:".format(prompt)
|
||||||
inputs = tokenizer(
|
inputs = tokenizer(
|
||||||
inputs, return_tensors=True, add_special_tokens=False
|
inputs,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=False,
|
||||||
|
return_attention_mask=False,
|
||||||
)
|
)
|
||||||
context.append(inputs["input_ids"])
|
context.append(inputs["input_ids"])
|
||||||
else:
|
else:
|
||||||
inputs = "user:{}\nsystem:{}".format(prompt, completion)
|
inputs = "user:{}\nsystem:{}".format(prompt, completion)
|
||||||
inputs = tokenizer(inputs, return_tensors=True, add_special_tokens=True)
|
inputs = tokenizer(
|
||||||
|
inputs,
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_attention_mask=False,
|
||||||
|
)
|
||||||
context.append(inputs["input_ids"])
|
context.append(inputs["input_ids"])
|
||||||
context = torch.cat(context, dim=-1)
|
context = torch.cat(context, dim=-1)
|
||||||
context = context[:, -1024:]
|
context = context[:, -1024:]
|
||||||
|
@ -93,7 +102,8 @@ with gr.Blocks() as demo:
|
||||||
context = context.cuda()
|
context = context.cuda()
|
||||||
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
|
pred = model.generate(input_ids=context, max_new_tokens=512, do_sample=True)
|
||||||
pred = pred[:, inputs_len:]
|
pred = pred[:, inputs_len:]
|
||||||
pred = tokenizer.decode(pred.cpu())[0]
|
pred = tokenizer.decode(pred.cpu()[0])
|
||||||
|
pred = pred.strip()
|
||||||
print(pred)
|
print(pred)
|
||||||
bot_message = parse_codeblock(pred)
|
bot_message = parse_codeblock(pred)
|
||||||
history[-1][1] = bot_message
|
history[-1][1] = bot_message
|
||||||
|
|
|
@ -2,18 +2,20 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-12 19:12:42
|
Date: 2023-04-12 19:12:42
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-24 20:06:19
|
LastEditTime: 2023-04-26 23:05:47
|
||||||
FilePath: /Open-Llama/pretrain.py
|
FilePath: /Open-Llama/pretrain.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
import torch
|
import torch
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from datasets.distributed import split_dataset_by_node
|
||||||
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
|
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
|
||||||
|
|
||||||
from dataset.dataset import construct_dataset
|
from dataset.dataset import construct_dataset
|
||||||
|
@ -36,6 +38,11 @@ def main(argv):
|
||||||
)
|
)
|
||||||
data_config = config["data"]
|
data_config = config["data"]
|
||||||
pretrain_dataset = construct_dataset(data_config, tokenizer)
|
pretrain_dataset = construct_dataset(data_config, tokenizer)
|
||||||
|
pretrain_dataset = split_dataset_by_node(
|
||||||
|
pretrain_dataset,
|
||||||
|
rank=int(os.environ["RANK"]),
|
||||||
|
world_size=int(os.environ["WORLD_SIZE"]),
|
||||||
|
)
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
pretrain_dataset,
|
pretrain_dataset,
|
||||||
batch_size=config["train"]["train_batch_size"],
|
batch_size=config["train"]["train_batch_size"],
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
Author: LiangSong(sl12160010@gmail.com)
|
Author: LiangSong(sl12160010@gmail.com)
|
||||||
Date: 2023-04-24 20:05:21
|
Date: 2023-04-24 20:05:21
|
||||||
LastEditors: LiangSong(sl12160010@gmail.com)
|
LastEditors: LiangSong(sl12160010@gmail.com)
|
||||||
LastEditTime: 2023-04-24 20:06:07
|
LastEditTime: 2023-04-26 23:06:55
|
||||||
FilePath: /Open-Llama/solver/trainer.py
|
FilePath: /Open-Llama/solver/trainer.py
|
||||||
Description:
|
Description:
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ class Trainer:
|
||||||
|
|
||||||
def prepare(self):
|
def prepare(self):
|
||||||
(
|
(
|
||||||
self.train_loader,
|
_,
|
||||||
self.model,
|
self.model,
|
||||||
self.optim,
|
self.optim,
|
||||||
self.scheduler,
|
self.scheduler,
|
||||||
|
@ -113,6 +113,8 @@ class Trainer:
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
for self.data_step, batch in enumerate(self.train_loader):
|
for self.data_step, batch in enumerate(self.train_loader):
|
||||||
|
for k, v in batch.items():
|
||||||
|
batch[k] = v.to(self.accelerator.device, non_blocking=True)
|
||||||
if self.data_step >= self.config["train"]["num_training_steps"]:
|
if self.data_step >= self.config["train"]["num_training_steps"]:
|
||||||
break
|
break
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user