update tokenizer to LlamaTokenizer
This commit is contained in:
parent
f41f5558ec
commit
0377b43628
|
@ -6,14 +6,14 @@ deepspeed_config:
|
|||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
zero_stage: 1
|
||||
distributed_type: DEEPSPEED
|
||||
fsdp_config: {}
|
||||
machine_rank: 0
|
||||
main_process_ip: null
|
||||
main_process_port: null
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
use_cpu: false
|
|
@ -131,7 +131,7 @@ if __name__ == "__main__":
|
|||
import time
|
||||
from unicodedata import normalize
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import OpenLlamaTokenizer
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
data_config = {
|
||||
"mode": "pretrain",
|
||||
|
@ -140,7 +140,7 @@ if __name__ == "__main__":
|
|||
"num_sequences": 10,
|
||||
"seq_length": 2048,
|
||||
}
|
||||
tokenizer = OpenLlamaTokenizer(
|
||||
tokenizer = LlamaTokenizer(
|
||||
"configs/llama_tokenizer_extended.model",
|
||||
pad_token="<pad>",
|
||||
add_bos_token=False,
|
||||
|
|
|
@ -14,7 +14,7 @@ from absl import app
|
|||
from absl import flags
|
||||
from accelerate import Accelerator
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, OpenLlamaTokenizer
|
||||
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
|
||||
|
||||
from dataset.dataset import construct_dataset
|
||||
from solver.trainer import Trainer
|
||||
|
@ -28,7 +28,7 @@ def main(argv):
|
|||
|
||||
with open(FLAGS.config, "r", encoding="utf-8") as fp:
|
||||
config = yaml.load(fp, Loader=yaml.FullLoader)
|
||||
tokenizer = OpenLlamaTokenizer(
|
||||
tokenizer = LlamaTokenizer(
|
||||
config["data"]["tokenizer_model_path"],
|
||||
pad_token="<pad>",
|
||||
add_bos_token=False,
|
||||
|
|
|
@ -110,8 +110,8 @@ class Trainer:
|
|||
self.get_lr_scheduler()
|
||||
self.prepare()
|
||||
self.global_step = 0
|
||||
self.start_time = time.time()
|
||||
self.optim.zero_grad()
|
||||
self.start_time = time.time()
|
||||
for self.data_step, batch in enumerate(self.train_loader):
|
||||
if self.data_step >= self.config["train"]["num_training_steps"]:
|
||||
break
|
||||
|
|
Loading…
Reference in New Issue
Block a user