Open-Llama/dataset/tokenizer.py
2023-03-26 23:59:53 +08:00

146 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

'''
Author: LiangSong(sl12160010@gmail.com)
Date: 2023-03-20 21:39:47
LastEditors: LiangSong(sl12160010@gmail.com)
LastEditTime: 2023-03-26 23:09:39
FilePath: /Open-Llama/dataset/tokenizer.py
Description:
Copyright (c) 2023 by LiangSong(sl12160010@gmail.com), All Rights Reserved.
'''
import torch
class Tokenizer:
def __init__(self, sp_model):
self.sp_model = sp_model
self.bos_id = self.sp_model.bos_id()
self.eos_id = self.sp_model.eos_id()
self.pad_id = self.sp_model.pad_id()
self.vocab_size = self.sp_model.vocab_size()
def __call__(self, inputs, padding=None, max_length=256, return_tensors=False, truncation=False,
add_special_tokens=True, return_mask=False):
if isinstance(inputs, str):
return self.encode(inputs, padding=padding, max_length=max_length,
return_tensors=return_tensors, truncation=truncation, add_special_tokens=add_special_tokens, return_mask=return_mask)
else:
return self.encode_batch(inputs, padding=padding, max_length=max_length,
return_tensors=return_tensors, truncation=truncation, add_special_tokens=add_special_tokens, return_mask=return_mask)
def encode(self, inputs, padding=None, max_length=8192, return_tensors=False, truncation=False,
add_special_tokens=True, return_mask=False):
assert(isinstance(inputs, str))
input_ids = self.sp_model.Encode(inputs)
if return_mask:
attention_mask = [1] * len(input_ids)
if truncation:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L780
# 参考Transformer中的实现 默认最后一位一定是pad或者eos
input_ids = input_ids[: max_length-1]
if return_mask:
attention_mask = attention_mask[: max_length-1]
if add_special_tokens:
input_ids = input_ids + [self.eos_id]
if return_mask:
attention_mask = attention_mask + [0]
if padding == 'max_length':
input_ids = input_ids + [self.pad_id] * (max_length-len(input_ids))
if return_mask:
attention_mask = attention_mask + [0] * (max_length-len(attention_mask))
if return_tensors:
input_ids = torch.tensor([input_ids])
out = {
'input_ids': input_ids,
}
if return_mask:
attention_mask = torch.tensor([attention_mask])
out['attention_mask'] = attention_mask
else:
out = {
'input_ids': input_ids,
}
if return_mask:
out['attention_mask'] = attention_mask
return out
def encode_batch(self, inputs, padding=None, max_length=8192, return_tensors=False, truncation=False,
add_special_tokens=True, return_mask=False):
input_ids = self.sp_model.Encode(inputs)
if return_mask:
attention_mask = [[1] * len(i) for i in input_ids]
if truncation:
input_ids = [i[: max_length-1] for i in input_ids]
if return_mask:
attention_mask = [i[: max_length-1] for i in attention_mask]
if add_special_tokens:
input_ids = [i+[self.eos_id] for i in input_ids]
if return_mask:
attention_mask = [i+[0] for i in attention_mask]
if padding == 'max_length':
input_ids_pad = []
if return_mask:
attention_mask_pad = []
for idx, i in enumerate(input_ids):
input_ids_pad.append(i + [self.pad_id] * (max_length-len(i)))
if return_mask:
j = attention_mask[idx]
attention_mask_pad.append(j + [0] * (max_length-len(j)))
input_ids = input_ids_pad
if return_mask:
attention_mask = attention_mask_pad
if return_tensors:
input_ids = torch.tensor(input_ids)
out = {
'input_ids': input_ids,
}
if return_mask:
attention_mask = torch.tensor(attention_mask)
out['attention_mask'] = attention_mask
else:
out = {
'input_ids': input_ids,
}
if return_mask:
out['attention_mask'] = attention_mask
return out
def decode(self, inputs):
inputs = inputs.tolist()
out = []
for i in inputs:
if self.eos_id in i:
eos_idx = i.index(self.eos_id)
i = i[: eos_idx]
out.append(i)
out = self.sp_model.Decode(out)
return out
if __name__ == '__main__':
import sentencepiece as spm
from unicodedata import normalize
# Using sentencepiece may not be able to process some reserved keywords like '▁'.
sp_model = spm.SentencePieceProcessor(model_file='configs/10w_vocab_wudao5_pile10.model')
tokenizer = Tokenizer(sp_model)
tmp = ['hello world',
'这是开源项目的V1版本this is the first version of a open-source project!',
'# this is a python script\nfor i in range(10):\n print(i)\n for j in range(10):\n print(j)']
print(tmp)
out = tokenizer(tmp, padding='max_length', return_tensors=True, max_length=64, truncation=True)
for k, v in out.items():
print(k, v.shape)
print(out['input_ids'])
out = tokenizer.decode(out['input_ids'])
print(out)
for i, j in zip(tmp, out):
assert(normalize('NFKC', i) == j)
from dataset.data_iter import create_shard_kwargs, create_data_iter
patterns = [
'data/pretrain_data/part-wudao*.jsonl.zst'
]
paths = create_shard_kwargs(patterns)
data_iter = create_data_iter(paths)
for i, data in enumerate(data_iter):
assert(normalize('NFKC', data['content']) == sp_model.Decode(sp_model.Encode(data['content'])) or '' in data['content'])
if i == 1000:
break