Python bindings: unicode decoding (#1281)

* rewrote the unicode decoding using the structure of multi-byte unicode symbols.
This commit is contained in:
385olt 2023-07-30 20:29:51 +02:00 committed by GitHub
parent 91a32c0e84
commit 3ed6d176a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,7 @@
import ctypes
import os
import platform
import queue
from queue import Queue
import re
import subprocess
import sys
@ -157,6 +157,9 @@ class LLModel:
self.context = None
self.llmodel_lib = llmodel
self.buffer = bytearray()
self.buff_expecting_cont_bytes: int = 0
def __del__(self):
if self.model is not None:
self.llmodel_lib.llmodel_model_destroy(self.model)
@ -291,6 +294,9 @@ class LLModel:
None
"""
self.buffer.clear()
self.buff_expecting_cont_bytes = 0
logger.info(
"LLModel.prompt_model -- prompt:\n"
+ "%s\n"
@ -322,6 +328,7 @@ class LLModel:
self.context,
)
def prompt_model_streaming(
self,
prompt: str,
@ -331,7 +338,7 @@ class LLModel:
# Symbol to terminate from generator
TERMINATING_SYMBOL = object()
output_queue = queue.Queue()
output_queue: Queue = Queue()
# Put response tokens into an output queue
def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
@ -371,8 +378,42 @@ class LLModel:
def _callback_decoder(self, callback: ResponseCallbackType) -> RawResponseCallbackType:
def _raw_callback(token_id: int, response: bytes) -> bool:
nonlocal callback
return callback(token_id, response.decode("utf-8", "replace"))
nonlocal self, callback
decoded = []
for byte in response:
bits = "{:08b}".format(byte)
(high_ones, _, _) = bits.partition('0')
if len(high_ones) == 1:
# continuation byte
self.buffer.append(byte)
self.buff_expecting_cont_bytes -= 1
else:
# beginning of a byte sequence
if len(self.buffer) > 0:
decoded.append(self.buffer.decode('utf-8', 'replace'))
self.buffer.clear()
self.buffer.append(byte)
self.buff_expecting_cont_bytes = max(0, len(high_ones) - 1)
if self.buff_expecting_cont_bytes <= 0:
# received the whole sequence or an out of place continuation byte
decoded.append(self.buffer.decode('utf-8', 'replace'))
self.buffer.clear()
self.buff_expecting_cont_bytes = 0
if len(decoded) == 0 and self.buff_expecting_cont_bytes > 0:
# wait for more continuation bytes
return True
return callback(token_id, ''.join(decoded))
return _raw_callback