mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Use ggml scratch bufs for mpt and gptj models (#1104)
* backend/gptj: use scratch buffers reduces total memory required and makes eval buf not grow with n_past * backend/mpt: use scratch bufs * fix format-related compile warnings
This commit is contained in:
parent
70cbff70cc
commit
40a3faeb05
@ -5,6 +5,7 @@
|
|||||||
#include "llmodel_shared.h"
|
#include "llmodel_shared.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <cinttypes>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
@ -85,7 +86,9 @@ struct gptj_model {
|
|||||||
struct ggml_context * ctx;
|
struct ggml_context * ctx;
|
||||||
std::map<std::string, struct ggml_tensor *> tensors;
|
std::map<std::string, struct ggml_tensor *> tensors;
|
||||||
|
|
||||||
llm_buffer buf;
|
llm_buffer eval_buf;
|
||||||
|
llm_buffer scr0_buf;
|
||||||
|
llm_buffer scr1_buf;
|
||||||
|
|
||||||
~gptj_model() {
|
~gptj_model() {
|
||||||
if (ctx) {
|
if (ctx) {
|
||||||
@ -393,7 +396,7 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%lu, %lu], expected [%d, %d]\n",
|
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%" PRId64 ", %" PRId64 "], expected [%d, %d]\n",
|
||||||
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -438,6 +441,9 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
|
|||||||
printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
|
printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model.scr0_buf.resize(256u * 1024 * 1024);
|
||||||
|
model.scr1_buf.resize(256u * 1024 * 1024);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -484,24 +490,24 @@ bool gptj_eval(
|
|||||||
const int n_rot = hparams.n_rot;
|
const int n_rot = hparams.n_rot;
|
||||||
|
|
||||||
const size_t init_buf_size = 1024_MiB;
|
const size_t init_buf_size = 1024_MiB;
|
||||||
if (!model.buf.addr || model.buf.size < init_buf_size)
|
if (!model.eval_buf.addr || model.eval_buf.size < init_buf_size)
|
||||||
model.buf.resize(init_buf_size);
|
model.eval_buf.resize(init_buf_size);
|
||||||
|
|
||||||
if (mem_per_token > 0 && mem_per_token*N > model.buf.size) {
|
if (mem_per_token > 0 && mem_per_token*N > model.eval_buf.size) {
|
||||||
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||||
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, model.buf.size, buf_size_new);
|
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, model.eval_buf.size, buf_size_new);
|
||||||
|
|
||||||
// reallocate
|
// reallocate
|
||||||
model.buf.resize(buf_size_new);
|
model.eval_buf.resize(buf_size_new);
|
||||||
if (model.buf.addr == nullptr) {
|
if (model.eval_buf.addr == nullptr) {
|
||||||
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, model.buf.size);
|
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, model.eval_buf.size);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
.mem_size = model.buf.size,
|
.mem_size = model.eval_buf.size,
|
||||||
.mem_buffer = model.buf.addr,
|
.mem_buffer = model.eval_buf.addr,
|
||||||
.no_alloc = false
|
.no_alloc = false
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -517,7 +523,7 @@ bool gptj_eval(
|
|||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
|
ggml_set_scratch(ctx0, {0, model.scr0_buf.size, model.scr0_buf.addr, });
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_norm(ctx0, inpL);
|
cur = ggml_norm(ctx0, inpL);
|
||||||
@ -612,6 +618,7 @@ bool gptj_eval(
|
|||||||
|
|
||||||
struct ggml_tensor * inpFF = cur;
|
struct ggml_tensor * inpFF = cur;
|
||||||
|
|
||||||
|
ggml_set_scratch(ctx0, {0, model.scr1_buf.size, model.scr1_buf.addr, });
|
||||||
// feed-forward network
|
// feed-forward network
|
||||||
// this is independent of the self-attention result, so it could be done in parallel to the self-attention
|
// this is independent of the self-attention result, so it could be done in parallel to the self-attention
|
||||||
{
|
{
|
||||||
@ -645,6 +652,8 @@ bool gptj_eval(
|
|||||||
inpL = ggml_add(ctx0, cur, inpL);
|
inpL = ggml_add(ctx0, cur, inpL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_set_scratch(ctx0, {0, model.scr0_buf.size, model.scr0_buf.addr, });
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
inpL = ggml_norm(ctx0, inpL);
|
inpL = ggml_norm(ctx0, inpL);
|
||||||
@ -657,6 +666,8 @@ bool gptj_eval(
|
|||||||
ggml_repeat(ctx0, model.ln_f_b, inpL));
|
ggml_repeat(ctx0, model.ln_f_b, inpL));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_set_scratch(ctx0, { 0, 0, nullptr, });
|
||||||
|
|
||||||
// lm_head
|
// lm_head
|
||||||
{
|
{
|
||||||
inpL = ggml_mul_mat(ctx0, model.lmh_g, inpL);
|
inpL = ggml_mul_mat(ctx0, model.lmh_g, inpL);
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include "llmodel_shared.h"
|
#include "llmodel_shared.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <cinttypes>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
@ -80,7 +81,9 @@ struct mpt_model {
|
|||||||
std::map<std::string, struct ggml_tensor *> tensors;
|
std::map<std::string, struct ggml_tensor *> tensors;
|
||||||
|
|
||||||
|
|
||||||
llm_buffer buf;
|
llm_buffer eval_buf;
|
||||||
|
llm_buffer scr0_buf;
|
||||||
|
llm_buffer scr1_buf;
|
||||||
|
|
||||||
~mpt_model() {
|
~mpt_model() {
|
||||||
if (ctx) {
|
if (ctx) {
|
||||||
@ -370,8 +373,8 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%" PRId64 ", %" PRId64 "], expected [%d, %d]\n",
|
||||||
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
|
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -403,6 +406,9 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
|||||||
printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
|
printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model.scr0_buf.resize(256u * 1024 * 1024);
|
||||||
|
model.scr1_buf.resize(256u * 1024 * 1024);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -438,24 +444,24 @@ bool mpt_eval(
|
|||||||
const int n_vocab = hparams.n_vocab;
|
const int n_vocab = hparams.n_vocab;
|
||||||
|
|
||||||
const size_t init_buf_size = 1024_MiB;
|
const size_t init_buf_size = 1024_MiB;
|
||||||
if (!model.buf.addr || model.buf.size < init_buf_size)
|
if (!model.eval_buf.addr || model.eval_buf.size < init_buf_size)
|
||||||
model.buf.resize(init_buf_size);
|
model.eval_buf.resize(init_buf_size);
|
||||||
|
|
||||||
if (mem_per_token > 0 && mem_per_token*N > model.buf.size) {
|
if (mem_per_token > 0 && mem_per_token*N > model.eval_buf.size) {
|
||||||
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||||
// printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, model.buf.size, buf_size_new);
|
// printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, model.buf.size, buf_size_new);
|
||||||
|
|
||||||
// reallocate
|
// reallocate
|
||||||
model.buf.resize(buf_size_new);
|
model.eval_buf.resize(buf_size_new);
|
||||||
if (model.buf.addr == nullptr) {
|
if (model.eval_buf.addr == nullptr) {
|
||||||
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, model.buf.size);
|
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, model.eval_buf.size);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
.mem_size = model.buf.size,
|
.mem_size = model.eval_buf.size,
|
||||||
.mem_buffer = model.buf.addr,
|
.mem_buffer = model.eval_buf.addr,
|
||||||
.no_alloc = false
|
.no_alloc = false
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -470,6 +476,7 @@ bool mpt_eval(
|
|||||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.wte, embd);
|
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.wte, embd);
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
ggml_set_scratch(ctx0, {0, model.scr0_buf.size, model.scr0_buf.addr, });
|
||||||
|
|
||||||
struct ggml_tensor * inpSA = inpL;
|
struct ggml_tensor * inpSA = inpL;
|
||||||
struct ggml_tensor * cur = inpSA;
|
struct ggml_tensor * cur = inpSA;
|
||||||
@ -561,7 +568,7 @@ bool mpt_eval(
|
|||||||
cur);
|
cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_set_scratch(ctx0, {0, model.scr1_buf.size, model.scr1_buf.addr, });
|
||||||
// residual
|
// residual
|
||||||
struct ggml_tensor * resSA = ggml_add(ctx0, cur, inpSA);
|
struct ggml_tensor * resSA = ggml_add(ctx0, cur, inpSA);
|
||||||
// feed-forward network
|
// feed-forward network
|
||||||
@ -586,6 +593,7 @@ bool mpt_eval(
|
|||||||
// self-attention + FF
|
// self-attention + FF
|
||||||
inpL = ggml_add(ctx0, cur, resSA);
|
inpL = ggml_add(ctx0, cur, resSA);
|
||||||
}
|
}
|
||||||
|
ggml_set_scratch(ctx0, {0, model.scr0_buf.size, model.scr0_buf.addr, });
|
||||||
|
|
||||||
struct ggml_tensor * out = inpL;
|
struct ggml_tensor * out = inpL;
|
||||||
// -> logits
|
// -> logits
|
||||||
@ -594,6 +602,7 @@ bool mpt_eval(
|
|||||||
out = ggml_mul(ctx0,
|
out = ggml_mul(ctx0,
|
||||||
ggml_repeat(ctx0, model.norm_f_w, out),
|
ggml_repeat(ctx0, model.norm_f_w, out),
|
||||||
out);
|
out);
|
||||||
|
ggml_set_scratch(ctx0, { 0, 0, nullptr, });
|
||||||
out = ggml_mul_mat(ctx0, model.wte, out);
|
out = ggml_mul_mat(ctx0, model.wte, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user