fix: use right conversion script

This commit is contained in:
Zach Nussbaum 2023-05-11 11:20:43 -04:00
parent f8fdcccc5d
commit 1ed71fbbf8

View File

@ -76,24 +76,32 @@ fout = open(fname_out, "wb")
vocab = tokenizer.vocab vocab = tokenizer.vocab
hparams["multiple_of"] = 1 hparams["multiple_of"] = 1
fout.write(struct.pack("i", 0x67676d6d)) # magic: ggml in hex fout.write(struct.pack("I", 0x67676d6d)) # magic: ggml in hex
fout.write(struct.pack("i", hparams["vocab_size"])) fout.write(struct.pack("I", model.config.vocab_size))
fout.write(struct.pack("i", hparams["max_seq_len"])) fout.write(struct.pack("I", model.config.max_seq_len))
fout.write(struct.pack("i", hparams["d_model"])) fout.write(struct.pack("I", model.config.n_layers))
fout.write(struct.pack("i", hparams["n_heads"])) fout.write(struct.pack("I", model.config.n_heads))
fout.write(struct.pack("i", hparams["n_layers"])) fout.write(struct.pack("I", model.config.d_model))
# n_rot (unused) fout.write(struct.pack("f", model.config.attn_config['alibi_bias_max']))
fout.write(struct.pack("i", 0)) clip_qkv = model.config.attn_config['clip_qkv']
fout.write(struct.pack("i", ftype)) fout.write(struct.pack("f", clip_qkv if clip_qkv is not None else 0))
fout.write(struct.pack("I", ftype))
# # Is this correct?? # # Is this correct??
# dot_token = tokenizer.encode(".")[0] # dot_token = tokenizer.encode(".")[0]
# write tokens to ggml file # write tokens to ggml file
fout.write(struct.pack("i", hparams["vocab_size"])) dot_token = tokenizer.encode('.')[0]
fout.write(struct.pack("I", model.config.vocab_size))
for i in range(hparams["vocab_size"]): for i in range(model.config.vocab_size):
text = tokenizer.decode([i]).encode('utf-8') text = tokenizer.decode([dot_token, i]).encode('utf-8')
fout.write(struct.pack("i", len(text))) # remove the first byte (it's always '.')
text = text[1:]
enclen = len(text)
if i in tokenizer.all_special_ids:
print(f"special token: {text}")
enclen = enclen | 1<<31
fout.write(struct.pack("I", enclen))
fout.write(text) fout.write(text)
list_vars = model.state_dict() list_vars = model.state_dict()
@ -101,73 +109,35 @@ for name in list_vars.keys():
data = list_vars[name].squeeze().numpy() data = list_vars[name].squeeze().numpy()
print("Processing variable: " + name + " with shape: ", data.shape) print("Processing variable: " + name + " with shape: ", data.shape)
# we don't need these n_dims = len(data.shape);
if name.endswith("attn.masked_bias") or name.endswith(".attn.bias"):
print(" Skipping variable: " + name)
continue
if "Wqkv.weight" in name: # ftype == 0 -> float32, ftype == 1 -> float16
# chunk qkv ftype_cur = 0;
query, key, value = np.split(data, 3, axis=0) if ftype != 0:
# Keep token embeddings in fp32
new_name = name.split("Wqkv.weight")[0] if name[-7:] == ".weight" and n_dims == 2 and ".wte" not in name:
print(" Converting to float16")
for (data, name) in [(query, new_name + "q_proj.weight"), (key, new_name + "k_proj.weight"), (value, new_name + "v_proj.weight")]: data = data.astype(np.float16)
print(f"Processing variable: {name} with shape: {data.shape}") ftype_cur = 1
n_dims = len(data.shape);
# ftype == 0 -> float32, ftype == 1 -> float16
ftype_cur = 0;
if ftype != 0:
print(" Converting to float16")
data = data.astype(np.float16)
ftype_cur = 1
else:
if data.dtype != np.float32:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
# header
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(str);
# data
data.tofile(fout)
else:
n_dims = len(data.shape);
# ftype == 0 -> float32, ftype == 1 -> float16
ftype_cur = 0;
if ftype != 0:
if name[-7:] == ".weight" and n_dims == 2:
print(" Converting to float16")
data = data.astype(np.float16)
ftype_cur = 1
else:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
else: else:
if data.dtype != np.float32: print(" Converting to float32")
print(" Converting to float32") data = data.astype(np.float32)
data = data.astype(np.float32) ftype_cur = 0
ftype_cur = 0 else:
if data.dtype != np.float32:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
# header # header
str = name.encode('utf-8') str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims): for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(str); fout.write(str);
# data # data
data.tofile(fout) data.tofile(fout)
fout.close() fout.close()