mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
python: implement close() and context manager interface (#2177)
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
dddaf49428
commit
3313c7de0d
@ -9,7 +9,7 @@ import sys
|
||||
import threading
|
||||
from enum import Enum
|
||||
from queue import Queue
|
||||
from typing import Any, Callable, Generic, Iterable, TypeVar, overload
|
||||
from typing import Any, Callable, Generic, Iterable, NoReturn, TypeVar, overload
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
import importlib.resources as importlib_resources
|
||||
@ -200,13 +200,22 @@ class LLModel:
|
||||
if model is None:
|
||||
s = err.value
|
||||
raise RuntimeError(f"Unable to instantiate model: {'null' if s is None else s.decode()}")
|
||||
self.model = model
|
||||
self.model: ctypes.c_void_p | None = model
|
||||
|
||||
def __del__(self, llmodel=llmodel):
|
||||
if hasattr(self, 'model'):
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
if self.model is not None:
|
||||
llmodel.llmodel_model_destroy(self.model)
|
||||
self.model = None
|
||||
|
||||
def _raise_closed(self) -> NoReturn:
|
||||
raise ValueError("Attempted operation on a closed LLModel")
|
||||
|
||||
def _list_gpu(self, mem_required: int) -> list[LLModelGPUDevice]:
|
||||
assert self.model is not None
|
||||
num_devices = ctypes.c_int32(0)
|
||||
devices_ptr = llmodel.llmodel_available_gpu_devices(self.model, mem_required, ctypes.byref(num_devices))
|
||||
if not devices_ptr:
|
||||
@ -214,6 +223,9 @@ class LLModel:
|
||||
return devices_ptr[:num_devices.value]
|
||||
|
||||
def init_gpu(self, device: str):
|
||||
if self.model is None:
|
||||
self._raise_closed()
|
||||
|
||||
mem_required = llmodel.llmodel_required_mem(self.model, self.model_path, self.n_ctx, self.ngl)
|
||||
|
||||
if llmodel.llmodel_gpu_init_gpu_device_by_string(self.model, mem_required, device.encode()):
|
||||
@ -246,14 +258,21 @@ class LLModel:
|
||||
-------
|
||||
True if model loaded successfully, False otherwise
|
||||
"""
|
||||
if self.model is None:
|
||||
self._raise_closed()
|
||||
|
||||
return llmodel.llmodel_loadModel(self.model, self.model_path, self.n_ctx, self.ngl)
|
||||
|
||||
def set_thread_count(self, n_threads):
|
||||
if self.model is None:
|
||||
self._raise_closed()
|
||||
if not llmodel.llmodel_isModelLoaded(self.model):
|
||||
raise Exception("Model not loaded")
|
||||
llmodel.llmodel_setThreadCount(self.model, n_threads)
|
||||
|
||||
def thread_count(self):
|
||||
if self.model is None:
|
||||
self._raise_closed()
|
||||
if not llmodel.llmodel_isModelLoaded(self.model):
|
||||
raise Exception("Model not loaded")
|
||||
return llmodel.llmodel_threadCount(self.model)
|
||||
@ -322,6 +341,9 @@ class LLModel:
|
||||
if not text:
|
||||
raise ValueError("text must not be None or empty")
|
||||
|
||||
if self.model is None:
|
||||
self._raise_closed()
|
||||
|
||||
if (single_text := isinstance(text, str)):
|
||||
text = [text]
|
||||
|
||||
@ -387,6 +409,9 @@ class LLModel:
|
||||
None
|
||||
"""
|
||||
|
||||
if self.model is None:
|
||||
self._raise_closed()
|
||||
|
||||
self.buffer.clear()
|
||||
self.buff_expecting_cont_bytes = 0
|
||||
|
||||
@ -419,6 +444,9 @@ class LLModel:
|
||||
def prompt_model_streaming(
|
||||
self, prompt: str, prompt_template: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
|
||||
) -> Iterable[str]:
|
||||
if self.model is None:
|
||||
self._raise_closed()
|
||||
|
||||
output_queue: Queue[str | Sentinel] = Queue()
|
||||
|
||||
# Put response tokens into an output queue
|
||||
|
@ -11,6 +11,7 @@ import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Any, Iterable, Literal, Protocol, overload
|
||||
|
||||
import requests
|
||||
@ -22,7 +23,7 @@ from . import _pyllmodel
|
||||
from ._pyllmodel import EmbedResult as EmbedResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeAlias
|
||||
from typing_extensions import Self, TypeAlias
|
||||
|
||||
if sys.platform == 'darwin':
|
||||
import fcntl
|
||||
@ -54,6 +55,18 @@ class Embed4All:
|
||||
model_name = 'all-MiniLM-L6-v2.gguf2.f16.gguf'
|
||||
self.gpt4all = GPT4All(model_name, n_threads=n_threads, **kwargs)
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self, typ: type[BaseException] | None, value: BaseException | None, tb: TracebackType | None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Delete the model instance and free associated system resources."""
|
||||
self.gpt4all.close()
|
||||
|
||||
# return_dict=False
|
||||
@overload
|
||||
def embed(
|
||||
@ -190,6 +203,18 @@ class GPT4All:
|
||||
self._history: list[MessageType] | None = None
|
||||
self._current_prompt_template: str = "{0}"
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self, typ: type[BaseException] | None, value: BaseException | None, tb: TracebackType | None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Delete the model instance and free associated system resources."""
|
||||
self.model.close()
|
||||
|
||||
@property
|
||||
def current_chat_session(self) -> list[MessageType] | None:
|
||||
return None if self._history is None else list(self._history)
|
||||
|
@ -68,7 +68,7 @@ def get_long_description():
|
||||
|
||||
setup(
|
||||
name=package_name,
|
||||
version="2.3.2",
|
||||
version="2.3.3",
|
||||
description="Python bindings for GPT4All",
|
||||
long_description=get_long_description(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
Loading…
Reference in New Issue
Block a user