mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-10-01 12:25:44 +00:00
Add type hints to various handlers. (#9223)
With this change all handlers except the e2e_* ones have type hints enabled.
This commit is contained in:
parent
26837d5dbe
commit
1baab20352
1
changelog.d/9223.misc
Normal file
1
changelog.d/9223.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add type hints to handlers code.
|
14
mypy.ini
14
mypy.ini
@ -26,6 +26,8 @@ files =
|
|||||||
synapse/handlers/_base.py,
|
synapse/handlers/_base.py,
|
||||||
synapse/handlers/account_data.py,
|
synapse/handlers/account_data.py,
|
||||||
synapse/handlers/account_validity.py,
|
synapse/handlers/account_validity.py,
|
||||||
|
synapse/handlers/acme.py,
|
||||||
|
synapse/handlers/acme_issuing_service.py,
|
||||||
synapse/handlers/admin.py,
|
synapse/handlers/admin.py,
|
||||||
synapse/handlers/appservice.py,
|
synapse/handlers/appservice.py,
|
||||||
synapse/handlers/auth.py,
|
synapse/handlers/auth.py,
|
||||||
@ -36,6 +38,7 @@ files =
|
|||||||
synapse/handlers/directory.py,
|
synapse/handlers/directory.py,
|
||||||
synapse/handlers/events.py,
|
synapse/handlers/events.py,
|
||||||
synapse/handlers/federation.py,
|
synapse/handlers/federation.py,
|
||||||
|
synapse/handlers/groups_local.py,
|
||||||
synapse/handlers/identity.py,
|
synapse/handlers/identity.py,
|
||||||
synapse/handlers/initial_sync.py,
|
synapse/handlers/initial_sync.py,
|
||||||
synapse/handlers/message.py,
|
synapse/handlers/message.py,
|
||||||
@ -52,8 +55,13 @@ files =
|
|||||||
synapse/handlers/room_member.py,
|
synapse/handlers/room_member.py,
|
||||||
synapse/handlers/room_member_worker.py,
|
synapse/handlers/room_member_worker.py,
|
||||||
synapse/handlers/saml_handler.py,
|
synapse/handlers/saml_handler.py,
|
||||||
|
synapse/handlers/search.py,
|
||||||
|
synapse/handlers/set_password.py,
|
||||||
synapse/handlers/sso.py,
|
synapse/handlers/sso.py,
|
||||||
|
synapse/handlers/state_deltas.py,
|
||||||
|
synapse/handlers/stats.py,
|
||||||
synapse/handlers/sync.py,
|
synapse/handlers/sync.py,
|
||||||
|
synapse/handlers/typing.py,
|
||||||
synapse/handlers/user_directory.py,
|
synapse/handlers/user_directory.py,
|
||||||
synapse/handlers/ui_auth,
|
synapse/handlers/ui_auth,
|
||||||
synapse/http/client.py,
|
synapse/http/client.py,
|
||||||
@ -194,3 +202,9 @@ ignore_missing_imports = True
|
|||||||
|
|
||||||
[mypy-hiredis]
|
[mypy-hiredis]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-josepy.*]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-txacme.*]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import twisted
|
import twisted
|
||||||
import twisted.internet.error
|
import twisted.internet.error
|
||||||
@ -22,6 +23,9 @@ from twisted.web.resource import Resource
|
|||||||
|
|
||||||
from synapse.app import check_bind_error
|
from synapse.app import check_bind_error
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ACME_REGISTER_FAIL_ERROR = """
|
ACME_REGISTER_FAIL_ERROR = """
|
||||||
@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
|
|||||||
|
|
||||||
|
|
||||||
class AcmeHandler:
|
class AcmeHandler:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.reactor = hs.get_reactor()
|
self.reactor = hs.get_reactor()
|
||||||
self._acme_domain = hs.config.acme_domain
|
self._acme_domain = hs.config.acme_domain
|
||||||
|
|
||||||
async def start_listening(self):
|
async def start_listening(self) -> None:
|
||||||
from synapse.handlers import acme_issuing_service
|
from synapse.handlers import acme_issuing_service
|
||||||
|
|
||||||
# Configure logging for txacme, if you need to debug
|
# Configure logging for txacme, if you need to debug
|
||||||
@ -85,7 +89,7 @@ class AcmeHandler:
|
|||||||
logger.error(ACME_REGISTER_FAIL_ERROR)
|
logger.error(ACME_REGISTER_FAIL_ERROR)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def provision_certificate(self):
|
async def provision_certificate(self) -> None:
|
||||||
|
|
||||||
logger.warning("Reprovisioning %s", self._acme_domain)
|
logger.warning("Reprovisioning %s", self._acme_domain)
|
||||||
|
|
||||||
@ -110,5 +114,3 @@ class AcmeHandler:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed saving!")
|
logger.exception("Failed saving!")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return True
|
|
||||||
|
@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to
|
|||||||
imported conditionally.
|
imported conditionally.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Dict, Iterable, List
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
import pem
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
from cryptography.hazmat.primitives import serialization
|
from cryptography.hazmat.primitives import serialization
|
||||||
from josepy import JWKRSA
|
from josepy import JWKRSA
|
||||||
@ -36,20 +38,27 @@ from txacme.util import generate_private_key
|
|||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.internet.interfaces import IReactorTCP
|
||||||
from twisted.python.filepath import FilePath
|
from twisted.python.filepath import FilePath
|
||||||
from twisted.python.url import URL
|
from twisted.python.url import URL
|
||||||
|
from twisted.web.resource import IResource
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
|
def create_issuing_service(
|
||||||
|
reactor: IReactorTCP,
|
||||||
|
acme_url: str,
|
||||||
|
account_key_file: str,
|
||||||
|
well_known_resource: IResource,
|
||||||
|
) -> AcmeIssuingService:
|
||||||
"""Create an ACME issuing service, and attach it to a web Resource
|
"""Create an ACME issuing service, and attach it to a web Resource
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
reactor: twisted reactor
|
reactor: twisted reactor
|
||||||
acme_url (str): URL to use to request certificates
|
acme_url: URL to use to request certificates
|
||||||
account_key_file (str): where to store the account key
|
account_key_file: where to store the account key
|
||||||
well_known_resource (twisted.web.IResource): web resource for .well-known.
|
well_known_resource: web resource for .well-known.
|
||||||
we will attach a child resource for "acme-challenge".
|
we will attach a child resource for "acme-challenge".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -83,18 +92,20 @@ class ErsatzStore:
|
|||||||
A store that only stores in memory.
|
A store that only stores in memory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
certs = attr.ib(default=attr.Factory(dict))
|
certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
|
||||||
|
|
||||||
def store(self, server_name, pem_objects):
|
def store(
|
||||||
|
self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
|
||||||
|
) -> defer.Deferred:
|
||||||
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
|
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
|
||||||
|
|
||||||
def load_or_create_client_key(key_file):
|
def load_or_create_client_key(key_file: str) -> JWKRSA:
|
||||||
"""Load the ACME account key from a file, creating it if it does not exist.
|
"""Load the ACME account key from a file, creating it if it does not exist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key_file (str): name of the file to use as the account key
|
key_file: name of the file to use as the account key
|
||||||
"""
|
"""
|
||||||
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't
|
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't
|
||||||
# hardcode the 'client.key' filename
|
# hardcode the 'client.key' filename
|
||||||
|
@ -15,9 +15,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Set
|
||||||
|
|
||||||
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
|
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
|
||||||
from synapse.types import GroupID, get_domain_from_id
|
from synapse.types import GroupID, JsonDict, get_domain_from_id
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -56,7 +60,7 @@ def _create_rerouter(func_name):
|
|||||||
|
|
||||||
|
|
||||||
class GroupsLocalWorkerHandler:
|
class GroupsLocalWorkerHandler:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.room_list_handler = hs.get_room_list_handler()
|
self.room_list_handler = hs.get_room_list_handler()
|
||||||
@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler:
|
|||||||
get_group_role = _create_rerouter("get_group_role")
|
get_group_role = _create_rerouter("get_group_role")
|
||||||
get_group_roles = _create_rerouter("get_group_roles")
|
get_group_roles = _create_rerouter("get_group_roles")
|
||||||
|
|
||||||
async def get_group_summary(self, group_id, requester_user_id):
|
async def get_group_summary(
|
||||||
|
self, group_id: str, requester_user_id: str
|
||||||
|
) -> JsonDict:
|
||||||
"""Get the group summary for a group.
|
"""Get the group summary for a group.
|
||||||
|
|
||||||
If the group is remote we check that the users have valid attestations.
|
If the group is remote we check that the users have valid attestations.
|
||||||
@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler:
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def get_users_in_group(self, group_id, requester_user_id):
|
async def get_users_in_group(
|
||||||
|
self, group_id: str, requester_user_id: str
|
||||||
|
) -> JsonDict:
|
||||||
"""Get users in a group
|
"""Get users in a group
|
||||||
"""
|
"""
|
||||||
if self.is_mine_id(group_id):
|
if self.is_mine_id(group_id):
|
||||||
res = await self.groups_server_handler.get_users_in_group(
|
return await self.groups_server_handler.get_users_in_group(
|
||||||
group_id, requester_user_id
|
group_id, requester_user_id
|
||||||
)
|
)
|
||||||
return res
|
|
||||||
|
|
||||||
group_server_name = get_domain_from_id(group_id)
|
group_server_name = get_domain_from_id(group_id)
|
||||||
|
|
||||||
@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler:
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def get_joined_groups(self, user_id):
|
async def get_joined_groups(self, user_id: str) -> JsonDict:
|
||||||
group_ids = await self.store.get_joined_groups(user_id)
|
group_ids = await self.store.get_joined_groups(user_id)
|
||||||
return {"groups": group_ids}
|
return {"groups": group_ids}
|
||||||
|
|
||||||
async def get_publicised_groups_for_user(self, user_id):
|
async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
|
||||||
if self.hs.is_mine_id(user_id):
|
if self.hs.is_mine_id(user_id):
|
||||||
result = await self.store.get_publicised_groups_for_user(user_id)
|
result = await self.store.get_publicised_groups_for_user(user_id)
|
||||||
|
|
||||||
@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler:
|
|||||||
# TODO: Verify attestations
|
# TODO: Verify attestations
|
||||||
return {"groups": result}
|
return {"groups": result}
|
||||||
|
|
||||||
async def bulk_get_publicised_groups(self, user_ids, proxy=True):
|
async def bulk_get_publicised_groups(
|
||||||
destinations = {}
|
self, user_ids: Iterable[str], proxy: bool = True
|
||||||
|
) -> JsonDict:
|
||||||
|
destinations = {} # type: Dict[str, Set[str]]
|
||||||
local_users = set()
|
local_users = set()
|
||||||
|
|
||||||
for user_id in user_ids:
|
for user_id in user_ids:
|
||||||
@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler:
|
|||||||
raise SynapseError(400, "Some user_ids are not local")
|
raise SynapseError(400, "Some user_ids are not local")
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
failed_results = []
|
failed_results = [] # type: List[str]
|
||||||
for destination, dest_user_ids in destinations.items():
|
for destination, dest_user_ids in destinations.items():
|
||||||
try:
|
try:
|
||||||
r = await self.transport_client.bulk_get_publicised_groups(
|
r = await self.transport_client.bulk_get_publicised_groups(
|
||||||
@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler:
|
|||||||
|
|
||||||
|
|
||||||
class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
# Ensure attestations get renewed
|
# Ensure attestations get renewed
|
||||||
@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
|
|
||||||
set_group_join_policy = _create_rerouter("set_group_join_policy")
|
set_group_join_policy = _create_rerouter("set_group_join_policy")
|
||||||
|
|
||||||
async def create_group(self, group_id, user_id, content):
|
async def create_group(
|
||||||
|
self, group_id: str, user_id: str, content: JsonDict
|
||||||
|
) -> JsonDict:
|
||||||
"""Create a group
|
"""Create a group
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
local_attestation = None
|
local_attestation = None
|
||||||
remote_attestation = None
|
remote_attestation = None
|
||||||
else:
|
else:
|
||||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
raise SynapseError(400, "Unable to create remote groups")
|
||||||
content["attestation"] = local_attestation
|
|
||||||
|
|
||||||
content["user_profile"] = await self.profile_handler.get_profile(user_id)
|
|
||||||
|
|
||||||
try:
|
|
||||||
res = await self.transport_client.create_group(
|
|
||||||
get_domain_from_id(group_id), group_id, user_id, content
|
|
||||||
)
|
|
||||||
except HttpResponseException as e:
|
|
||||||
raise e.to_synapse_error()
|
|
||||||
except RequestSendFailed:
|
|
||||||
raise SynapseError(502, "Failed to contact group server")
|
|
||||||
|
|
||||||
remote_attestation = res["attestation"]
|
|
||||||
await self.attestations.verify_attestation(
|
|
||||||
remote_attestation,
|
|
||||||
group_id=group_id,
|
|
||||||
user_id=user_id,
|
|
||||||
server_name=get_domain_from_id(group_id),
|
|
||||||
)
|
|
||||||
|
|
||||||
is_publicised = content.get("publicise", False)
|
is_publicised = content.get("publicise", False)
|
||||||
token = await self.store.register_user_group_membership(
|
token = await self.store.register_user_group_membership(
|
||||||
@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def join_group(self, group_id, user_id, content):
|
async def join_group(
|
||||||
|
self, group_id: str, user_id: str, content: JsonDict
|
||||||
|
) -> JsonDict:
|
||||||
"""Request to join a group
|
"""Request to join a group
|
||||||
"""
|
"""
|
||||||
if self.is_mine_id(group_id):
|
if self.is_mine_id(group_id):
|
||||||
@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def accept_invite(self, group_id, user_id, content):
|
async def accept_invite(
|
||||||
|
self, group_id: str, user_id: str, content: JsonDict
|
||||||
|
) -> JsonDict:
|
||||||
"""Accept an invite to a group
|
"""Accept an invite to a group
|
||||||
"""
|
"""
|
||||||
if self.is_mine_id(group_id):
|
if self.is_mine_id(group_id):
|
||||||
@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def invite(self, group_id, user_id, requester_user_id, config):
|
async def invite(
|
||||||
|
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
|
||||||
|
) -> JsonDict:
|
||||||
"""Invite a user to a group
|
"""Invite a user to a group
|
||||||
"""
|
"""
|
||||||
content = {"requester_user_id": requester_user_id, "config": config}
|
content = {"requester_user_id": requester_user_id, "config": config}
|
||||||
@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def on_invite(self, group_id, user_id, content):
|
async def on_invite(
|
||||||
|
self, group_id: str, user_id: str, content: JsonDict
|
||||||
|
) -> JsonDict:
|
||||||
"""One of our users were invited to a group
|
"""One of our users were invited to a group
|
||||||
"""
|
"""
|
||||||
# TODO: Support auto join and rejection
|
# TODO: Support auto join and rejection
|
||||||
@ -465,8 +464,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
return {"state": "invite", "user_profile": user_profile}
|
return {"state": "invite", "user_profile": user_profile}
|
||||||
|
|
||||||
async def remove_user_from_group(
|
async def remove_user_from_group(
|
||||||
self, group_id, user_id, requester_user_id, content
|
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
|
||||||
):
|
) -> JsonDict:
|
||||||
"""Remove a user from a group
|
"""Remove a user from a group
|
||||||
"""
|
"""
|
||||||
if user_id == requester_user_id:
|
if user_id == requester_user_id:
|
||||||
@ -499,7 +498,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def user_removed_from_group(self, group_id, user_id, content):
|
async def user_removed_from_group(
|
||||||
|
self, group_id: str, user_id: str, content: JsonDict
|
||||||
|
) -> None:
|
||||||
"""One of our users was removed/kicked from a group
|
"""One of our users was removed/kicked from a group
|
||||||
"""
|
"""
|
||||||
# TODO: Check if user in group
|
# TODO: Check if user in group
|
||||||
|
@ -15,23 +15,28 @@
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from typing import Iterable
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
from unpaddedbase64 import decode_base64, encode_base64
|
from unpaddedbase64 import decode_base64, encode_base64
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import NotFoundError, SynapseError
|
from synapse.api.errors import NotFoundError, SynapseError
|
||||||
from synapse.api.filtering import Filter
|
from synapse.api.filtering import Filter
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
|
from synapse.types import JsonDict, UserID
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SearchHandler(BaseHandler):
|
class SearchHandler(BaseHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
self.storage = hs.get_storage()
|
self.storage = hs.get_storage()
|
||||||
@ -87,13 +92,15 @@ class SearchHandler(BaseHandler):
|
|||||||
|
|
||||||
return historical_room_ids
|
return historical_room_ids
|
||||||
|
|
||||||
async def search(self, user, content, batch=None):
|
async def search(
|
||||||
|
self, user: UserID, content: JsonDict, batch: Optional[str] = None
|
||||||
|
) -> JsonDict:
|
||||||
"""Performs a full text search for a user.
|
"""Performs a full text search for a user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user (UserID)
|
user
|
||||||
content (dict): Search parameters
|
content: Search parameters
|
||||||
batch (str): The next_batch parameter. Used for pagination.
|
batch: The next_batch parameter. Used for pagination.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict to be returned to the client with results of search
|
dict to be returned to the client with results of search
|
||||||
@ -186,7 +193,7 @@ class SearchHandler(BaseHandler):
|
|||||||
# If doing a subset of all rooms seearch, check if any of the rooms
|
# If doing a subset of all rooms seearch, check if any of the rooms
|
||||||
# are from an upgraded room, and search their contents as well
|
# are from an upgraded room, and search their contents as well
|
||||||
if search_filter.rooms:
|
if search_filter.rooms:
|
||||||
historical_room_ids = []
|
historical_room_ids = [] # type: List[str]
|
||||||
for room_id in search_filter.rooms:
|
for room_id in search_filter.rooms:
|
||||||
# Add any previous rooms to the search if they exist
|
# Add any previous rooms to the search if they exist
|
||||||
ids = await self.get_old_rooms_from_upgraded_room(room_id)
|
ids = await self.get_old_rooms_from_upgraded_room(room_id)
|
||||||
@ -209,8 +216,10 @@ class SearchHandler(BaseHandler):
|
|||||||
|
|
||||||
rank_map = {} # event_id -> rank of event
|
rank_map = {} # event_id -> rank of event
|
||||||
allowed_events = []
|
allowed_events = []
|
||||||
room_groups = {} # Holds result of grouping by room, if applicable
|
# Holds result of grouping by room, if applicable
|
||||||
sender_group = {} # Holds result of grouping by sender, if applicable
|
room_groups = {} # type: Dict[str, JsonDict]
|
||||||
|
# Holds result of grouping by sender, if applicable
|
||||||
|
sender_group = {} # type: Dict[str, JsonDict]
|
||||||
|
|
||||||
# Holds the next_batch for the entire result set if one of those exists
|
# Holds the next_batch for the entire result set if one of those exists
|
||||||
global_next_batch = None
|
global_next_batch = None
|
||||||
@ -254,7 +263,7 @@ class SearchHandler(BaseHandler):
|
|||||||
s["results"].append(e.event_id)
|
s["results"].append(e.event_id)
|
||||||
|
|
||||||
elif order_by == "recent":
|
elif order_by == "recent":
|
||||||
room_events = []
|
room_events = [] # type: List[EventBase]
|
||||||
i = 0
|
i = 0
|
||||||
|
|
||||||
pagination_token = batch_token
|
pagination_token = batch_token
|
||||||
@ -418,13 +427,10 @@ class SearchHandler(BaseHandler):
|
|||||||
|
|
||||||
state_results = {}
|
state_results = {}
|
||||||
if include_state:
|
if include_state:
|
||||||
rooms = {e.room_id for e in allowed_events}
|
for room_id in {e.room_id for e in allowed_events}:
|
||||||
for room_id in rooms:
|
|
||||||
state = await self.state_handler.get_current_state(room_id)
|
state = await self.state_handler.get_current_state(room_id)
|
||||||
state_results[room_id] = list(state.values())
|
state_results[room_id] = list(state.values())
|
||||||
|
|
||||||
state_results.values()
|
|
||||||
|
|
||||||
# We're now about to serialize the events. We should not make any
|
# We're now about to serialize the events. We should not make any
|
||||||
# blocking calls after this. Otherwise the 'age' will be wrong
|
# blocking calls after this. Otherwise the 'age' will be wrong
|
||||||
|
|
||||||
@ -448,9 +454,9 @@ class SearchHandler(BaseHandler):
|
|||||||
|
|
||||||
if state_results:
|
if state_results:
|
||||||
s = {}
|
s = {}
|
||||||
for room_id, state in state_results.items():
|
for room_id, state_events in state_results.items():
|
||||||
s[room_id] = await self._event_serializer.serialize_events(
|
s[room_id] = await self._event_serializer.serialize_events(
|
||||||
state, time_now
|
state_events, time_now
|
||||||
)
|
)
|
||||||
|
|
||||||
rooms_cat_res["state"] = s
|
rooms_cat_res["state"] = s
|
||||||
|
@ -13,24 +13,26 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||||
from synapse.types import Requester
|
from synapse.types import Requester
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SetPasswordHandler(BaseHandler):
|
class SetPasswordHandler(BaseHandler):
|
||||||
"""Handler which deals with changing user account passwords"""
|
"""Handler which deals with changing user account passwords"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self._device_handler = hs.get_device_handler()
|
self._device_handler = hs.get_device_handler()
|
||||||
self._password_policy_handler = hs.get_password_policy_handler()
|
|
||||||
|
|
||||||
async def set_password(
|
async def set_password(
|
||||||
self,
|
self,
|
||||||
@ -38,7 +40,7 @@ class SetPasswordHandler(BaseHandler):
|
|||||||
password_hash: str,
|
password_hash: str,
|
||||||
logout_devices: bool,
|
logout_devices: bool,
|
||||||
requester: Optional[Requester] = None,
|
requester: Optional[Requester] = None,
|
||||||
):
|
) -> None:
|
||||||
if not self.hs.config.password_localdb_enabled:
|
if not self.hs.config.password_localdb_enabled:
|
||||||
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
|
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
@ -14,15 +14,25 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class StateDeltasHandler:
|
class StateDeltasHandler:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
|
async def _get_key_change(
|
||||||
|
self,
|
||||||
|
prev_event_id: Optional[str],
|
||||||
|
event_id: Optional[str],
|
||||||
|
key_name: str,
|
||||||
|
public_value: str,
|
||||||
|
) -> Optional[bool]:
|
||||||
"""Given two events check if the `key_name` field in content changed
|
"""Given two events check if the `key_name` field in content changed
|
||||||
from not matching `public_value` to doing so.
|
from not matching `public_value` to doing so.
|
||||||
|
|
||||||
|
@ -12,13 +12,19 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
|
from typing_extensions import Counter as CounterType
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.metrics import event_processing_positions
|
from synapse.metrics import event_processing_positions
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -31,7 +37,7 @@ class StatsHandler:
|
|||||||
Heavily derived from UserDirectoryHandler
|
Heavily derived from UserDirectoryHandler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
@ -44,7 +50,7 @@ class StatsHandler:
|
|||||||
self.stats_enabled = hs.config.stats_enabled
|
self.stats_enabled = hs.config.stats_enabled
|
||||||
|
|
||||||
# The current position in the current_state_delta stream
|
# The current position in the current_state_delta stream
|
||||||
self.pos = None
|
self.pos = None # type: Optional[int]
|
||||||
|
|
||||||
# Guard to ensure we only process deltas one at a time
|
# Guard to ensure we only process deltas one at a time
|
||||||
self._is_processing = False
|
self._is_processing = False
|
||||||
@ -56,7 +62,7 @@ class StatsHandler:
|
|||||||
# we start populating stats
|
# we start populating stats
|
||||||
self.clock.call_later(0, self.notify_new_event)
|
self.clock.call_later(0, self.notify_new_event)
|
||||||
|
|
||||||
def notify_new_event(self):
|
def notify_new_event(self) -> None:
|
||||||
"""Called when there may be more deltas to process
|
"""Called when there may be more deltas to process
|
||||||
"""
|
"""
|
||||||
if not self.stats_enabled or self._is_processing:
|
if not self.stats_enabled or self._is_processing:
|
||||||
@ -72,7 +78,7 @@ class StatsHandler:
|
|||||||
|
|
||||||
run_as_background_process("stats.notify_new_event", process)
|
run_as_background_process("stats.notify_new_event", process)
|
||||||
|
|
||||||
async def _unsafe_process(self):
|
async def _unsafe_process(self) -> None:
|
||||||
# If self.pos is None then means we haven't fetched it from DB
|
# If self.pos is None then means we haven't fetched it from DB
|
||||||
if self.pos is None:
|
if self.pos is None:
|
||||||
self.pos = await self.store.get_stats_positions()
|
self.pos = await self.store.get_stats_positions()
|
||||||
@ -110,10 +116,10 @@ class StatsHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
for room_id, fields in room_count.items():
|
for room_id, fields in room_count.items():
|
||||||
room_deltas.setdefault(room_id, {}).update(fields)
|
room_deltas.setdefault(room_id, Counter()).update(fields)
|
||||||
|
|
||||||
for user_id, fields in user_count.items():
|
for user_id, fields in user_count.items():
|
||||||
user_deltas.setdefault(user_id, {}).update(fields)
|
user_deltas.setdefault(user_id, Counter()).update(fields)
|
||||||
|
|
||||||
logger.debug("room_deltas: %s", room_deltas)
|
logger.debug("room_deltas: %s", room_deltas)
|
||||||
logger.debug("user_deltas: %s", user_deltas)
|
logger.debug("user_deltas: %s", user_deltas)
|
||||||
@ -131,19 +137,20 @@ class StatsHandler:
|
|||||||
|
|
||||||
self.pos = max_pos
|
self.pos = max_pos
|
||||||
|
|
||||||
async def _handle_deltas(self, deltas):
|
async def _handle_deltas(
|
||||||
|
self, deltas: Iterable[JsonDict]
|
||||||
|
) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
|
||||||
"""Called with the state deltas to process
|
"""Called with the state deltas to process
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[dict[str, Counter], dict[str, counter]]
|
|
||||||
Two dicts: the room deltas and the user deltas,
|
Two dicts: the room deltas and the user deltas,
|
||||||
mapping from room/user ID to changes in the various fields.
|
mapping from room/user ID to changes in the various fields.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
room_to_stats_deltas = {}
|
room_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
|
||||||
user_to_stats_deltas = {}
|
user_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
|
||||||
|
|
||||||
room_to_state_updates = {}
|
room_to_state_updates = {} # type: Dict[str, Dict[str, Any]]
|
||||||
|
|
||||||
for delta in deltas:
|
for delta in deltas:
|
||||||
typ = delta["type"]
|
typ = delta["type"]
|
||||||
@ -173,7 +180,7 @@ class StatsHandler:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
event_content = {}
|
event_content = {} # type: JsonDict
|
||||||
|
|
||||||
sender = None
|
sender = None
|
||||||
if event_id is not None:
|
if event_id is not None:
|
||||||
@ -257,13 +264,13 @@ class StatsHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if has_changed_joinedness:
|
if has_changed_joinedness:
|
||||||
delta = +1 if membership == Membership.JOIN else -1
|
membership_delta = +1 if membership == Membership.JOIN else -1
|
||||||
|
|
||||||
user_to_stats_deltas.setdefault(user_id, Counter())[
|
user_to_stats_deltas.setdefault(user_id, Counter())[
|
||||||
"joined_rooms"
|
"joined_rooms"
|
||||||
] += delta
|
] += membership_delta
|
||||||
|
|
||||||
room_stats_delta["local_users_in_room"] += delta
|
room_stats_delta["local_users_in_room"] += membership_delta
|
||||||
|
|
||||||
elif typ == EventTypes.Create:
|
elif typ == EventTypes.Create:
|
||||||
room_state["is_federatable"] = (
|
room_state["is_federatable"] = (
|
||||||
|
@ -15,13 +15,13 @@
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import TYPE_CHECKING, List, Set, Tuple
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
|
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.replication.tcp.streams import TypingStream
|
from synapse.replication.tcp.streams import TypingStream
|
||||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.util.wheel_timer import WheelTimer
|
from synapse.util.wheel_timer import WheelTimer
|
||||||
@ -65,17 +65,17 @@ class FollowerTypingHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# map room IDs to serial numbers
|
# map room IDs to serial numbers
|
||||||
self._room_serials = {}
|
self._room_serials = {} # type: Dict[str, int]
|
||||||
# map room IDs to sets of users currently typing
|
# map room IDs to sets of users currently typing
|
||||||
self._room_typing = {}
|
self._room_typing = {} # type: Dict[str, Set[str]]
|
||||||
|
|
||||||
self._member_last_federation_poke = {}
|
self._member_last_federation_poke = {} # type: Dict[RoomMember, int]
|
||||||
self.wheel_timer = WheelTimer(bucket_size=5000)
|
self.wheel_timer = WheelTimer(bucket_size=5000)
|
||||||
self._latest_room_serial = 0
|
self._latest_room_serial = 0
|
||||||
|
|
||||||
self.clock.looping_call(self._handle_timeouts, 5000)
|
self.clock.looping_call(self._handle_timeouts, 5000)
|
||||||
|
|
||||||
def _reset(self):
|
def _reset(self) -> None:
|
||||||
"""Reset the typing handler's data caches.
|
"""Reset the typing handler's data caches.
|
||||||
"""
|
"""
|
||||||
# map room IDs to serial numbers
|
# map room IDs to serial numbers
|
||||||
@ -86,7 +86,7 @@ class FollowerTypingHandler:
|
|||||||
self._member_last_federation_poke = {}
|
self._member_last_federation_poke = {}
|
||||||
self.wheel_timer = WheelTimer(bucket_size=5000)
|
self.wheel_timer = WheelTimer(bucket_size=5000)
|
||||||
|
|
||||||
def _handle_timeouts(self):
|
def _handle_timeouts(self) -> None:
|
||||||
logger.debug("Checking for typing timeouts")
|
logger.debug("Checking for typing timeouts")
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
@ -96,7 +96,7 @@ class FollowerTypingHandler:
|
|||||||
for member in members:
|
for member in members:
|
||||||
self._handle_timeout_for_member(now, member)
|
self._handle_timeout_for_member(now, member)
|
||||||
|
|
||||||
def _handle_timeout_for_member(self, now: int, member: RoomMember):
|
def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
|
||||||
if not self.is_typing(member):
|
if not self.is_typing(member):
|
||||||
# Nothing to do if they're no longer typing
|
# Nothing to do if they're no longer typing
|
||||||
return
|
return
|
||||||
@ -114,10 +114,10 @@ class FollowerTypingHandler:
|
|||||||
# each person typing.
|
# each person typing.
|
||||||
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
|
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
|
||||||
|
|
||||||
def is_typing(self, member):
|
def is_typing(self, member: RoomMember) -> bool:
|
||||||
return member.user_id in self._room_typing.get(member.room_id, [])
|
return member.user_id in self._room_typing.get(member.room_id, [])
|
||||||
|
|
||||||
async def _push_remote(self, member, typing):
|
async def _push_remote(self, member: RoomMember, typing: bool) -> None:
|
||||||
if not self.federation:
|
if not self.federation:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ class FollowerTypingHandler:
|
|||||||
|
|
||||||
def process_replication_rows(
|
def process_replication_rows(
|
||||||
self, token: int, rows: List[TypingStream.TypingStreamRow]
|
self, token: int, rows: List[TypingStream.TypingStreamRow]
|
||||||
):
|
) -> None:
|
||||||
"""Should be called whenever we receive updates for typing stream.
|
"""Should be called whenever we receive updates for typing stream.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -178,7 +178,7 @@ class FollowerTypingHandler:
|
|||||||
|
|
||||||
async def _send_changes_in_typing_to_remotes(
|
async def _send_changes_in_typing_to_remotes(
|
||||||
self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
|
self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
|
||||||
):
|
) -> None:
|
||||||
"""Process a change in typing of a room from replication, sending EDUs
|
"""Process a change in typing of a room from replication, sending EDUs
|
||||||
for any local users.
|
for any local users.
|
||||||
"""
|
"""
|
||||||
@ -194,12 +194,12 @@ class FollowerTypingHandler:
|
|||||||
if self.is_mine_id(user_id):
|
if self.is_mine_id(user_id):
|
||||||
await self._push_remote(RoomMember(room_id, user_id), False)
|
await self._push_remote(RoomMember(room_id, user_id), False)
|
||||||
|
|
||||||
def get_current_token(self):
|
def get_current_token(self) -> int:
|
||||||
return self._latest_room_serial
|
return self._latest_room_serial
|
||||||
|
|
||||||
|
|
||||||
class TypingWriterHandler(FollowerTypingHandler):
|
class TypingWriterHandler(FollowerTypingHandler):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
assert hs.config.worker.writers.typing == hs.get_instance_name()
|
assert hs.config.worker.writers.typing == hs.get_instance_name()
|
||||||
@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||||||
|
|
||||||
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
||||||
|
|
||||||
self._member_typing_until = {} # clock time we expect to stop
|
# clock time we expect to stop
|
||||||
|
self._member_typing_until = {} # type: Dict[RoomMember, int]
|
||||||
|
|
||||||
# caches which room_ids changed at which serials
|
# caches which room_ids changed at which serials
|
||||||
self._typing_stream_change_cache = StreamChangeCache(
|
self._typing_stream_change_cache = StreamChangeCache(
|
||||||
"TypingStreamChangeCache", self._latest_room_serial
|
"TypingStreamChangeCache", self._latest_room_serial
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_timeout_for_member(self, now: int, member: RoomMember):
|
def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
|
||||||
super()._handle_timeout_for_member(now, member)
|
super()._handle_timeout_for_member(now, member)
|
||||||
|
|
||||||
if not self.is_typing(member):
|
if not self.is_typing(member):
|
||||||
@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||||||
self._stopped_typing(member)
|
self._stopped_typing(member)
|
||||||
return
|
return
|
||||||
|
|
||||||
async def started_typing(self, target_user, requester, room_id, timeout):
|
async def started_typing(
|
||||||
|
self, target_user: UserID, requester: Requester, room_id: str, timeout: int
|
||||||
|
) -> None:
|
||||||
target_user_id = target_user.to_string()
|
target_user_id = target_user.to_string()
|
||||||
auth_user_id = requester.user.to_string()
|
auth_user_id = requester.user.to_string()
|
||||||
|
|
||||||
@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||||||
|
|
||||||
if was_present:
|
if was_present:
|
||||||
# No point sending another notification
|
# No point sending another notification
|
||||||
return None
|
return
|
||||||
|
|
||||||
self._push_update(member=member, typing=True)
|
self._push_update(member=member, typing=True)
|
||||||
|
|
||||||
async def stopped_typing(self, target_user, requester, room_id):
|
async def stopped_typing(
|
||||||
|
self, target_user: UserID, requester: Requester, room_id: str
|
||||||
|
) -> None:
|
||||||
target_user_id = target_user.to_string()
|
target_user_id = target_user.to_string()
|
||||||
auth_user_id = requester.user.to_string()
|
auth_user_id = requester.user.to_string()
|
||||||
|
|
||||||
@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||||||
|
|
||||||
self._stopped_typing(member)
|
self._stopped_typing(member)
|
||||||
|
|
||||||
def user_left_room(self, user, room_id):
|
def user_left_room(self, user: UserID, room_id: str) -> None:
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
if self.is_mine_id(user_id):
|
if self.is_mine_id(user_id):
|
||||||
member = RoomMember(room_id=room_id, user_id=user_id)
|
member = RoomMember(room_id=room_id, user_id=user_id)
|
||||||
self._stopped_typing(member)
|
self._stopped_typing(member)
|
||||||
|
|
||||||
def _stopped_typing(self, member):
|
def _stopped_typing(self, member: RoomMember) -> None:
|
||||||
if member.user_id not in self._room_typing.get(member.room_id, set()):
|
if member.user_id not in self._room_typing.get(member.room_id, set()):
|
||||||
# No point
|
# No point
|
||||||
return None
|
return
|
||||||
|
|
||||||
self._member_typing_until.pop(member, None)
|
self._member_typing_until.pop(member, None)
|
||||||
self._member_last_federation_poke.pop(member, None)
|
self._member_last_federation_poke.pop(member, None)
|
||||||
|
|
||||||
self._push_update(member=member, typing=False)
|
self._push_update(member=member, typing=False)
|
||||||
|
|
||||||
def _push_update(self, member, typing):
|
def _push_update(self, member: RoomMember, typing: bool) -> None:
|
||||||
if self.hs.is_mine_id(member.user_id):
|
if self.hs.is_mine_id(member.user_id):
|
||||||
# Only send updates for changes to our own users.
|
# Only send updates for changes to our own users.
|
||||||
run_as_background_process(
|
run_as_background_process(
|
||||||
@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||||||
|
|
||||||
self._push_update_local(member=member, typing=typing)
|
self._push_update_local(member=member, typing=typing)
|
||||||
|
|
||||||
async def _recv_edu(self, origin, content):
|
async def _recv_edu(self, origin: str, content: JsonDict) -> None:
|
||||||
room_id = content["room_id"]
|
room_id = content["room_id"]
|
||||||
user_id = content["user_id"]
|
user_id = content["user_id"]
|
||||||
|
|
||||||
@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||||||
self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
|
self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
|
||||||
self._push_update_local(member=member, typing=content["typing"])
|
self._push_update_local(member=member, typing=content["typing"])
|
||||||
|
|
||||||
def _push_update_local(self, member, typing):
|
def _push_update_local(self, member: RoomMember, typing: bool) -> None:
|
||||||
room_set = self._room_typing.setdefault(member.room_id, set())
|
room_set = self._room_typing.setdefault(member.room_id, set())
|
||||||
if typing:
|
if typing:
|
||||||
room_set.add(member.user_id)
|
room_set.add(member.user_id)
|
||||||
@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||||||
|
|
||||||
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
|
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
|
||||||
last_id
|
last_id
|
||||||
)
|
) # type: Optional[Iterable[str]]
|
||||||
|
|
||||||
if changed_rooms is None:
|
if changed_rooms is None:
|
||||||
changed_rooms = self._room_serials
|
changed_rooms = self._room_serials
|
||||||
@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||||||
|
|
||||||
def process_replication_rows(
|
def process_replication_rows(
|
||||||
self, token: int, rows: List[TypingStream.TypingStreamRow]
|
self, token: int, rows: List[TypingStream.TypingStreamRow]
|
||||||
):
|
) -> None:
|
||||||
# The writing process should never get updates from replication.
|
# The writing process should never get updates from replication.
|
||||||
raise Exception("Typing writer instance got typing info over replication")
|
raise Exception("Typing writer instance got typing info over replication")
|
||||||
|
|
||||||
|
|
||||||
class TypingNotificationEventSource:
|
class TypingNotificationEventSource:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
# We can't call get_typing_handler here because there's a cycle:
|
# We can't call get_typing_handler here because there's a cycle:
|
||||||
@ -427,7 +432,7 @@ class TypingNotificationEventSource:
|
|||||||
#
|
#
|
||||||
self.get_typing_handler = hs.get_typing_handler
|
self.get_typing_handler = hs.get_typing_handler
|
||||||
|
|
||||||
def _make_event_for(self, room_id):
|
def _make_event_for(self, room_id: str) -> JsonDict:
|
||||||
typing = self.get_typing_handler()._room_typing[room_id]
|
typing = self.get_typing_handler()._room_typing[room_id]
|
||||||
return {
|
return {
|
||||||
"type": "m.typing",
|
"type": "m.typing",
|
||||||
@ -462,7 +467,9 @@ class TypingNotificationEventSource:
|
|||||||
|
|
||||||
return (events, handler._latest_room_serial)
|
return (events, handler._latest_room_serial)
|
||||||
|
|
||||||
async def get_new_events(self, from_key, room_ids, **kwargs):
|
async def get_new_events(
|
||||||
|
self, from_key: int, room_ids: Iterable[str], **kwargs
|
||||||
|
) -> Tuple[List[JsonDict], int]:
|
||||||
with Measure(self.clock, "typing.get_new_events"):
|
with Measure(self.clock, "typing.get_new_events"):
|
||||||
from_key = int(from_key)
|
from_key = int(from_key)
|
||||||
handler = self.get_typing_handler()
|
handler = self.get_typing_handler()
|
||||||
@ -478,5 +485,5 @@ class TypingNotificationEventSource:
|
|||||||
|
|
||||||
return (events, handler._latest_room_serial)
|
return (events, handler._latest_room_serial)
|
||||||
|
|
||||||
def get_current_key(self):
|
def get_current_key(self) -> int:
|
||||||
return self.get_typing_handler()._latest_room_serial
|
return self.get_typing_handler()._latest_room_serial
|
||||||
|
@ -145,10 +145,6 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||||||
if self.pos is None:
|
if self.pos is None:
|
||||||
self.pos = await self.store.get_user_directory_stream_pos()
|
self.pos = await self.store.get_user_directory_stream_pos()
|
||||||
|
|
||||||
# If still None then the initial background update hasn't happened yet
|
|
||||||
if self.pos is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Loop round handling deltas until we're up to date
|
# Loop round handling deltas until we're up to date
|
||||||
while True:
|
while True:
|
||||||
with Measure(self.clock, "user_dir_delta"):
|
with Measure(self.clock, "user_dir_delta"):
|
||||||
@ -233,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||||||
|
|
||||||
if change: # The user joined
|
if change: # The user joined
|
||||||
event = await self.store.get_event(event_id, allow_none=True)
|
event = await self.store.get_event(event_id, allow_none=True)
|
||||||
|
# It isn't expected for this event to not exist, but we
|
||||||
|
# don't want the entire background process to break.
|
||||||
|
if event is None:
|
||||||
|
continue
|
||||||
|
|
||||||
profile = ProfileInfo(
|
profile = ProfileInfo(
|
||||||
avatar_url=event.content.get("avatar_url"),
|
avatar_url=event.content.get("avatar_url"),
|
||||||
display_name=event.content.get("displayname"),
|
display_name=event.content.get("displayname"),
|
||||||
|
@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
|
|||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||||
|
from synapse.types import Collection
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
|||||||
|
|
||||||
async def search_rooms(
|
async def search_rooms(
|
||||||
self,
|
self,
|
||||||
room_ids: List[str],
|
room_ids: Collection[str],
|
||||||
search_term: str,
|
search_term: str,
|
||||||
keys: List[str],
|
keys: List[str],
|
||||||
limit,
|
limit,
|
||||||
|
@ -15,11 +15,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import Counter
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from typing_extensions import Counter
|
||||||
|
|
||||||
from twisted.internet.defer import DeferredLock
|
from twisted.internet.defer import DeferredLock
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
@ -319,7 +320,9 @@ class StatsStore(StateDeltasStore):
|
|||||||
return slice_list
|
return slice_list
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
|
async def get_earliest_token_for_stats(
|
||||||
|
self, stats_type: str, id: str
|
||||||
|
) -> Optional[int]:
|
||||||
"""
|
"""
|
||||||
Fetch the "earliest token". This is used by the room stats delta
|
Fetch the "earliest token". This is used by the room stats delta
|
||||||
processor to ignore deltas that have been processed between the
|
processor to ignore deltas that have been processed between the
|
||||||
@ -339,7 +342,7 @@ class StatsStore(StateDeltasStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def bulk_update_stats_delta(
|
async def bulk_update_stats_delta(
|
||||||
self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
|
self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Bulk update stats tables for a given stream_id and updates the stats
|
"""Bulk update stats tables for a given stream_id and updates the stats
|
||||||
incremental position.
|
incremental position.
|
||||||
@ -665,7 +668,7 @@ class StatsStore(StateDeltasStore):
|
|||||||
|
|
||||||
async def get_changes_room_total_events_and_bytes(
|
async def get_changes_room_total_events_and_bytes(
|
||||||
self, min_pos: int, max_pos: int
|
self, min_pos: int, max_pos: int
|
||||||
) -> Dict[str, Dict[str, int]]:
|
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
|
||||||
"""Fetches the counts of events in the given range of stream IDs.
|
"""Fetches the counts of events in the given range of stream IDs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -683,18 +686,19 @@ class StatsStore(StateDeltasStore):
|
|||||||
max_pos,
|
max_pos,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
|
def get_changes_room_total_events_and_bytes_txn(
|
||||||
|
self, txn, low_pos: int, high_pos: int
|
||||||
|
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
|
||||||
"""Gets the total_events and total_event_bytes counts for rooms and
|
"""Gets the total_events and total_event_bytes counts for rooms and
|
||||||
senders, in a range of stream_orderings (including backfilled events).
|
senders, in a range of stream_orderings (including backfilled events).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn
|
txn
|
||||||
low_pos (int): Low stream ordering
|
low_pos: Low stream ordering
|
||||||
high_pos (int): High stream ordering
|
high_pos: High stream ordering
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
|
The room and user deltas for total_events/total_event_bytes in the
|
||||||
room and user deltas for total_events/total_event_bytes in the
|
|
||||||
format of `stats_id` -> fields
|
format of `stats_id` -> fields
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -540,7 +540,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||||||
desc="get_user_in_directory",
|
desc="get_user_in_directory",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def update_user_directory_stream_pos(self, stream_id: str) -> None:
|
async def update_user_directory_stream_pos(self, stream_id: int) -> None:
|
||||||
await self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
table="user_directory_stream_pos",
|
table="user_directory_stream_pos",
|
||||||
keyvalues={},
|
keyvalues={},
|
||||||
|
Loading…
Reference in New Issue
Block a user