mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-06 11:55:04 -04:00
Add type hints to application services. (#8655)
This commit is contained in:
parent
2239813278
commit
31d721fbf6
5 changed files with 122 additions and 79 deletions
|
@ -15,21 +15,31 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
|
||||
|
||||
from synapse.appservice import ApplicationService, AppServiceTransaction
|
||||
from synapse.appservice import (
|
||||
ApplicationService,
|
||||
ApplicationServiceState,
|
||||
AppServiceTransaction,
|
||||
)
|
||||
from synapse.config.appservice import load_appservices
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _make_exclusive_regex(services_cache):
|
||||
def _make_exclusive_regex(
|
||||
services_cache: List[ApplicationService],
|
||||
) -> Optional[Pattern]:
|
||||
# We precompile a regex constructed from all the regexes that the AS's
|
||||
# have registered for exclusive users.
|
||||
exclusive_user_regexes = [
|
||||
|
@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache):
|
|||
]
|
||||
if exclusive_user_regexes:
|
||||
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
|
||||
exclusive_user_regex = re.compile(exclusive_user_regex)
|
||||
exclusive_user_pattern = re.compile(
|
||||
exclusive_user_regex
|
||||
) # type: Optional[Pattern]
|
||||
else:
|
||||
# We handle this case specially otherwise the constructed regex
|
||||
# will always match
|
||||
exclusive_user_regex = None
|
||||
exclusive_user_pattern = None
|
||||
|
||||
return exclusive_user_regex
|
||||
return exclusive_user_pattern
|
||||
|
||||
|
||||
class ApplicationServiceWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||
self.services_cache = load_appservices(
|
||||
hs.hostname, hs.config.app_service_config_files
|
||||
)
|
||||
|
@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
|
|||
def get_app_services(self):
|
||||
return self.services_cache
|
||||
|
||||
def get_if_app_services_interested_in_user(self, user_id):
|
||||
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
|
||||
"""Check if the user is one associated with an app service (exclusively)
|
||||
"""
|
||||
if self.exclusive_user_regex:
|
||||
|
@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
|
|||
else:
|
||||
return False
|
||||
|
||||
def get_app_service_by_user_id(self, user_id):
|
||||
def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
|
||||
"""Retrieve an application service from their user ID.
|
||||
|
||||
All application services have associated with them a particular user ID.
|
||||
|
@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
|
|||
a user ID to an application service.
|
||||
|
||||
Args:
|
||||
user_id(str): The user ID to see if it is an application service.
|
||||
user_id: The user ID to see if it is an application service.
|
||||
Returns:
|
||||
synapse.appservice.ApplicationService or None.
|
||||
The application service or None.
|
||||
"""
|
||||
for service in self.services_cache:
|
||||
if service.sender == user_id:
|
||||
return service
|
||||
return None
|
||||
|
||||
def get_app_service_by_token(self, token):
|
||||
def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
|
||||
"""Get the application service with the given appservice token.
|
||||
|
||||
Args:
|
||||
token (str): The application service token.
|
||||
token: The application service token.
|
||||
Returns:
|
||||
synapse.appservice.ApplicationService or None.
|
||||
The application service or None.
|
||||
"""
|
||||
for service in self.services_cache:
|
||||
if service.token == token:
|
||||
return service
|
||||
return None
|
||||
|
||||
def get_app_service_by_id(self, as_id):
|
||||
def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
|
||||
"""Get the application service with the given appservice ID.
|
||||
|
||||
Args:
|
||||
as_id (str): The application service ID.
|
||||
as_id: The application service ID.
|
||||
Returns:
|
||||
synapse.appservice.ApplicationService or None.
|
||||
The application service or None.
|
||||
"""
|
||||
for service in self.services_cache:
|
||||
if service.id == as_id:
|
||||
|
@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
|
|||
class ApplicationServiceTransactionWorkerStore(
|
||||
ApplicationServiceWorkerStore, EventsWorkerStore
|
||||
):
|
||||
async def get_appservices_by_state(self, state):
|
||||
async def get_appservices_by_state(
|
||||
self, state: ApplicationServiceState
|
||||
) -> List[ApplicationService]:
|
||||
"""Get a list of application services based on their state.
|
||||
|
||||
Args:
|
||||
state(ApplicationServiceState): The state to filter on.
|
||||
state: The state to filter on.
|
||||
Returns:
|
||||
A list of ApplicationServices, which may be empty.
|
||||
"""
|
||||
|
@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
services.append(service)
|
||||
return services
|
||||
|
||||
async def get_appservice_state(self, service):
|
||||
async def get_appservice_state(
|
||||
self, service: ApplicationService
|
||||
) -> Optional[ApplicationServiceState]:
|
||||
"""Get the application service state.
|
||||
|
||||
Args:
|
||||
service(ApplicationService): The service whose state to set.
|
||||
service: The service whose state to set.
|
||||
Returns:
|
||||
An ApplicationServiceState.
|
||||
An ApplicationServiceState or none.
|
||||
"""
|
||||
result = await self.db_pool.simple_select_one(
|
||||
"application_services_state",
|
||||
|
@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
return result.get("state")
|
||||
return None
|
||||
|
||||
async def set_appservice_state(self, service, state) -> None:
|
||||
async def set_appservice_state(
|
||||
self, service: ApplicationService, state: ApplicationServiceState
|
||||
) -> None:
|
||||
"""Set the application service state.
|
||||
|
||||
Args:
|
||||
service(ApplicationService): The service whose state to set.
|
||||
state(ApplicationServiceState): The connectivity state to apply.
|
||||
service: The service whose state to set.
|
||||
state: The connectivity state to apply.
|
||||
"""
|
||||
await self.db_pool.simple_upsert(
|
||||
"application_services_state", {"as_id": service.id}, {"state": state}
|
||||
|
@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
"create_appservice_txn", _create_appservice_txn
|
||||
)
|
||||
|
||||
async def complete_appservice_txn(self, txn_id, service) -> None:
|
||||
async def complete_appservice_txn(
|
||||
self, txn_id: int, service: ApplicationService
|
||||
) -> None:
|
||||
"""Completes an application service transaction.
|
||||
|
||||
Args:
|
||||
txn_id(str): The transaction ID being completed.
|
||||
service(ApplicationService): The application service which was sent
|
||||
this transaction.
|
||||
txn_id: The transaction ID being completed.
|
||||
service: The application service which was sent this transaction.
|
||||
"""
|
||||
txn_id = int(txn_id)
|
||||
|
||||
|
@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
# has probably missed some events), so whine loudly but still continue,
|
||||
# since it shouldn't fail completion of the transaction.
|
||||
last_txn_id = self._get_last_txn(txn, service.id)
|
||||
if (last_txn_id + 1) != txn_id:
|
||||
if (txn_id + 1) != txn_id:
|
||||
logger.error(
|
||||
"appservice: Completing a transaction which has an ID > 1 from "
|
||||
"the last ID sent to this AS. We've either dropped events or "
|
||||
|
@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
"complete_appservice_txn", _complete_appservice_txn
|
||||
)
|
||||
|
||||
async def get_oldest_unsent_txn(self, service):
|
||||
"""Get the oldest transaction which has not been sent for this
|
||||
service.
|
||||
async def get_oldest_unsent_txn(
|
||||
self, service: ApplicationService
|
||||
) -> Optional[AppServiceTransaction]:
|
||||
"""Get the oldest transaction which has not been sent for this service.
|
||||
|
||||
Args:
|
||||
service(ApplicationService): The app service to get the oldest txn.
|
||||
service: The app service to get the oldest txn.
|
||||
Returns:
|
||||
An AppServiceTransaction or None.
|
||||
"""
|
||||
|
@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
service=service, id=entry["txn_id"], events=events, ephemeral=[]
|
||||
)
|
||||
|
||||
def _get_last_txn(self, txn, service_id):
|
||||
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
|
||||
txn.execute(
|
||||
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
||||
(service_id,),
|
||||
|
@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
else:
|
||||
return int(last_txn_id[0]) # select 'last_txn' col
|
||||
|
||||
async def set_appservice_last_pos(self, pos) -> None:
|
||||
async def set_appservice_last_pos(self, pos: int) -> None:
|
||||
def set_appservice_last_pos_txn(txn):
|
||||
txn.execute(
|
||||
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
||||
|
@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
"set_appservice_last_pos", set_appservice_last_pos_txn
|
||||
)
|
||||
|
||||
async def get_new_events_for_appservice(self, current_id, limit):
|
||||
async def get_new_events_for_appservice(
|
||||
self, current_id: int, limit: int
|
||||
) -> Tuple[int, List[EventBase]]:
|
||||
"""Get all new events for an appservice"""
|
||||
|
||||
def get_new_events_for_appservice_txn(txn):
|
||||
|
@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
)
|
||||
|
||||
async def set_type_stream_id_for_appservice(
|
||||
self, service: ApplicationService, type: str, pos: int
|
||||
self, service: ApplicationService, type: str, pos: Optional[int]
|
||||
) -> None:
|
||||
if type not in ("read_receipt", "presence"):
|
||||
raise ValueError(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue