219 lines
7.2 KiB
Python
219 lines
7.2 KiB
Python
"""
|
||
Author: LiangSong(sl12160010@gmail.com)
|
||
Date: 2023-03-20 21:39:47
|
||
LastEditors: LiangSong(sl12160010@gmail.com)
|
||
LastEditTime: 2023-04-06 23:01:50
|
||
FilePath: /Open-Llama/dataset/tokenizer.py
|
||
Description:
|
||
|
||
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
|
||
"""
|
||
import torch
|
||
|
||
|
||
class Tokenizer:
|
||
def __init__(self, sp_model):
|
||
self.sp_model = sp_model
|
||
self.bos_id = self.sp_model.bos_id()
|
||
self.eos_id = self.sp_model.eos_id()
|
||
self.pad_id = self.sp_model.pad_id()
|
||
self.vocab_size = self.sp_model.vocab_size()
|
||
|
||
def __call__(
|
||
self,
|
||
inputs,
|
||
padding=None,
|
||
max_length=256,
|
||
return_tensors=False,
|
||
truncation=False,
|
||
add_special_tokens=True,
|
||
return_mask=False,
|
||
):
|
||
if isinstance(inputs, str):
|
||
return self.encode(
|
||
inputs,
|
||
padding=padding,
|
||
max_length=max_length,
|
||
return_tensors=return_tensors,
|
||
truncation=truncation,
|
||
add_special_tokens=add_special_tokens,
|
||
return_mask=return_mask,
|
||
)
|
||
else:
|
||
return self.encode_batch(
|
||
inputs,
|
||
padding=padding,
|
||
max_length=max_length,
|
||
return_tensors=return_tensors,
|
||
truncation=truncation,
|
||
add_special_tokens=add_special_tokens,
|
||
return_mask=return_mask,
|
||
)
|
||
|
||
def encode(
|
||
self,
|
||
inputs,
|
||
padding=None,
|
||
max_length=8192,
|
||
return_tensors=False,
|
||
truncation=False,
|
||
add_special_tokens=True,
|
||
return_mask=False,
|
||
):
|
||
assert isinstance(inputs, str)
|
||
input_ids = self.sp_model.Encode(inputs)
|
||
if return_mask:
|
||
attention_mask = [1] * len(input_ids)
|
||
if truncation:
|
||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L780
|
||
# 参考Transformer中的实现 默认最后一位一定是pad或者eos
|
||
input_ids = input_ids[: max_length - 1]
|
||
if return_mask:
|
||
attention_mask = attention_mask[: max_length - 1]
|
||
if add_special_tokens:
|
||
input_ids = input_ids + [self.eos_id]
|
||
if return_mask:
|
||
attention_mask = attention_mask + [0]
|
||
if padding == "max_length":
|
||
input_ids = input_ids + [self.pad_id] * (max_length - len(input_ids))
|
||
if return_mask:
|
||
attention_mask = attention_mask + [0] * (
|
||
max_length - len(attention_mask)
|
||
)
|
||
if return_tensors:
|
||
input_ids = torch.tensor([input_ids])
|
||
out = {
|
||
"input_ids": input_ids,
|
||
}
|
||
if return_mask:
|
||
attention_mask = torch.tensor([attention_mask])
|
||
out["attention_mask"] = attention_mask
|
||
else:
|
||
out = {
|
||
"input_ids": input_ids,
|
||
}
|
||
if return_mask:
|
||
out["attention_mask"] = attention_mask
|
||
return out
|
||
|
||
def encode_batch(
|
||
self,
|
||
inputs,
|
||
padding=None,
|
||
max_length=8192,
|
||
return_tensors=False,
|
||
truncation=False,
|
||
add_special_tokens=True,
|
||
return_mask=False,
|
||
):
|
||
input_ids = self.sp_model.Encode(inputs)
|
||
if return_mask:
|
||
attention_mask = [[1] * len(i) for i in input_ids]
|
||
if truncation:
|
||
input_ids = [i[: max_length - 1] for i in input_ids]
|
||
if return_mask:
|
||
attention_mask = [i[: max_length - 1] for i in attention_mask]
|
||
if add_special_tokens:
|
||
input_ids = [i + [self.eos_id] for i in input_ids]
|
||
if return_mask:
|
||
attention_mask = [i + [0] for i in attention_mask]
|
||
if padding == "max_length":
|
||
input_ids_pad = []
|
||
if return_mask:
|
||
attention_mask_pad = []
|
||
for idx, i in enumerate(input_ids):
|
||
input_ids_pad.append(i + [self.pad_id] * (max_length - len(i)))
|
||
if return_mask:
|
||
j = attention_mask[idx]
|
||
attention_mask_pad.append(j + [0] * (max_length - len(j)))
|
||
input_ids = input_ids_pad
|
||
if return_mask:
|
||
attention_mask = attention_mask_pad
|
||
if return_tensors:
|
||
input_ids = torch.tensor(input_ids)
|
||
out = {
|
||
"input_ids": input_ids,
|
||
}
|
||
if return_mask:
|
||
attention_mask = torch.tensor(attention_mask)
|
||
out["attention_mask"] = attention_mask
|
||
else:
|
||
out = {
|
||
"input_ids": input_ids,
|
||
}
|
||
if return_mask:
|
||
out["attention_mask"] = attention_mask
|
||
return out
|
||
|
||
def decode(self, inputs, max_rounds=None):
|
||
inputs = inputs.tolist()
|
||
out = []
|
||
for i, ids in enumerate(inputs):
|
||
count = 0
|
||
flag = False
|
||
for j, token in enumerate(ids):
|
||
if token == self.eos_id:
|
||
if max_rounds is None:
|
||
flag = True
|
||
break
|
||
elif isinstance(max_rounds, int):
|
||
if count < max_rounds:
|
||
count += 1
|
||
else:
|
||
flag = True
|
||
break
|
||
elif isinstance(max_rounds, list):
|
||
if count < max_rounds[i]:
|
||
count += 1
|
||
else:
|
||
flag = True
|
||
break
|
||
if flag:
|
||
ids = ids[:j]
|
||
else:
|
||
ids = ids
|
||
out.append(ids)
|
||
out = self.sp_model.Decode(out)
|
||
return out
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import sentencepiece as spm
|
||
from unicodedata import normalize
|
||
|
||
# Using sentencepiece may not be able to process some reserved keywords like '▁'.
|
||
sp_model = spm.SentencePieceProcessor(
|
||
model_file="configs/10w_vocab_wudao5_pile10.model"
|
||
)
|
||
tokenizer = Tokenizer(sp_model)
|
||
tmp = [
|
||
"hello world",
|
||
"这是开源项目的V1版本,this is the first version of a open-source project!",
|
||
"# this is a python script\nfor i in range(10):\n print(i)\n for j in range(10):\n print(j)",
|
||
]
|
||
print(tmp)
|
||
out = tokenizer(
|
||
tmp, padding="max_length", return_tensors=True, max_length=64, truncation=True
|
||
)
|
||
for k, v in out.items():
|
||
print(k, v.shape)
|
||
print(out["input_ids"])
|
||
out = tokenizer.decode(out["input_ids"])
|
||
print(out)
|
||
for i, j in zip(tmp, out):
|
||
assert normalize("NFKC", i) == j
|
||
|
||
from dataset.data_iter import create_shard_kwargs, DataIter
|
||
|
||
patterns = ["data/pretrain_data/part-wudao*.jsonl.zst"]
|
||
paths = create_shard_kwargs(patterns)
|
||
data_iter = DataIter(paths)
|
||
for i, data in enumerate(data_iter):
|
||
assert (
|
||
normalize("NFKC", data["content"])
|
||
== sp_model.Decode(sp_model.Encode(data["content"]))
|
||
or "▁" in data["content"]
|
||
)
|
||
if i == 1000:
|
||
break
|