mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
backend: fix buffer overrun in repeat penalty code
Caught with AddressSanitizer running a basic prompt test against llmodel standalone. This fix allows ASan builds to complete a simple prompt without illegal accesses but there are still notably several leaks.
This commit is contained in:
parent
26cb31c4e6
commit
e6fd0a240d
@ -993,9 +993,10 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
gpt_vocab::id id = 0;
|
gpt_vocab::id id = 0;
|
||||||
{
|
{
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
|
||||||
id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab,
|
id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab,
|
||||||
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx,
|
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
|
||||||
promptCtx.n_ctx,
|
n_prev_toks,
|
||||||
promptCtx.logits,
|
promptCtx.logits,
|
||||||
promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
||||||
promptCtx.repeat_penalty,
|
promptCtx.repeat_penalty,
|
||||||
|
@ -180,9 +180,10 @@ void LLamaModel::prompt(const std::string &prompt,
|
|||||||
int32_t totalPredictions = 0;
|
int32_t totalPredictions = 0;
|
||||||
for (int i = 0; i < promptCtx.n_predict; i++) {
|
for (int i = 0; i < promptCtx.n_predict; i++) {
|
||||||
// sample next token
|
// sample next token
|
||||||
|
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
|
||||||
llama_token id = llama_sample_top_p_top_k(d_ptr->ctx,
|
llama_token id = llama_sample_top_p_top_k(d_ptr->ctx,
|
||||||
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.repeat_last_n,
|
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
|
||||||
promptCtx.repeat_last_n, promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
n_prev_toks, promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
||||||
promptCtx.repeat_penalty);
|
promptCtx.repeat_penalty);
|
||||||
|
|
||||||
// Check if the context has run out...
|
// Check if the context has run out...
|
||||||
|
@ -918,9 +918,10 @@ void MPT::prompt(const std::string &prompt,
|
|||||||
int id = 0;
|
int id = 0;
|
||||||
{
|
{
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size());
|
||||||
id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab,
|
id = gpt_sample_top_k_top_p(d_ptr->vocab, n_vocab,
|
||||||
promptCtx.tokens.data() + promptCtx.n_ctx - promptCtx.n_ctx,
|
promptCtx.tokens.data() + promptCtx.tokens.size() - n_prev_toks,
|
||||||
promptCtx.n_ctx,
|
n_prev_toks,
|
||||||
promptCtx.logits,
|
promptCtx.logits,
|
||||||
promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
promptCtx.top_k, promptCtx.top_p, promptCtx.temp,
|
||||||
promptCtx.repeat_penalty,
|
promptCtx.repeat_penalty,
|
||||||
|
Loading…
Reference in New Issue
Block a user