Open-Llama/utils/convert_ckpt.py

67 lines
2.1 KiB
Python
Raw Normal View History

2023-04-12 09:15:40 +00:00
import torch
import sentencepiece as spm
sp_model = spm.SentencePieceProcessor(
model_file="configs/llama_tokenizer_extended.model"
)
merged_vocab_size = sp_model.vocab_size()
2023-04-12 14:16:15 +00:00
ckpt = torch.load("data/llama_raw_ckpt/7B/consolidated.00.pth")
2023-04-12 09:15:40 +00:00
2023-04-12 14:16:15 +00:00
raw_vocab_size, hidden_size = ckpt["tok_embeddings.weight"].shape
2023-04-12 09:15:40 +00:00
extended_tok_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden_size)
extended_tok_embeddings = extended_tok_embeddings * 0.001
2023-04-12 14:16:15 +00:00
ckpt["tok_embeddings.weight"] = torch.cat(
[ckpt["tok_embeddings.weight"], extended_tok_embeddings], dim=0
)
2023-04-12 09:15:40 +00:00
extended_out_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden_size)
extended_out_embeddings = extended_out_embeddings * 0.001
2023-04-12 14:16:15 +00:00
ckpt["output.weight"] = torch.cat(
[ckpt["output.weight"], extended_out_embeddings], dim=0
)
2023-04-12 09:15:40 +00:00
2023-04-12 09:59:05 +00:00
rename_map = {
"tok_embeddings.weight": "model.embed_tokens.weight",
"norm.weight": "model.norm.weight",
"output.weight": "lm_head.weight",
}
for f, t in rename_map.items():
v = ckpt.pop(f)
ckpt[t] = v
2023-04-12 14:16:15 +00:00
2023-04-12 09:59:05 +00:00
from_names = [
2023-04-12 14:16:15 +00:00
"layers.{}.attention.wq.weight",
"layers.{}.attention.wk.weight",
"layers.{}.attention.wv.weight",
"layers.{}.attention.wo.weight",
"layers.{}.feed_forward.w1.weight",
"layers.{}.feed_forward.w2.weight",
"layers.{}.feed_forward.w3.weight",
"layers.{}.attention_norm.weight",
"layers.{}.ffn_norm.weight",
"layers.{}.attention.inner_attention.rope.freqs",
2023-04-12 09:59:05 +00:00
]
to_names = [
2023-04-12 14:16:15 +00:00
"model.layers.{}.self_attn.q_proj.weight",
"model.layers.{}.self_attn.k_proj.weight",
"model.layers.{}.self_attn.v_proj.weight",
"model.layers.{}.self_attn.o_proj.weight",
"model.layers.{}.mlp.gate_proj.weight",
"model.layers.{}.mlp.down_proj.weight",
"model.layers.{}.mlp.up_proj.weight",
"model.layers.{}.input_layernorm.weight",
2023-04-12 09:59:05 +00:00
"model.layers.{}.post_attention_layernorm.weight",
2023-04-12 14:16:15 +00:00
"model.layers.{}.self_attn.rotary_emb.inv_freq",
2023-04-12 09:59:05 +00:00
]
for layer in range(32):
for f, t in zip(from_names, to_names):
f = f.format(layer)
t = t.format(layer)
v = ckpt.pop(f)
ckpt[t] = v
2023-04-12 14:16:15 +00:00
torch.save(ckpt, "data/llama_raw_ckpt/7B/extended.pth")