mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
gpt-j: update inference to match latest llama.cpp insights
- Use F16 KV cache - Store transposed V in the cache - Avoid unnecessary Q copy Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> ggml upstream commit 0265f0813492602fec0e1159fe61de1bf0ccaf78
This commit is contained in:
parent
050e7f076e
commit
d5d72f0361
@ -375,37 +375,31 @@ bool gptj_eval(
|
||||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur);
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur);
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_v_proj_w, cur);
|
||||
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_v_proj_w, cur));
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, model.kv_self.k, N*n_embd, (ggml_element_size(model.kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
||||
struct ggml_tensor * v = ggml_view_1d(ctx0, model.kv_self.v, N*n_embd, (ggml_element_size(model.kv_self.v)*n_embd)*(il*n_ctx + n_past));
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, model.kv_self.v, N, n_embd,
|
||||
( n_ctx)*ggml_element_size(model.kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(model.kv_self.v)*n_embd + n_past*ggml_element_size(model.kv_self.v));
|
||||
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_rope(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
|
||||
n_past, n_rot, 0, 0),
|
||||
0, 2, 1, 3);
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
|
||||
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
ggml_rope(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.k)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
n_past, n_rot, 1, 0),
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.k)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
@ -425,17 +419,15 @@ bool gptj_eval(
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
||||
|
||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||
struct ggml_tensor * V_trans =
|
||||
ggml_cpy(ctx0,
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.v)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctx0, model.kv_self.v->type, n_past + N, n_embd/n_head, n_head));
|
||||
struct ggml_tensor * V =
|
||||
ggml_view_3d(ctx0, model.kv_self.v,
|
||||
n_past + N, n_embd/n_head, n_head,
|
||||
n_ctx*ggml_element_size(model.kv_self.v),
|
||||
n_ctx*ggml_element_size(model.kv_self.v)*n_embd/n_head,
|
||||
il*n_ctx*ggml_element_size(model.kv_self.v)*n_embd);
|
||||
|
||||
// KQV = transpose(V) * KQ_soft_max
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
|
||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
Loading…
Reference in New Issue
Block a user