mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Python bindings: unicode decoding (#1281)
* rewrote the unicode decoding using the structure of multi-byte unicode symbols.
This commit is contained in:
parent
91a32c0e84
commit
3ed6d176a5
@ -1,7 +1,7 @@
|
|||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import queue
|
from queue import Queue
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
@ -157,6 +157,9 @@ class LLModel:
|
|||||||
self.context = None
|
self.context = None
|
||||||
self.llmodel_lib = llmodel
|
self.llmodel_lib = llmodel
|
||||||
|
|
||||||
|
self.buffer = bytearray()
|
||||||
|
self.buff_expecting_cont_bytes: int = 0
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.model is not None:
|
if self.model is not None:
|
||||||
self.llmodel_lib.llmodel_model_destroy(self.model)
|
self.llmodel_lib.llmodel_model_destroy(self.model)
|
||||||
@ -291,6 +294,9 @@ class LLModel:
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
self.buffer.clear()
|
||||||
|
self.buff_expecting_cont_bytes = 0
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"LLModel.prompt_model -- prompt:\n"
|
"LLModel.prompt_model -- prompt:\n"
|
||||||
+ "%s\n"
|
+ "%s\n"
|
||||||
@ -322,6 +328,7 @@ class LLModel:
|
|||||||
self.context,
|
self.context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def prompt_model_streaming(
|
def prompt_model_streaming(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -331,7 +338,7 @@ class LLModel:
|
|||||||
# Symbol to terminate from generator
|
# Symbol to terminate from generator
|
||||||
TERMINATING_SYMBOL = object()
|
TERMINATING_SYMBOL = object()
|
||||||
|
|
||||||
output_queue = queue.Queue()
|
output_queue: Queue = Queue()
|
||||||
|
|
||||||
# Put response tokens into an output queue
|
# Put response tokens into an output queue
|
||||||
def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
|
def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
|
||||||
@ -371,8 +378,42 @@ class LLModel:
|
|||||||
|
|
||||||
def _callback_decoder(self, callback: ResponseCallbackType) -> RawResponseCallbackType:
|
def _callback_decoder(self, callback: ResponseCallbackType) -> RawResponseCallbackType:
|
||||||
def _raw_callback(token_id: int, response: bytes) -> bool:
|
def _raw_callback(token_id: int, response: bytes) -> bool:
|
||||||
nonlocal callback
|
nonlocal self, callback
|
||||||
return callback(token_id, response.decode("utf-8", "replace"))
|
|
||||||
|
decoded = []
|
||||||
|
|
||||||
|
for byte in response:
|
||||||
|
|
||||||
|
bits = "{:08b}".format(byte)
|
||||||
|
(high_ones, _, _) = bits.partition('0')
|
||||||
|
|
||||||
|
if len(high_ones) == 1:
|
||||||
|
# continuation byte
|
||||||
|
self.buffer.append(byte)
|
||||||
|
self.buff_expecting_cont_bytes -= 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
# beginning of a byte sequence
|
||||||
|
if len(self.buffer) > 0:
|
||||||
|
decoded.append(self.buffer.decode('utf-8', 'replace'))
|
||||||
|
|
||||||
|
self.buffer.clear()
|
||||||
|
|
||||||
|
self.buffer.append(byte)
|
||||||
|
self.buff_expecting_cont_bytes = max(0, len(high_ones) - 1)
|
||||||
|
|
||||||
|
if self.buff_expecting_cont_bytes <= 0:
|
||||||
|
# received the whole sequence or an out of place continuation byte
|
||||||
|
decoded.append(self.buffer.decode('utf-8', 'replace'))
|
||||||
|
|
||||||
|
self.buffer.clear()
|
||||||
|
self.buff_expecting_cont_bytes = 0
|
||||||
|
|
||||||
|
if len(decoded) == 0 and self.buff_expecting_cont_bytes > 0:
|
||||||
|
# wait for more continuation bytes
|
||||||
|
return True
|
||||||
|
|
||||||
|
return callback(token_id, ''.join(decoded))
|
||||||
|
|
||||||
return _raw_callback
|
return _raw_callback
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user