fix regressions in system prompt handling (#2219)

* python: fix system prompt being ignored
* fix unintended whitespace after system prompt

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-04-15 11:39:48 -04:00 committed by GitHub
parent 2273cf145e
commit ac498f79ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 19 deletions

View File

@ -755,6 +755,7 @@ void LLamaModel::embedInternal(
tokens.resize(text.length()+4); tokens.resize(text.length()+4);
int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), wantBOS, false); int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), wantBOS, false);
if (n_tokens) { if (n_tokens) {
(void)eos_token;
assert(useEOS == (eos_token != -1 && tokens[n_tokens - 1] == eos_token)); assert(useEOS == (eos_token != -1 && tokens[n_tokens - 1] == eos_token));
tokens.resize(n_tokens - useEOS); // erase EOS/SEP tokens.resize(n_tokens - useEOS); // erase EOS/SEP
} else { } else {

View File

@ -497,16 +497,16 @@ class GPT4All:
if self._history is not None: if self._history is not None:
# check if there is only one message, i.e. system prompt: # check if there is only one message, i.e. system prompt:
reset = len(self._history) == 1 reset = len(self._history) == 1
generate_kwargs["reset_context"] = reset
self._history.append({"role": "user", "content": prompt}) self._history.append({"role": "user", "content": prompt})
fct_func = self._format_chat_prompt_template.__func__ # type: ignore[attr-defined] fct_func = self._format_chat_prompt_template.__func__ # type: ignore[attr-defined]
if fct_func is GPT4All._format_chat_prompt_template: if fct_func is GPT4All._format_chat_prompt_template:
if reset: if reset:
# ingest system prompt # ingest system prompt
self.model.prompt_model(self._history[0]["content"], "%1", # use "%1%2" and not "%1" to avoid implicit whitespace
self.model.prompt_model(self._history[0]["content"], "%1%2",
empty_response_callback, empty_response_callback,
n_batch=n_batch, n_predict=0, special=True) n_batch=n_batch, n_predict=0, reset_context=True, special=True)
prompt_template = self._current_prompt_template.format("%1", "%2") prompt_template = self._current_prompt_template.format("%1", "%2")
else: else:
warnings.warn( warnings.warn(
@ -519,6 +519,7 @@ class GPT4All:
self._history[0]["content"] if reset else "", self._history[0]["content"] if reset else "",
) )
prompt_template = "%1" prompt_template = "%1"
generate_kwargs["reset_context"] = reset
else: else:
prompt_template = "%1" prompt_template = "%1"
generate_kwargs["reset_context"] = True generate_kwargs["reset_context"] = True

View File

@ -68,7 +68,7 @@ def get_long_description():
setup( setup(
name=package_name, name=package_name,
version="2.5.1", version="2.5.2",
description="Python bindings for GPT4All", description="Python bindings for GPT4All",
long_description=get_long_description(), long_description=get_long_description(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

View File

@ -733,23 +733,13 @@ void ChatLLM::generateName()
if (!isModelLoaded()) if (!isModelLoaded())
return; return;
QString instructPrompt("### Instruction:\n" std::string instructPrompt("### Instruction:\n%1\n### Response:\n"); // standard Alpaca
"Describe response above in three words.\n"
"### Response:\n");
auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1); auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2);
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1); auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1);
LLModel::PromptContext ctx = m_ctx; LLModel::PromptContext ctx = m_ctx;
#if defined(DEBUG) m_llModelInfo.model->prompt("Describe response above in three words.", instructPrompt, promptFunc, responseFunc,
printf("%s", qPrintable(instructPrompt)); recalcFunc, ctx);
fflush(stdout);
#endif
m_llModelInfo.model->prompt(instructPrompt.toStdString(), "%1", promptFunc, responseFunc, recalcFunc, ctx);
#if defined(DEBUG)
printf("\n");
fflush(stdout);
#endif
std::string trimmed = trim_whitespace(m_nameResponse); std::string trimmed = trim_whitespace(m_nameResponse);
if (trimmed != m_nameResponse) { if (trimmed != m_nameResponse) {
m_nameResponse = trimmed; m_nameResponse = trimmed;
@ -1056,7 +1046,8 @@ void ChatLLM::processSystemPrompt()
fflush(stdout); fflush(stdout);
#endif #endif
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
m_llModelInfo.model->prompt(systemPrompt, "%1", promptFunc, nullptr, recalcFunc, m_ctx, true); // use "%1%2" and not "%1" to avoid implicit whitespace
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, recalcFunc, m_ctx, true);
m_ctx.n_predict = old_n_predict; m_ctx.n_predict = old_n_predict;
#if defined(DEBUG) #if defined(DEBUG)
printf("\n"); printf("\n");