diff --git a/gpt4all-backend/gptj.cpp b/gpt4all-backend/gptj.cpp index a66a34a6..0ec5d61c 100644 --- a/gpt4all-backend/gptj.cpp +++ b/gpt4all-backend/gptj.cpp @@ -375,37 +375,31 @@ bool gptj_eval( // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].c_attn_v_proj_w, cur); + struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_q_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); + struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_k_proj_w, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); // store key and value to memory { + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_mul_mat(ctx0, model.layers[il].c_attn_v_proj_w, cur)); + struct ggml_tensor * k = ggml_view_1d(ctx0, model.kv_self.k, N*n_embd, (ggml_element_size(model.kv_self.k)*n_embd)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctx0, model.kv_self.v, N*n_embd, (ggml_element_size(model.kv_self.v)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_2d(ctx0, model.kv_self.v, N, n_embd, + ( n_ctx)*ggml_element_size(model.kv_self.v), + (il*n_ctx)*ggml_element_size(model.kv_self.v)*n_embd + n_past*ggml_element_size(model.kv_self.v)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); } // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_rope(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)), - n_past, n_rot, 0, 0), - 0, 2, 1, 3); + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) struct ggml_tensor * K = ggml_permute(ctx0, - ggml_rope(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, model.kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.k)*n_embd), - n_embd/n_head, n_head, n_past + N), - n_past, n_rot, 1, 0), + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.k)*n_embd), + n_embd/n_head, n_head, n_past + N), 0, 2, 1, 3); // K * Q @@ -425,17 +419,15 @@ bool gptj_eval( struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() - struct ggml_tensor * V_trans = - ggml_cpy(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, model.kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.v)*n_embd), - n_embd/n_head, n_head, n_past + N), - 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, model.kv_self.v->type, n_past + N, n_embd/n_head, n_head)); + struct ggml_tensor * V = + ggml_view_3d(ctx0, model.kv_self.v, + n_past + N, n_embd/n_head, n_head, + n_ctx*ggml_element_size(model.kv_self.v), + n_ctx*ggml_element_size(model.kv_self.v)*n_embd/n_head, + il*n_ctx*ggml_element_size(model.kv_self.v)*n_embd); // KQV = transpose(V) * KQ_soft_max - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); // KQV_merged = KQV.permute(0, 2, 1, 3) struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);