mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Python bindings: unicode decoding (#1281)
* rewrote the unicode decoding using the structure of multi-byte unicode symbols.
This commit is contained in:
parent
91a32c0e84
commit
3ed6d176a5
@ -1,7 +1,7 @@
|
||||
import ctypes
|
||||
import os
|
||||
import platform
|
||||
import queue
|
||||
from queue import Queue
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
@ -157,6 +157,9 @@ class LLModel:
|
||||
self.context = None
|
||||
self.llmodel_lib = llmodel
|
||||
|
||||
self.buffer = bytearray()
|
||||
self.buff_expecting_cont_bytes: int = 0
|
||||
|
||||
def __del__(self):
|
||||
if self.model is not None:
|
||||
self.llmodel_lib.llmodel_model_destroy(self.model)
|
||||
@ -291,6 +294,9 @@ class LLModel:
|
||||
None
|
||||
"""
|
||||
|
||||
self.buffer.clear()
|
||||
self.buff_expecting_cont_bytes = 0
|
||||
|
||||
logger.info(
|
||||
"LLModel.prompt_model -- prompt:\n"
|
||||
+ "%s\n"
|
||||
@ -322,6 +328,7 @@ class LLModel:
|
||||
self.context,
|
||||
)
|
||||
|
||||
|
||||
def prompt_model_streaming(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -331,7 +338,7 @@ class LLModel:
|
||||
# Symbol to terminate from generator
|
||||
TERMINATING_SYMBOL = object()
|
||||
|
||||
output_queue = queue.Queue()
|
||||
output_queue: Queue = Queue()
|
||||
|
||||
# Put response tokens into an output queue
|
||||
def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
|
||||
@ -371,8 +378,42 @@ class LLModel:
|
||||
|
||||
def _callback_decoder(self, callback: ResponseCallbackType) -> RawResponseCallbackType:
|
||||
def _raw_callback(token_id: int, response: bytes) -> bool:
|
||||
nonlocal callback
|
||||
return callback(token_id, response.decode("utf-8", "replace"))
|
||||
nonlocal self, callback
|
||||
|
||||
decoded = []
|
||||
|
||||
for byte in response:
|
||||
|
||||
bits = "{:08b}".format(byte)
|
||||
(high_ones, _, _) = bits.partition('0')
|
||||
|
||||
if len(high_ones) == 1:
|
||||
# continuation byte
|
||||
self.buffer.append(byte)
|
||||
self.buff_expecting_cont_bytes -= 1
|
||||
|
||||
else:
|
||||
# beginning of a byte sequence
|
||||
if len(self.buffer) > 0:
|
||||
decoded.append(self.buffer.decode('utf-8', 'replace'))
|
||||
|
||||
self.buffer.clear()
|
||||
|
||||
self.buffer.append(byte)
|
||||
self.buff_expecting_cont_bytes = max(0, len(high_ones) - 1)
|
||||
|
||||
if self.buff_expecting_cont_bytes <= 0:
|
||||
# received the whole sequence or an out of place continuation byte
|
||||
decoded.append(self.buffer.decode('utf-8', 'replace'))
|
||||
|
||||
self.buffer.clear()
|
||||
self.buff_expecting_cont_bytes = 0
|
||||
|
||||
if len(decoded) == 0 and self.buff_expecting_cont_bytes > 0:
|
||||
# wait for more continuation bytes
|
||||
return True
|
||||
|
||||
return callback(token_id, ''.join(decoded))
|
||||
|
||||
return _raw_callback
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user