Python chat streaming (#1127)

* Support streaming in chat session

* Uncommented tests
This commit is contained in:
Andriy Mulyar 2023-07-03 12:59:39 -04:00 committed by GitHub
parent aced5e6615
commit 01bd3d6802
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 10 deletions

View File

@ -64,7 +64,6 @@ Use the GPT4All `chat_session` context manager to hold chat conversations with t
}
]
```
When using GPT4All models in the chat_session context:
- The model is given a prompt template which makes it chatty.
@ -79,7 +78,7 @@ When using GPT4All models in the chat_session context:
### Streaming Generations
To interact with GPT4All responses as the model generates, use the `streaming = True` flag during generation.
=== "GPT4All Example"
=== "GPT4All Streaming Example"
``` py
from gpt4all import GPT4All
model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
@ -93,4 +92,22 @@ To interact with GPT4All responses as the model generates, use the `streaming =
[' Paris', ' is', ' a', ' city', ' that', ' has', ' been', ' a', ' major', ' cultural', ' and', ' economic', ' center', ' for', ' over', ' ', '2', ',', '0', '0']
```
#### Streaming and Chat Sessions
When streaming tokens in a chat session, you must manually handle collection and updating of the chat history.
```python
from gpt4all import GPT4All
model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
with model.chat_session():
tokens = list(model.generate(prompt='hello', top_k=1, streaming=True))
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
tokens = list(model.generate(prompt='write me a poem about dogs', top_k=1, streaming=True))
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
print(model.current_chat_session)
```
::: gpt4all.gpt4all.GPT4All

View File

@ -210,9 +210,6 @@ class GPT4All:
if n_predict is not None:
generate_kwargs['n_predict'] = n_predict
if streaming and self._is_chat_session_activated:
raise NotImplementedError("Streaming tokens in a chat session is not currently supported.")
if self._is_chat_session_activated:
self.current_chat_session.append({"role": "user", "content": prompt})
generate_kwargs['prompt'] = self._format_chat_prompt_template(messages=self.current_chat_session)

View File

@ -25,11 +25,13 @@ def test_inference():
assert len(tokens) > 0
with model.chat_session():
try:
response = model.generate(prompt='hello', top_k=1, streaming=True)
assert False
except NotImplementedError:
assert True
tokens = list(model.generate(prompt='hello', top_k=1, streaming=True))
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
tokens = list(model.generate(prompt='write me a poem about dogs', top_k=1, streaming=True))
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
print(model.current_chat_session)
def do_long_input(model):