fix python async context

This commit is contained in:
Christien Rioux 2025-03-06 14:29:45 -05:00
parent c75de1196f
commit bd5492404b
2 changed files with 67 additions and 12 deletions

View File

@ -6,11 +6,20 @@ from .state import VeilidState
class RoutingContext(ABC): class RoutingContext(ABC):
ref_count: int
def __init__(
self,
):
self.ref_count = 0
async def __aenter__(self) -> Self: async def __aenter__(self) -> Self:
self.ref_count += 1
return self return self
async def __aexit__(self, *excinfo): async def __aexit__(self, *excinfo):
if not self.is_done(): self.ref_count -= 1
if self.ref_count == 0 and not self.is_done():
await self.release() await self.release()
@abstractmethod @abstractmethod
@ -109,13 +118,22 @@ class RoutingContext(ABC):
class TableDbTransaction(ABC): class TableDbTransaction(ABC):
ref_count: int
def __init__(
self,
):
self.ref_count = 0
async def __aenter__(self) -> Self: async def __aenter__(self) -> Self:
self.ref_count += 1
return self return self
async def __aexit__(self, *excinfo): async def __aexit__(self, *excinfo):
if not self.is_done(): self.ref_count -= 1
await self.rollback() if self.ref_count == 0 and not self.is_done():
await self.release()
@abstractmethod @abstractmethod
def is_done(self) -> bool: def is_done(self) -> bool:
pass pass
@ -138,11 +156,20 @@ class TableDbTransaction(ABC):
class TableDb(ABC): class TableDb(ABC):
ref_count: int
def __init__(
self,
):
self.ref_count = 0
async def __aenter__(self) -> Self: async def __aenter__(self) -> Self:
self.ref_count += 1
return self return self
async def __aexit__(self, *excinfo): async def __aexit__(self, *excinfo):
if not self.is_done(): self.ref_count -= 1
if self.ref_count == 0 and not self.is_done():
await self.release() await self.release()
@abstractmethod @abstractmethod
@ -179,11 +206,20 @@ class TableDb(ABC):
class CryptoSystem(ABC): class CryptoSystem(ABC):
ref_count: int
def __init__(
self,
):
self.ref_count = 0
async def __aenter__(self) -> Self: async def __aenter__(self) -> Self:
self.ref_count += 1
return self return self
async def __aexit__(self, *excinfo): async def __aexit__(self, *excinfo):
if not self.is_done(): self.ref_count -= 1
if self.ref_count == 0 and not self.is_done():
await self.release() await self.release()
@abstractmethod @abstractmethod
@ -306,11 +342,20 @@ class CryptoSystem(ABC):
class VeilidAPI(ABC): class VeilidAPI(ABC):
ref_count: int
def __init__(
self,
):
self.ref_count = 0
async def __aenter__(self) -> Self: async def __aenter__(self) -> Self:
self.ref_count += 1
return self return self
async def __aexit__(self, *excinfo): async def __aexit__(self, *excinfo):
if not self.is_done(): self.ref_count -= 1
if self.ref_count == 0 and not self.is_done():
await self.release() await self.release()
@abstractmethod @abstractmethod

View File

@ -99,6 +99,8 @@ class _JsonVeilidAPI(VeilidAPI):
update_callback: Callable[[VeilidUpdate], Awaitable], update_callback: Callable[[VeilidUpdate], Awaitable],
validate_schema: bool = True, validate_schema: bool = True,
): ):
super().__init__()
self.reader = reader self.reader = reader
self.writer = writer self.writer = writer
self.update_callback = update_callback self.update_callback = update_callback
@ -308,7 +310,7 @@ class _JsonVeilidAPI(VeilidAPI):
# Validate if we have a validator # Validate if we have a validator
if response["op"] != req["op"]: if response["op"] != req["op"]:
raise ValueError("Response op does not match request op") raise ValueError(f"Response op does not match request op: {response['op']} != {req['op']}")
if validate is not None: if validate is not None:
validate(req, response) validate(req, response)
@ -459,7 +461,7 @@ class _JsonVeilidAPI(VeilidAPI):
def validate_rc_op(request: dict, response: dict): def validate_rc_op(request: dict, response: dict):
if response["rc_op"] != request["rc_op"]: if response["rc_op"] != request["rc_op"]:
raise ValueError("Response rc_op does not match request rc_op") raise ValueError(f"Response rc_op does not match request rc_op: {response["rc_op"]} != {request["rc_op"]}")
class _JsonRoutingContext(RoutingContext): class _JsonRoutingContext(RoutingContext):
@ -468,6 +470,8 @@ class _JsonRoutingContext(RoutingContext):
done: bool done: bool
def __init__(self, api: _JsonVeilidAPI, rc_id: int): def __init__(self, api: _JsonVeilidAPI, rc_id: int):
super().__init__()
self.api = api self.api = api
self.rc_id = rc_id self.rc_id = rc_id
self.done = False self.done = False
@ -728,7 +732,7 @@ class _JsonRoutingContext(RoutingContext):
def validate_tx_op(request: dict, response: dict): def validate_tx_op(request: dict, response: dict):
if response["tx_op"] != request["tx_op"]: if response["tx_op"] != request["tx_op"]:
raise ValueError("Response tx_op does not match request tx_op") raise ValueError(f"Response tx_op does not match request tx_op: {response['tx_op']} != {request['tx_op']}")
class _JsonTableDbTransaction(TableDbTransaction): class _JsonTableDbTransaction(TableDbTransaction):
@ -737,6 +741,8 @@ class _JsonTableDbTransaction(TableDbTransaction):
done: bool done: bool
def __init__(self, api: _JsonVeilidAPI, tx_id: int): def __init__(self, api: _JsonVeilidAPI, tx_id: int):
super().__init__()
self.api = api self.api = api
self.tx_id = tx_id self.tx_id = tx_id
self.done = False self.done = False
@ -810,7 +816,7 @@ class _JsonTableDbTransaction(TableDbTransaction):
def validate_db_op(request: dict, response: dict): def validate_db_op(request: dict, response: dict):
if response["db_op"] != request["db_op"]: if response["db_op"] != request["db_op"]:
raise ValueError("Response db_op does not match request db_op") raise ValueError(f"Response db_op does not match request db_op: {response['db_op']} != {request['db_op']}")
class _JsonTableDb(TableDb): class _JsonTableDb(TableDb):
@ -819,6 +825,8 @@ class _JsonTableDb(TableDb):
done: bool done: bool
def __init__(self, api: _JsonVeilidAPI, db_id: int): def __init__(self, api: _JsonVeilidAPI, db_id: int):
super().__init__()
self.api = api self.api = api
self.db_id = db_id self.db_id = db_id
self.done = False self.done = False
@ -929,7 +937,7 @@ class _JsonTableDb(TableDb):
def validate_cs_op(request: dict, response: dict): def validate_cs_op(request: dict, response: dict):
if response["cs_op"] != request["cs_op"]: if response["cs_op"] != request["cs_op"]:
raise ValueError("Response cs_op does not match request cs_op") raise ValueError(f"Response cs_op does not match request cs_op: {response['cs_op']} != {request['cs_op']}")
class _JsonCryptoSystem(CryptoSystem): class _JsonCryptoSystem(CryptoSystem):
@ -938,6 +946,8 @@ class _JsonCryptoSystem(CryptoSystem):
done: bool done: bool
def __init__(self, api: _JsonVeilidAPI, cs_id: int): def __init__(self, api: _JsonVeilidAPI, cs_id: int):
super().__init__()
self.api = api self.api = api
self.cs_id = cs_id self.cs_id = cs_id
self.done = False self.done = False