update tokenizer to LlamaTokenizer

This commit is contained in:
LiangSong 2023-04-26 18:53:30 +08:00
parent f41f5558ec
commit 0377b43628
4 changed files with 8 additions and 8 deletions

View File

@ -6,14 +6,14 @@ deepspeed_config:
offload_optimizer_device: none offload_optimizer_device: none
offload_param_device: none offload_param_device: none
zero3_init_flag: false zero3_init_flag: false
zero_stage: 2 zero_stage: 1
distributed_type: DEEPSPEED distributed_type: DEEPSPEED
fsdp_config: {} fsdp_config: {}
machine_rank: 0 machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main main_training_function: main
mixed_precision: bf16 mixed_precision: bf16
num_machines: 1 num_machines: 1
num_processes: 8 num_processes: 8
rdzv_backend: static
same_network: true
use_cpu: false use_cpu: false

View File

@ -131,7 +131,7 @@ if __name__ == "__main__":
import time import time
from unicodedata import normalize from unicodedata import normalize
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import OpenLlamaTokenizer from transformers import LlamaTokenizer
data_config = { data_config = {
"mode": "pretrain", "mode": "pretrain",
@ -140,7 +140,7 @@ if __name__ == "__main__":
"num_sequences": 10, "num_sequences": 10,
"seq_length": 2048, "seq_length": 2048,
} }
tokenizer = OpenLlamaTokenizer( tokenizer = LlamaTokenizer(
"configs/llama_tokenizer_extended.model", "configs/llama_tokenizer_extended.model",
pad_token="<pad>", pad_token="<pad>",
add_bos_token=False, add_bos_token=False,

View File

@ -14,7 +14,7 @@ from absl import app
from absl import flags from absl import flags
from accelerate import Accelerator from accelerate import Accelerator
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, OpenLlamaTokenizer from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
from dataset.dataset import construct_dataset from dataset.dataset import construct_dataset
from solver.trainer import Trainer from solver.trainer import Trainer
@ -28,7 +28,7 @@ def main(argv):
with open(FLAGS.config, "r", encoding="utf-8") as fp: with open(FLAGS.config, "r", encoding="utf-8") as fp:
config = yaml.load(fp, Loader=yaml.FullLoader) config = yaml.load(fp, Loader=yaml.FullLoader)
tokenizer = OpenLlamaTokenizer( tokenizer = LlamaTokenizer(
config["data"]["tokenizer_model_path"], config["data"]["tokenizer_model_path"],
pad_token="<pad>", pad_token="<pad>",
add_bos_token=False, add_bos_token=False,

View File

@ -110,8 +110,8 @@ class Trainer:
self.get_lr_scheduler() self.get_lr_scheduler()
self.prepare() self.prepare()
self.global_step = 0 self.global_step = 0
self.start_time = time.time()
self.optim.zero_grad() self.optim.zero_grad()
self.start_time = time.time()
for self.data_step, batch in enumerate(self.train_loader): for self.data_step, batch in enumerate(self.train_loader):
if self.data_step >= self.config["train"]["num_training_steps"]: if self.data_step >= self.config["train"]["num_training_steps"]:
break break