67 lines
2.1 KiB
Python
67 lines
2.1 KiB
Python
import torch
|
|
import sentencepiece as spm
|
|
|
|
|
|
sp_model = spm.SentencePieceProcessor(
|
|
model_file="configs/llama_tokenizer_extended.model"
|
|
)
|
|
merged_vocab_size = sp_model.vocab_size()
|
|
ckpt = torch.load("data/llama_raw_ckpt/7B/consolidated.00.pth")
|
|
|
|
raw_vocab_size, hidden_size = ckpt["tok_embeddings.weight"].shape
|
|
extended_tok_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden_size)
|
|
extended_tok_embeddings = extended_tok_embeddings * 0.001
|
|
ckpt["tok_embeddings.weight"] = torch.cat(
|
|
[ckpt["tok_embeddings.weight"], extended_tok_embeddings], dim=0
|
|
)
|
|
|
|
extended_out_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden_size)
|
|
extended_out_embeddings = extended_out_embeddings * 0.001
|
|
ckpt["output.weight"] = torch.cat(
|
|
[ckpt["output.weight"], extended_out_embeddings], dim=0
|
|
)
|
|
|
|
rename_map = {
|
|
"tok_embeddings.weight": "model.embed_tokens.weight",
|
|
"norm.weight": "model.norm.weight",
|
|
"output.weight": "lm_head.weight",
|
|
}
|
|
|
|
for f, t in rename_map.items():
|
|
v = ckpt.pop(f)
|
|
ckpt[t] = v
|
|
|
|
from_names = [
|
|
"layers.{}.attention.wq.weight",
|
|
"layers.{}.attention.wk.weight",
|
|
"layers.{}.attention.wv.weight",
|
|
"layers.{}.attention.wo.weight",
|
|
"layers.{}.feed_forward.w1.weight",
|
|
"layers.{}.feed_forward.w2.weight",
|
|
"layers.{}.feed_forward.w3.weight",
|
|
"layers.{}.attention_norm.weight",
|
|
"layers.{}.ffn_norm.weight",
|
|
"layers.{}.attention.inner_attention.rope.freqs",
|
|
]
|
|
|
|
to_names = [
|
|
"model.layers.{}.self_attn.q_proj.weight",
|
|
"model.layers.{}.self_attn.k_proj.weight",
|
|
"model.layers.{}.self_attn.v_proj.weight",
|
|
"model.layers.{}.self_attn.o_proj.weight",
|
|
"model.layers.{}.mlp.gate_proj.weight",
|
|
"model.layers.{}.mlp.down_proj.weight",
|
|
"model.layers.{}.mlp.up_proj.weight",
|
|
"model.layers.{}.input_layernorm.weight",
|
|
"model.layers.{}.post_attention_layernorm.weight",
|
|
"model.layers.{}.self_attn.rotary_emb.inv_freq",
|
|
]
|
|
|
|
for layer in range(32):
|
|
for f, t in zip(from_names, to_names):
|
|
f = f.format(layer)
|
|
t = t.format(layer)
|
|
v = ckpt.pop(f)
|
|
ckpt[t] = v
|
|
torch.save(ckpt, "data/llama_raw_ckpt/7B/extended.pth")
|