update tokenizer to LlamaTokenizer
This commit is contained in:
parent
f41f5558ec
commit
0377b43628
|
@ -6,14 +6,14 @@ deepspeed_config:
|
||||||
offload_optimizer_device: none
|
offload_optimizer_device: none
|
||||||
offload_param_device: none
|
offload_param_device: none
|
||||||
zero3_init_flag: false
|
zero3_init_flag: false
|
||||||
zero_stage: 2
|
zero_stage: 1
|
||||||
distributed_type: DEEPSPEED
|
distributed_type: DEEPSPEED
|
||||||
fsdp_config: {}
|
fsdp_config: {}
|
||||||
machine_rank: 0
|
machine_rank: 0
|
||||||
main_process_ip: null
|
|
||||||
main_process_port: null
|
|
||||||
main_training_function: main
|
main_training_function: main
|
||||||
mixed_precision: bf16
|
mixed_precision: bf16
|
||||||
num_machines: 1
|
num_machines: 1
|
||||||
num_processes: 8
|
num_processes: 8
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
use_cpu: false
|
use_cpu: false
|
|
@ -131,7 +131,7 @@ if __name__ == "__main__":
|
||||||
import time
|
import time
|
||||||
from unicodedata import normalize
|
from unicodedata import normalize
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from transformers import OpenLlamaTokenizer
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
data_config = {
|
data_config = {
|
||||||
"mode": "pretrain",
|
"mode": "pretrain",
|
||||||
|
@ -140,7 +140,7 @@ if __name__ == "__main__":
|
||||||
"num_sequences": 10,
|
"num_sequences": 10,
|
||||||
"seq_length": 2048,
|
"seq_length": 2048,
|
||||||
}
|
}
|
||||||
tokenizer = OpenLlamaTokenizer(
|
tokenizer = LlamaTokenizer(
|
||||||
"configs/llama_tokenizer_extended.model",
|
"configs/llama_tokenizer_extended.model",
|
||||||
pad_token="<pad>",
|
pad_token="<pad>",
|
||||||
add_bos_token=False,
|
add_bos_token=False,
|
||||||
|
|
|
@ -14,7 +14,7 @@ from absl import app
|
||||||
from absl import flags
|
from absl import flags
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, OpenLlamaTokenizer
|
from transformers import OpenLlamaForCausalLM, OpenLlamaConfig, LlamaTokenizer
|
||||||
|
|
||||||
from dataset.dataset import construct_dataset
|
from dataset.dataset import construct_dataset
|
||||||
from solver.trainer import Trainer
|
from solver.trainer import Trainer
|
||||||
|
@ -28,7 +28,7 @@ def main(argv):
|
||||||
|
|
||||||
with open(FLAGS.config, "r", encoding="utf-8") as fp:
|
with open(FLAGS.config, "r", encoding="utf-8") as fp:
|
||||||
config = yaml.load(fp, Loader=yaml.FullLoader)
|
config = yaml.load(fp, Loader=yaml.FullLoader)
|
||||||
tokenizer = OpenLlamaTokenizer(
|
tokenizer = LlamaTokenizer(
|
||||||
config["data"]["tokenizer_model_path"],
|
config["data"]["tokenizer_model_path"],
|
||||||
pad_token="<pad>",
|
pad_token="<pad>",
|
||||||
add_bos_token=False,
|
add_bos_token=False,
|
||||||
|
|
|
@ -110,8 +110,8 @@ class Trainer:
|
||||||
self.get_lr_scheduler()
|
self.get_lr_scheduler()
|
||||||
self.prepare()
|
self.prepare()
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
self.start_time = time.time()
|
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
self.start_time = time.time()
|
||||||
for self.data_step, batch in enumerate(self.train_loader):
|
for self.data_step, batch in enumerate(self.train_loader):
|
||||||
if self.data_step >= self.config["train"]["num_training_steps"]:
|
if self.data_step >= self.config["train"]["num_training_steps"]:
|
||||||
break
|
break
|
||||||
|
|
Loading…
Reference in New Issue
Block a user