Open-Llama/utils/convert_ckpt.py
2023-05-17 22:21:46 +07:00

77 lines
2.4 KiB
Python

"""
Author: s-JoL(sl12160010@gmail.com)
Date: 2023-04-28 19:55:13
LastEditors: s-JoL(sl12160010@gmail.com)
LastEditTime: 2023-05-06 23:30:29
FilePath: /Open-Llama/utils/convert_ckpt.py
Description:
Copyright (c) 2023 by s-JoL(sl12160010@gmail.com), All Rights Reserved.
"""
import torch
import sentencepiece as spm
sp_model = spm.SentencePieceProcessor(
model_file="configs/tokenizer_models/llama_tokenizer_extended.model"
)
merged_vocab_size = sp_model.vocab_size()
ckpt = torch.load("data/llama_raw_ckpt/7B/consolidated.00.pth")
raw_vocab_size, hidden_size = ckpt["tok_embeddings.weight"].shape
extended_tok_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden_size)
extended_tok_embeddings = extended_tok_embeddings * 0.001
ckpt["tok_embeddings.weight"] = torch.cat(
[ckpt["tok_embeddings.weight"], extended_tok_embeddings], dim=0
)
extended_out_embeddings = torch.randn(merged_vocab_size - raw_vocab_size, hidden_size)
extended_out_embeddings = extended_out_embeddings * 0.001
ckpt["output.weight"] = torch.cat(
[ckpt["output.weight"], extended_out_embeddings], dim=0
)
rename_map = {
"tok_embeddings.weight": "model.embed_tokens.weight",
"norm.weight": "model.norm.weight",
"output.weight": "lm_head.weight",
}
for f, t in rename_map.items():
v = ckpt.pop(f)
ckpt[t] = v
from_names = [
"layers.{}.attention.wq.weight",
"layers.{}.attention.wk.weight",
"layers.{}.attention.wv.weight",
"layers.{}.attention.wo.weight",
"layers.{}.feed_forward.w1.weight",
"layers.{}.feed_forward.w2.weight",
"layers.{}.feed_forward.w3.weight",
"layers.{}.attention_norm.weight",
"layers.{}.ffn_norm.weight",
"layers.{}.attention.inner_attention.rope.freqs",
]
to_names = [
"model.layers.{}.self_attn.q_proj.weight",
"model.layers.{}.self_attn.k_proj.weight",
"model.layers.{}.self_attn.v_proj.weight",
"model.layers.{}.self_attn.o_proj.weight",
"model.layers.{}.mlp.gate_proj.weight",
"model.layers.{}.mlp.down_proj.weight",
"model.layers.{}.mlp.up_proj.weight",
"model.layers.{}.input_layernorm.weight",
"model.layers.{}.post_attention_layernorm.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq",
]
for layer in range(32):
for f, t in zip(from_names, to_names):
f = f.format(layer)
t = t.format(layer)
v = ckpt.pop(f)
ckpt[t] = v
torch.save(ckpt, "data/llama_raw_ckpt/7B/extended.pth")