From bd5492404b9600a82bf63d713d48f94edf335e99 Mon Sep 17 00:00:00 2001 From: Christien Rioux Date: Thu, 6 Mar 2025 14:29:45 -0500 Subject: [PATCH] fix python async context --- veilid-python/veilid/api.py | 59 ++++++++++++++++++++++++++++---- veilid-python/veilid/json_api.py | 20 ++++++++--- 2 files changed, 67 insertions(+), 12 deletions(-) diff --git a/veilid-python/veilid/api.py b/veilid-python/veilid/api.py index de043b67..b54786de 100644 --- a/veilid-python/veilid/api.py +++ b/veilid-python/veilid/api.py @@ -6,11 +6,20 @@ from .state import VeilidState class RoutingContext(ABC): + ref_count: int + + def __init__( + self, + ): + self.ref_count = 0 + async def __aenter__(self) -> Self: + self.ref_count += 1 return self async def __aexit__(self, *excinfo): - if not self.is_done(): + self.ref_count -= 1 + if self.ref_count == 0 and not self.is_done(): await self.release() @abstractmethod @@ -109,13 +118,22 @@ class RoutingContext(ABC): class TableDbTransaction(ABC): + ref_count: int + + def __init__( + self, + ): + self.ref_count = 0 + async def __aenter__(self) -> Self: + self.ref_count += 1 return self async def __aexit__(self, *excinfo): - if not self.is_done(): - await self.rollback() - + self.ref_count -= 1 + if self.ref_count == 0 and not self.is_done(): + await self.release() + @abstractmethod def is_done(self) -> bool: pass @@ -138,11 +156,20 @@ class TableDbTransaction(ABC): class TableDb(ABC): + ref_count: int + + def __init__( + self, + ): + self.ref_count = 0 + async def __aenter__(self) -> Self: + self.ref_count += 1 return self async def __aexit__(self, *excinfo): - if not self.is_done(): + self.ref_count -= 1 + if self.ref_count == 0 and not self.is_done(): await self.release() @abstractmethod @@ -179,11 +206,20 @@ class TableDb(ABC): class CryptoSystem(ABC): + ref_count: int + + def __init__( + self, + ): + self.ref_count = 0 + async def __aenter__(self) -> Self: + self.ref_count += 1 return self async def __aexit__(self, *excinfo): - if not self.is_done(): + self.ref_count -= 1 + if self.ref_count == 0 and not self.is_done(): await self.release() @abstractmethod @@ -306,11 +342,20 @@ class CryptoSystem(ABC): class VeilidAPI(ABC): + ref_count: int + + def __init__( + self, + ): + self.ref_count = 0 + async def __aenter__(self) -> Self: + self.ref_count += 1 return self async def __aexit__(self, *excinfo): - if not self.is_done(): + self.ref_count -= 1 + if self.ref_count == 0 and not self.is_done(): await self.release() @abstractmethod diff --git a/veilid-python/veilid/json_api.py b/veilid-python/veilid/json_api.py index 5969c5ea..7b75d6ed 100644 --- a/veilid-python/veilid/json_api.py +++ b/veilid-python/veilid/json_api.py @@ -99,6 +99,8 @@ class _JsonVeilidAPI(VeilidAPI): update_callback: Callable[[VeilidUpdate], Awaitable], validate_schema: bool = True, ): + super().__init__() + self.reader = reader self.writer = writer self.update_callback = update_callback @@ -308,7 +310,7 @@ class _JsonVeilidAPI(VeilidAPI): # Validate if we have a validator if response["op"] != req["op"]: - raise ValueError("Response op does not match request op") + raise ValueError(f"Response op does not match request op: {response['op']} != {req['op']}") if validate is not None: validate(req, response) @@ -459,7 +461,7 @@ class _JsonVeilidAPI(VeilidAPI): def validate_rc_op(request: dict, response: dict): if response["rc_op"] != request["rc_op"]: - raise ValueError("Response rc_op does not match request rc_op") + raise ValueError(f"Response rc_op does not match request rc_op: {response["rc_op"]} != {request["rc_op"]}") class _JsonRoutingContext(RoutingContext): @@ -468,6 +470,8 @@ class _JsonRoutingContext(RoutingContext): done: bool def __init__(self, api: _JsonVeilidAPI, rc_id: int): + super().__init__() + self.api = api self.rc_id = rc_id self.done = False @@ -728,7 +732,7 @@ class _JsonRoutingContext(RoutingContext): def validate_tx_op(request: dict, response: dict): if response["tx_op"] != request["tx_op"]: - raise ValueError("Response tx_op does not match request tx_op") + raise ValueError(f"Response tx_op does not match request tx_op: {response['tx_op']} != {request['tx_op']}") class _JsonTableDbTransaction(TableDbTransaction): @@ -737,6 +741,8 @@ class _JsonTableDbTransaction(TableDbTransaction): done: bool def __init__(self, api: _JsonVeilidAPI, tx_id: int): + super().__init__() + self.api = api self.tx_id = tx_id self.done = False @@ -810,7 +816,7 @@ class _JsonTableDbTransaction(TableDbTransaction): def validate_db_op(request: dict, response: dict): if response["db_op"] != request["db_op"]: - raise ValueError("Response db_op does not match request db_op") + raise ValueError(f"Response db_op does not match request db_op: {response['db_op']} != {request['db_op']}") class _JsonTableDb(TableDb): @@ -819,6 +825,8 @@ class _JsonTableDb(TableDb): done: bool def __init__(self, api: _JsonVeilidAPI, db_id: int): + super().__init__() + self.api = api self.db_id = db_id self.done = False @@ -929,7 +937,7 @@ class _JsonTableDb(TableDb): def validate_cs_op(request: dict, response: dict): if response["cs_op"] != request["cs_op"]: - raise ValueError("Response cs_op does not match request cs_op") + raise ValueError(f"Response cs_op does not match request cs_op: {response['cs_op']} != {request['cs_op']}") class _JsonCryptoSystem(CryptoSystem): @@ -938,6 +946,8 @@ class _JsonCryptoSystem(CryptoSystem): done: bool def __init__(self, api: _JsonVeilidAPI, cs_id: int): + super().__init__() + self.api = api self.cs_id = cs_id self.done = False