mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-15 18:27:07 -05:00
Update black, and run auto formatting over the codebase (#9381)
- Update black version to the latest
- Run black auto formatting over the codebase
- Run autoformatting according to [`docs/code_style.md
`](80d6dc9783/docs/code_style.md
)
- Update `code_style.md` docs around installing black to use the correct version
This commit is contained in:
parent
5636e597c3
commit
0a00b7ff14
1
changelog.d/9381.misc
Normal file
1
changelog.d/9381.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Update the version of black used to 20.8b1.
|
@ -23,8 +23,7 @@ from twisted.web.http_headers import Headers
|
|||||||
|
|
||||||
|
|
||||||
class HttpClient:
|
class HttpClient:
|
||||||
""" Interface for talking json over http
|
"""Interface for talking json over http"""
|
||||||
"""
|
|
||||||
|
|
||||||
def put_json(self, url, data):
|
def put_json(self, url, data):
|
||||||
"""Sends the specifed json data using PUT
|
"""Sends the specifed json data using PUT
|
||||||
@ -87,8 +86,7 @@ class TwistedHttpClient(HttpClient):
|
|||||||
defer.returnValue(json.loads(body))
|
defer.returnValue(json.loads(body))
|
||||||
|
|
||||||
def _create_put_request(self, url, json_data, headers_dict={}):
|
def _create_put_request(self, url, json_data, headers_dict={}):
|
||||||
""" Wrapper of _create_request to issue a PUT request
|
"""Wrapper of _create_request to issue a PUT request"""
|
||||||
"""
|
|
||||||
|
|
||||||
if "Content-Type" not in headers_dict:
|
if "Content-Type" not in headers_dict:
|
||||||
raise defer.error(RuntimeError("Must include Content-Type header for PUTs"))
|
raise defer.error(RuntimeError("Must include Content-Type header for PUTs"))
|
||||||
@ -98,8 +96,7 @@ class TwistedHttpClient(HttpClient):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _create_get_request(self, url, headers_dict={}):
|
def _create_get_request(self, url, headers_dict={}):
|
||||||
""" Wrapper of _create_request to issue a GET request
|
"""Wrapper of _create_request to issue a GET request"""
|
||||||
"""
|
|
||||||
return self._create_request("GET", url, headers_dict=headers_dict)
|
return self._create_request("GET", url, headers_dict=headers_dict)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -127,8 +124,7 @@ class TwistedHttpClient(HttpClient):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _create_request(self, method, url, producer=None, headers_dict={}):
|
def _create_request(self, method, url, producer=None, headers_dict={}):
|
||||||
""" Creates and sends a request to the given url
|
"""Creates and sends a request to the given url"""
|
||||||
"""
|
|
||||||
headers_dict["User-Agent"] = ["Synapse Cmd Client"]
|
headers_dict["User-Agent"] = ["Synapse Cmd Client"]
|
||||||
|
|
||||||
retries_left = 5
|
retries_left = 5
|
||||||
@ -185,8 +181,7 @@ class _RawProducer:
|
|||||||
|
|
||||||
|
|
||||||
class _JsonProducer:
|
class _JsonProducer:
|
||||||
""" Used by the twisted http client to create the HTTP body from json
|
"""Used by the twisted http client to create the HTTP body from json"""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, jsn):
|
def __init__(self, jsn):
|
||||||
self.data = jsn
|
self.data = jsn
|
||||||
|
@ -63,8 +63,7 @@ class CursesStdIO:
|
|||||||
self.redraw()
|
self.redraw()
|
||||||
|
|
||||||
def redraw(self):
|
def redraw(self):
|
||||||
""" method for redisplaying lines
|
"""method for redisplaying lines based on internal list of lines"""
|
||||||
based on internal list of lines """
|
|
||||||
|
|
||||||
self.stdscr.clear()
|
self.stdscr.clear()
|
||||||
self.paintStatus(self.statusText)
|
self.paintStatus(self.statusText)
|
||||||
|
@ -68,8 +68,7 @@ class InputOutput:
|
|||||||
self.server = server
|
self.server = server
|
||||||
|
|
||||||
def on_line(self, line):
|
def on_line(self, line):
|
||||||
""" This is where we process commands.
|
"""This is where we process commands."""
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
m = re.match(r"^join (\S+)$", line)
|
m = re.match(r"^join (\S+)$", line)
|
||||||
@ -148,8 +147,7 @@ class Room:
|
|||||||
self.have_got_metadata = False
|
self.have_got_metadata = False
|
||||||
|
|
||||||
def add_participant(self, participant):
|
def add_participant(self, participant):
|
||||||
""" Someone has joined the room
|
"""Someone has joined the room"""
|
||||||
"""
|
|
||||||
self.participants.add(participant)
|
self.participants.add(participant)
|
||||||
self.invited.discard(participant)
|
self.invited.discard(participant)
|
||||||
|
|
||||||
@ -160,8 +158,7 @@ class Room:
|
|||||||
self.oldest_server = server
|
self.oldest_server = server
|
||||||
|
|
||||||
def add_invited(self, invitee):
|
def add_invited(self, invitee):
|
||||||
""" Someone has been invited to the room
|
"""Someone has been invited to the room"""
|
||||||
"""
|
|
||||||
self.invited.add(invitee)
|
self.invited.add(invitee)
|
||||||
self.servers.add(origin_from_ucid(invitee))
|
self.servers.add(origin_from_ucid(invitee))
|
||||||
|
|
||||||
@ -181,8 +178,7 @@ class HomeServer(ReplicationHandler):
|
|||||||
self.output = output
|
self.output = output
|
||||||
|
|
||||||
def on_receive_pdu(self, pdu):
|
def on_receive_pdu(self, pdu):
|
||||||
""" We just received a PDU
|
"""We just received a PDU"""
|
||||||
"""
|
|
||||||
pdu_type = pdu.pdu_type
|
pdu_type = pdu.pdu_type
|
||||||
|
|
||||||
if pdu_type == "sy.room.message":
|
if pdu_type == "sy.room.message":
|
||||||
@ -199,23 +195,20 @@ class HomeServer(ReplicationHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _on_message(self, pdu):
|
def _on_message(self, pdu):
|
||||||
""" We received a message
|
"""We received a message"""
|
||||||
"""
|
|
||||||
self.output.print_line(
|
self.output.print_line(
|
||||||
"#%s %s %s" % (pdu.context, pdu.content["sender"], pdu.content["body"])
|
"#%s %s %s" % (pdu.context, pdu.content["sender"], pdu.content["body"])
|
||||||
)
|
)
|
||||||
|
|
||||||
def _on_join(self, context, joinee):
|
def _on_join(self, context, joinee):
|
||||||
""" Someone has joined a room, either a remote user or a local user
|
"""Someone has joined a room, either a remote user or a local user"""
|
||||||
"""
|
|
||||||
room = self._get_or_create_room(context)
|
room = self._get_or_create_room(context)
|
||||||
room.add_participant(joinee)
|
room.add_participant(joinee)
|
||||||
|
|
||||||
self.output.print_line("#%s %s %s" % (context, joinee, "*** JOINED"))
|
self.output.print_line("#%s %s %s" % (context, joinee, "*** JOINED"))
|
||||||
|
|
||||||
def _on_invite(self, origin, context, invitee):
|
def _on_invite(self, origin, context, invitee):
|
||||||
""" Someone has been invited
|
"""Someone has been invited"""
|
||||||
"""
|
|
||||||
room = self._get_or_create_room(context)
|
room = self._get_or_create_room(context)
|
||||||
room.add_invited(invitee)
|
room.add_invited(invitee)
|
||||||
|
|
||||||
@ -228,8 +221,7 @@ class HomeServer(ReplicationHandler):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_message(self, room_name, sender, body):
|
def send_message(self, room_name, sender, body):
|
||||||
""" Send a message to a room!
|
"""Send a message to a room!"""
|
||||||
"""
|
|
||||||
destinations = yield self.get_servers_for_context(room_name)
|
destinations = yield self.get_servers_for_context(room_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -247,8 +239,7 @@ class HomeServer(ReplicationHandler):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def join_room(self, room_name, sender, joinee):
|
def join_room(self, room_name, sender, joinee):
|
||||||
""" Join a room!
|
"""Join a room!"""
|
||||||
"""
|
|
||||||
self._on_join(room_name, joinee)
|
self._on_join(room_name, joinee)
|
||||||
|
|
||||||
destinations = yield self.get_servers_for_context(room_name)
|
destinations = yield self.get_servers_for_context(room_name)
|
||||||
@ -269,8 +260,7 @@ class HomeServer(ReplicationHandler):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def invite_to_room(self, room_name, sender, invitee):
|
def invite_to_room(self, room_name, sender, invitee):
|
||||||
""" Invite someone to a room!
|
"""Invite someone to a room!"""
|
||||||
"""
|
|
||||||
self._on_invite(self.server_name, room_name, invitee)
|
self._on_invite(self.server_name, room_name, invitee)
|
||||||
|
|
||||||
destinations = yield self.get_servers_for_context(room_name)
|
destinations = yield self.get_servers_for_context(room_name)
|
||||||
|
@ -193,15 +193,12 @@ class TrivialXmppClient:
|
|||||||
time.sleep(7)
|
time.sleep(7)
|
||||||
print("SSRC spammer started")
|
print("SSRC spammer started")
|
||||||
while self.running:
|
while self.running:
|
||||||
ssrcMsg = (
|
ssrcMsg = "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>" % {
|
||||||
"<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>"
|
|
||||||
% {
|
|
||||||
"tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid),
|
"tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid),
|
||||||
"nick": self.userId,
|
"nick": self.userId,
|
||||||
"assrc": self.ssrcs["audio"],
|
"assrc": self.ssrcs["audio"],
|
||||||
"vssrc": self.ssrcs["video"],
|
"vssrc": self.ssrcs["video"],
|
||||||
}
|
}
|
||||||
)
|
|
||||||
res = self.sendIq(ssrcMsg)
|
res = self.sendIq(ssrcMsg)
|
||||||
print("reply from ssrc announce: ", res)
|
print("reply from ssrc announce: ", res)
|
||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
|
@ -8,16 +8,16 @@ errors in code.
|
|||||||
|
|
||||||
The necessary tools are detailed below.
|
The necessary tools are detailed below.
|
||||||
|
|
||||||
|
First install them with:
|
||||||
|
|
||||||
|
pip install -e ".[lint,mypy]"
|
||||||
|
|
||||||
- **black**
|
- **black**
|
||||||
|
|
||||||
The Synapse codebase uses [black](https://pypi.org/project/black/)
|
The Synapse codebase uses [black](https://pypi.org/project/black/)
|
||||||
as an opinionated code formatter, ensuring all comitted code is
|
as an opinionated code formatter, ensuring all comitted code is
|
||||||
properly formatted.
|
properly formatted.
|
||||||
|
|
||||||
First install `black` with:
|
|
||||||
|
|
||||||
pip install --upgrade black
|
|
||||||
|
|
||||||
Have `black` auto-format your code (it shouldn't change any
|
Have `black` auto-format your code (it shouldn't change any
|
||||||
functionality) with:
|
functionality) with:
|
||||||
|
|
||||||
@ -28,10 +28,6 @@ The necessary tools are detailed below.
|
|||||||
`flake8` is a code checking tool. We require code to pass `flake8`
|
`flake8` is a code checking tool. We require code to pass `flake8`
|
||||||
before being merged into the codebase.
|
before being merged into the codebase.
|
||||||
|
|
||||||
Install `flake8` with:
|
|
||||||
|
|
||||||
pip install --upgrade flake8 flake8-comprehensions
|
|
||||||
|
|
||||||
Check all application and test code with:
|
Check all application and test code with:
|
||||||
|
|
||||||
flake8 synapse tests
|
flake8 synapse tests
|
||||||
@ -41,10 +37,6 @@ The necessary tools are detailed below.
|
|||||||
`isort` ensures imports are nicely formatted, and can suggest and
|
`isort` ensures imports are nicely formatted, and can suggest and
|
||||||
auto-fix issues such as double-importing.
|
auto-fix issues such as double-importing.
|
||||||
|
|
||||||
Install `isort` with:
|
|
||||||
|
|
||||||
pip install --upgrade isort
|
|
||||||
|
|
||||||
Auto-fix imports with:
|
Auto-fix imports with:
|
||||||
|
|
||||||
isort -rc synapse tests
|
isort -rc synapse tests
|
||||||
|
@ -87,7 +87,9 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
|||||||
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
|
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
|
||||||
|
|
||||||
signature = signature.copy_modified(
|
signature = signature.copy_modified(
|
||||||
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
|
arg_types=arg_types,
|
||||||
|
arg_names=arg_names,
|
||||||
|
arg_kinds=arg_kinds,
|
||||||
)
|
)
|
||||||
|
|
||||||
return signature
|
return signature
|
||||||
|
2
setup.py
2
setup.py
@ -97,7 +97,7 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
|
|||||||
# We pin black so that our tests don't start failing on new releases.
|
# We pin black so that our tests don't start failing on new releases.
|
||||||
CONDITIONAL_REQUIREMENTS["lint"] = [
|
CONDITIONAL_REQUIREMENTS["lint"] = [
|
||||||
"isort==5.7.0",
|
"isort==5.7.0",
|
||||||
"black==19.10b0",
|
"black==20.8b1",
|
||||||
"flake8-comprehensions",
|
"flake8-comprehensions",
|
||||||
"flake8",
|
"flake8",
|
||||||
]
|
]
|
||||||
|
@ -89,12 +89,16 @@ class SortedDict(Dict[_KT, _VT]):
|
|||||||
def __reduce__(
|
def __reduce__(
|
||||||
self,
|
self,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
Type[SortedDict[_KT, _VT]], Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]],
|
Type[SortedDict[_KT, _VT]],
|
||||||
|
Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]],
|
||||||
]: ...
|
]: ...
|
||||||
def __repr__(self) -> str: ...
|
def __repr__(self) -> str: ...
|
||||||
def _check(self) -> None: ...
|
def _check(self) -> None: ...
|
||||||
def islice(
|
def islice(
|
||||||
self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool,
|
self,
|
||||||
|
start: Optional[int] = ...,
|
||||||
|
stop: Optional[int] = ...,
|
||||||
|
reverse=bool,
|
||||||
) -> Iterator[_KT]: ...
|
) -> Iterator[_KT]: ...
|
||||||
def bisect_left(self, value: _KT) -> int: ...
|
def bisect_left(self, value: _KT) -> int: ...
|
||||||
def bisect_right(self, value: _KT) -> int: ...
|
def bisect_right(self, value: _KT) -> int: ...
|
||||||
|
@ -31,7 +31,9 @@ class SortedList(MutableSequence[_T]):
|
|||||||
|
|
||||||
DEFAULT_LOAD_FACTOR: int = ...
|
DEFAULT_LOAD_FACTOR: int = ...
|
||||||
def __init__(
|
def __init__(
|
||||||
self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ...,
|
self,
|
||||||
|
iterable: Optional[Iterable[_T]] = ...,
|
||||||
|
key: Optional[_Key[_T]] = ...,
|
||||||
): ...
|
): ...
|
||||||
# NB: currently mypy does not honour return type, see mypy #3307
|
# NB: currently mypy does not honour return type, see mypy #3307
|
||||||
@overload
|
@overload
|
||||||
@ -76,10 +78,18 @@ class SortedList(MutableSequence[_T]):
|
|||||||
def __len__(self) -> int: ...
|
def __len__(self) -> int: ...
|
||||||
def reverse(self) -> None: ...
|
def reverse(self) -> None: ...
|
||||||
def islice(
|
def islice(
|
||||||
self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool,
|
self,
|
||||||
|
start: Optional[int] = ...,
|
||||||
|
stop: Optional[int] = ...,
|
||||||
|
reverse=bool,
|
||||||
) -> Iterator[_T]: ...
|
) -> Iterator[_T]: ...
|
||||||
def _islice(
|
def _islice(
|
||||||
self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool,
|
self,
|
||||||
|
min_pos: int,
|
||||||
|
min_idx: int,
|
||||||
|
max_pos: int,
|
||||||
|
max_idx: int,
|
||||||
|
reverse: bool,
|
||||||
) -> Iterator[_T]: ...
|
) -> Iterator[_T]: ...
|
||||||
def irange(
|
def irange(
|
||||||
self,
|
self,
|
||||||
|
@ -294,7 +294,10 @@ class Auth:
|
|||||||
return user_id, app_service
|
return user_id, app_service
|
||||||
|
|
||||||
async def get_user_by_access_token(
|
async def get_user_by_access_token(
|
||||||
self, token: str, rights: str = "access", allow_expired: bool = False,
|
self,
|
||||||
|
token: str,
|
||||||
|
rights: str = "access",
|
||||||
|
allow_expired: bool = False,
|
||||||
) -> TokenLookupResult:
|
) -> TokenLookupResult:
|
||||||
"""Validate access token and get user_id from it
|
"""Validate access token and get user_id from it
|
||||||
|
|
||||||
@ -500,7 +503,10 @@ class Auth:
|
|||||||
return await self.store.is_server_admin(user)
|
return await self.store.is_server_admin(user)
|
||||||
|
|
||||||
def compute_auth_events(
|
def compute_auth_events(
|
||||||
self, event, current_state_ids: StateMap[str], for_verification: bool = False,
|
self,
|
||||||
|
event,
|
||||||
|
current_state_ids: StateMap[str],
|
||||||
|
for_verification: bool = False,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Given an event and current state return the list of event IDs used
|
"""Given an event and current state return the list of event IDs used
|
||||||
to auth an event.
|
to auth an event.
|
||||||
|
@ -128,8 +128,7 @@ class UserTypes:
|
|||||||
|
|
||||||
|
|
||||||
class RelationTypes:
|
class RelationTypes:
|
||||||
"""The types of relations known to this server.
|
"""The types of relations known to this server."""
|
||||||
"""
|
|
||||||
|
|
||||||
ANNOTATION = "m.annotation"
|
ANNOTATION = "m.annotation"
|
||||||
REPLACE = "m.replace"
|
REPLACE = "m.replace"
|
||||||
|
@ -390,8 +390,7 @@ class InvalidCaptchaError(SynapseError):
|
|||||||
|
|
||||||
|
|
||||||
class LimitExceededError(SynapseError):
|
class LimitExceededError(SynapseError):
|
||||||
"""A client has sent too many requests and is being throttled.
|
"""A client has sent too many requests and is being throttled."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -408,8 +407,7 @@ class LimitExceededError(SynapseError):
|
|||||||
|
|
||||||
|
|
||||||
class RoomKeysVersionError(SynapseError):
|
class RoomKeysVersionError(SynapseError):
|
||||||
"""A client has tried to upload to a non-current version of the room_keys store
|
"""A client has tried to upload to a non-current version of the room_keys store"""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, current_version: str):
|
def __init__(self, current_version: str):
|
||||||
"""
|
"""
|
||||||
@ -426,7 +424,9 @@ class UnsupportedRoomVersionError(SynapseError):
|
|||||||
|
|
||||||
def __init__(self, msg: str = "Homeserver does not support this room version"):
|
def __init__(self, msg: str = "Homeserver does not support this room version"):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION,
|
code=400,
|
||||||
|
msg=msg,
|
||||||
|
errcode=Codes.UNSUPPORTED_ROOM_VERSION,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -461,8 +461,7 @@ class IncompatibleRoomVersionError(SynapseError):
|
|||||||
|
|
||||||
|
|
||||||
class PasswordRefusedError(SynapseError):
|
class PasswordRefusedError(SynapseError):
|
||||||
"""A password has been refused, either during password reset/change or registration.
|
"""A password has been refused, either during password reset/change or registration."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -470,7 +469,9 @@ class PasswordRefusedError(SynapseError):
|
|||||||
errcode: str = Codes.WEAK_PASSWORD,
|
errcode: str = Codes.WEAK_PASSWORD,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
code=400, msg=msg, errcode=errcode,
|
code=400,
|
||||||
|
msg=msg,
|
||||||
|
errcode=errcode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,8 +56,7 @@ class UserPresenceState(
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls, user_id):
|
def default(cls, user_id):
|
||||||
"""Returns a default presence state.
|
"""Returns a default presence state."""
|
||||||
"""
|
|
||||||
return cls(
|
return cls(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
state=PresenceState.OFFLINE,
|
state=PresenceState.OFFLINE,
|
||||||
|
@ -313,9 +313,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
|
|||||||
refresh_certificate(hs)
|
refresh_certificate(hs)
|
||||||
|
|
||||||
# Start the tracer
|
# Start the tracer
|
||||||
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
|
synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa
|
||||||
hs
|
|
||||||
)
|
|
||||||
|
|
||||||
# It is now safe to start your Synapse.
|
# It is now safe to start your Synapse.
|
||||||
hs.start_listening(listeners)
|
hs.start_listening(listeners)
|
||||||
@ -370,8 +368,7 @@ def setup_sentry(hs):
|
|||||||
|
|
||||||
|
|
||||||
def setup_sdnotify(hs):
|
def setup_sdnotify(hs):
|
||||||
"""Adds process state hooks to tell systemd what we are up to.
|
"""Adds process state hooks to tell systemd what we are up to."""
|
||||||
"""
|
|
||||||
|
|
||||||
# Tell systemd our state, if we're using it. This will silently fail if
|
# Tell systemd our state, if we're using it. This will silently fail if
|
||||||
# we're not using systemd.
|
# we're not using systemd.
|
||||||
@ -405,8 +402,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
|
|||||||
|
|
||||||
|
|
||||||
class _LimitedHostnameResolver:
|
class _LimitedHostnameResolver:
|
||||||
"""Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.
|
"""Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, resolver, max_dns_requests_in_flight):
|
def __init__(self, resolver, max_dns_requests_in_flight):
|
||||||
self._resolver = resolver
|
self._resolver = resolver
|
||||||
|
@ -421,8 +421,7 @@ class GenericWorkerPresence(BasePresenceHandler):
|
|||||||
]
|
]
|
||||||
|
|
||||||
async def set_state(self, target_user, state, ignore_status_msg=False):
|
async def set_state(self, target_user, state, ignore_status_msg=False):
|
||||||
"""Set the presence state of the user.
|
"""Set the presence state of the user."""
|
||||||
"""
|
|
||||||
presence = state["presence"]
|
presence = state["presence"]
|
||||||
|
|
||||||
valid_presence = (
|
valid_presence = (
|
||||||
|
@ -166,7 +166,10 @@ class ApplicationService:
|
|||||||
|
|
||||||
@cached(num_args=1, cache_context=True)
|
@cached(num_args=1, cache_context=True)
|
||||||
async def matches_user_in_member_list(
|
async def matches_user_in_member_list(
|
||||||
self, room_id: str, store: "DataStore", cache_context: _CacheContext,
|
self,
|
||||||
|
room_id: str,
|
||||||
|
store: "DataStore",
|
||||||
|
cache_context: _CacheContext,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if this service is interested a room based upon it's membership
|
"""Check if this service is interested a room based upon it's membership
|
||||||
|
|
||||||
|
@ -227,7 +227,9 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
await self.put_json(
|
await self.put_json(
|
||||||
uri=uri, json_body=body, args={"access_token": service.hs_token},
|
uri=uri,
|
||||||
|
json_body=body,
|
||||||
|
args={"access_token": service.hs_token},
|
||||||
)
|
)
|
||||||
sent_transactions_counter.labels(service.id).inc()
|
sent_transactions_counter.labels(service.id).inc()
|
||||||
sent_events_counter.labels(service.id).inc(len(events))
|
sent_events_counter.labels(service.id).inc(len(events))
|
||||||
|
@ -224,7 +224,9 @@ class Config:
|
|||||||
return self.read_templates([filename])[0]
|
return self.read_templates([filename])[0]
|
||||||
|
|
||||||
def read_templates(
|
def read_templates(
|
||||||
self, filenames: List[str], custom_template_directory: Optional[str] = None,
|
self,
|
||||||
|
filenames: List[str],
|
||||||
|
custom_template_directory: Optional[str] = None,
|
||||||
) -> List[jinja2.Template]:
|
) -> List[jinja2.Template]:
|
||||||
"""Load a list of template files from disk using the given variables.
|
"""Load a list of template files from disk using the given variables.
|
||||||
|
|
||||||
@ -264,7 +266,10 @@ class Config:
|
|||||||
|
|
||||||
# TODO: switch to synapse.util.templates.build_jinja_env
|
# TODO: switch to synapse.util.templates.build_jinja_env
|
||||||
loader = jinja2.FileSystemLoader(search_directories)
|
loader = jinja2.FileSystemLoader(search_directories)
|
||||||
env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),)
|
env = jinja2.Environment(
|
||||||
|
loader=loader,
|
||||||
|
autoescape=jinja2.select_autoescape(),
|
||||||
|
)
|
||||||
|
|
||||||
# Update the environment with our custom filters
|
# Update the environment with our custom filters
|
||||||
env.filters.update(
|
env.filters.update(
|
||||||
@ -825,8 +830,7 @@ class ShardedWorkerHandlingConfig:
|
|||||||
instances = attr.ib(type=List[str])
|
instances = attr.ib(type=List[str])
|
||||||
|
|
||||||
def should_handle(self, instance_name: str, key: str) -> bool:
|
def should_handle(self, instance_name: str, key: str) -> bool:
|
||||||
"""Whether this instance is responsible for handling the given key.
|
"""Whether this instance is responsible for handling the given key."""
|
||||||
"""
|
|
||||||
# If multiple instances are not defined we always return true
|
# If multiple instances are not defined we always return true
|
||||||
if not self.instances or len(self.instances) == 1:
|
if not self.instances or len(self.instances) == 1:
|
||||||
return True
|
return True
|
||||||
|
@ -18,8 +18,7 @@ from ._base import Config
|
|||||||
|
|
||||||
|
|
||||||
class AuthConfig(Config):
|
class AuthConfig(Config):
|
||||||
"""Password and login configuration
|
"""Password and login configuration"""
|
||||||
"""
|
|
||||||
|
|
||||||
section = "auth"
|
section = "auth"
|
||||||
|
|
||||||
|
@ -207,8 +207,7 @@ class DatabaseConfig(Config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_single_database(self) -> DatabaseConnectionConfig:
|
def get_single_database(self) -> DatabaseConnectionConfig:
|
||||||
"""Returns the database if there is only one, useful for e.g. tests
|
"""Returns the database if there is only one, useful for e.g. tests"""
|
||||||
"""
|
|
||||||
if not self.databases:
|
if not self.databases:
|
||||||
raise Exception("More than one database exists")
|
raise Exception("More than one database exists")
|
||||||
|
|
||||||
|
@ -289,7 +289,8 @@ class EmailConfig(Config):
|
|||||||
self.email_notif_template_html,
|
self.email_notif_template_html,
|
||||||
self.email_notif_template_text,
|
self.email_notif_template_text,
|
||||||
) = self.read_templates(
|
) = self.read_templates(
|
||||||
[notif_template_html, notif_template_text], template_dir,
|
[notif_template_html, notif_template_text],
|
||||||
|
template_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.email_notif_for_new_users = email_config.get(
|
self.email_notif_for_new_users = email_config.get(
|
||||||
@ -311,7 +312,8 @@ class EmailConfig(Config):
|
|||||||
self.account_validity_template_html,
|
self.account_validity_template_html,
|
||||||
self.account_validity_template_text,
|
self.account_validity_template_text,
|
||||||
) = self.read_templates(
|
) = self.read_templates(
|
||||||
[expiry_template_html, expiry_template_text], template_dir,
|
[expiry_template_html, expiry_template_text],
|
||||||
|
template_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
subjects_config = email_config.get("subjects", {})
|
subjects_config = email_config.get("subjects", {})
|
||||||
|
@ -162,7 +162,10 @@ class LoggingConfig(Config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logging_group.add_argument(
|
logging_group.add_argument(
|
||||||
"-f", "--log-file", dest="log_file", help=argparse.SUPPRESS,
|
"-f",
|
||||||
|
"--log-file",
|
||||||
|
dest="log_file",
|
||||||
|
help=argparse.SUPPRESS,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_files(self, config, config_dir_path):
|
def generate_files(self, config, config_dir_path):
|
||||||
|
@ -355,9 +355,10 @@ def _parse_oidc_config_dict(
|
|||||||
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
|
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
|
||||||
ump_config.setdefault("config", {})
|
ump_config.setdefault("config", {})
|
||||||
|
|
||||||
(user_mapping_provider_class, user_mapping_provider_config,) = load_module(
|
(
|
||||||
ump_config, config_path + ("user_mapping_provider",)
|
user_mapping_provider_class,
|
||||||
)
|
user_mapping_provider_config,
|
||||||
|
) = load_module(ump_config, config_path + ("user_mapping_provider",))
|
||||||
|
|
||||||
# Ensure loaded user mapping module has defined all necessary methods
|
# Ensure loaded user mapping module has defined all necessary methods
|
||||||
required_methods = [
|
required_methods = [
|
||||||
@ -372,7 +373,11 @@ def _parse_oidc_config_dict(
|
|||||||
if missing_methods:
|
if missing_methods:
|
||||||
raise ConfigError(
|
raise ConfigError(
|
||||||
"Class %s is missing required "
|
"Class %s is missing required "
|
||||||
"methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),),
|
"methods: %s"
|
||||||
|
% (
|
||||||
|
user_mapping_provider_class,
|
||||||
|
", ".join(missing_methods),
|
||||||
|
),
|
||||||
config_path + ("user_mapping_provider", "module"),
|
config_path + ("user_mapping_provider", "module"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -52,7 +52,12 @@ def _6to4(network: IPNetwork) -> IPNetwork:
|
|||||||
hex_network = hex(network.first)[2:]
|
hex_network = hex(network.first)[2:]
|
||||||
hex_network = ("0" * (8 - len(hex_network))) + hex_network
|
hex_network = ("0" * (8 - len(hex_network))) + hex_network
|
||||||
return IPNetwork(
|
return IPNetwork(
|
||||||
"2002:%s:%s::/%d" % (hex_network[:4], hex_network[4:], 16 + network.prefixlen,)
|
"2002:%s:%s::/%d"
|
||||||
|
% (
|
||||||
|
hex_network[:4],
|
||||||
|
hex_network[4:],
|
||||||
|
16 + network.prefixlen,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -254,7 +259,8 @@ class ServerConfig(Config):
|
|||||||
# Whether to require sharing a room with a user to retrieve their
|
# Whether to require sharing a room with a user to retrieve their
|
||||||
# profile data
|
# profile data
|
||||||
self.limit_profile_requests_to_users_who_share_rooms = config.get(
|
self.limit_profile_requests_to_users_who_share_rooms = config.get(
|
||||||
"limit_profile_requests_to_users_who_share_rooms", False,
|
"limit_profile_requests_to_users_who_share_rooms",
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "restrict_public_rooms_to_local_users" in config and (
|
if "restrict_public_rooms_to_local_users" in config and (
|
||||||
@ -614,7 +620,9 @@ class ServerConfig(Config):
|
|||||||
if manhole:
|
if manhole:
|
||||||
self.listeners.append(
|
self.listeners.append(
|
||||||
ListenerConfig(
|
ListenerConfig(
|
||||||
port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
|
port=manhole,
|
||||||
|
bind_addresses=["127.0.0.1"],
|
||||||
|
type="manhole",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -650,7 +658,8 @@ class ServerConfig(Config):
|
|||||||
# and letting the client know which email address is bound to an account and
|
# and letting the client know which email address is bound to an account and
|
||||||
# which one isn't.
|
# which one isn't.
|
||||||
self.request_token_inhibit_3pid_errors = config.get(
|
self.request_token_inhibit_3pid_errors = config.get(
|
||||||
"request_token_inhibit_3pid_errors", False,
|
"request_token_inhibit_3pid_errors",
|
||||||
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# List of users trialing the new experimental default push rules. This setting is
|
# List of users trialing the new experimental default push rules. This setting is
|
||||||
|
@ -35,8 +35,7 @@ class SsoAttributeRequirement:
|
|||||||
|
|
||||||
|
|
||||||
class SSOConfig(Config):
|
class SSOConfig(Config):
|
||||||
"""SSO Configuration
|
"""SSO Configuration"""
|
||||||
"""
|
|
||||||
|
|
||||||
section = "sso"
|
section = "sso"
|
||||||
|
|
||||||
|
@ -33,8 +33,7 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
|
|||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
class InstanceLocationConfig:
|
class InstanceLocationConfig:
|
||||||
"""The host and port to talk to an instance via HTTP replication.
|
"""The host and port to talk to an instance via HTTP replication."""
|
||||||
"""
|
|
||||||
|
|
||||||
host = attr.ib(type=str)
|
host = attr.ib(type=str)
|
||||||
port = attr.ib(type=int)
|
port = attr.ib(type=int)
|
||||||
@ -54,13 +53,19 @@ class WriterLocations:
|
|||||||
)
|
)
|
||||||
typing = attr.ib(default="master", type=str)
|
typing = attr.ib(default="master", type=str)
|
||||||
to_device = attr.ib(
|
to_device = attr.ib(
|
||||||
default=["master"], type=List[str], converter=_instance_to_list_converter,
|
default=["master"],
|
||||||
|
type=List[str],
|
||||||
|
converter=_instance_to_list_converter,
|
||||||
)
|
)
|
||||||
account_data = attr.ib(
|
account_data = attr.ib(
|
||||||
default=["master"], type=List[str], converter=_instance_to_list_converter,
|
default=["master"],
|
||||||
|
type=List[str],
|
||||||
|
converter=_instance_to_list_converter,
|
||||||
)
|
)
|
||||||
receipts = attr.ib(
|
receipts = attr.ib(
|
||||||
default=["master"], type=List[str], converter=_instance_to_list_converter,
|
default=["master"],
|
||||||
|
type=List[str],
|
||||||
|
converter=_instance_to_list_converter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -107,7 +112,9 @@ class WorkerConfig(Config):
|
|||||||
if manhole:
|
if manhole:
|
||||||
self.worker_listeners.append(
|
self.worker_listeners.append(
|
||||||
ListenerConfig(
|
ListenerConfig(
|
||||||
port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
|
port=manhole,
|
||||||
|
bind_addresses=["127.0.0.1"],
|
||||||
|
type="manhole",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -423,7 +423,9 @@ def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def check_redaction(
|
def check_redaction(
|
||||||
room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
|
room_version_obj: RoomVersion,
|
||||||
|
event: EventBase,
|
||||||
|
auth_events: StateMap[EventBase],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check whether the event sender is allowed to redact the target event.
|
"""Check whether the event sender is allowed to redact the target event.
|
||||||
|
|
||||||
@ -459,7 +461,9 @@ def check_redaction(
|
|||||||
|
|
||||||
|
|
||||||
def _check_power_levels(
|
def _check_power_levels(
|
||||||
room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
|
room_version_obj: RoomVersion,
|
||||||
|
event: EventBase,
|
||||||
|
auth_events: StateMap[EventBase],
|
||||||
) -> None:
|
) -> None:
|
||||||
user_list = event.content.get("users", {})
|
user_list = event.content.get("users", {})
|
||||||
# Validate users
|
# Validate users
|
||||||
|
@ -98,7 +98,9 @@ class EventBuilder:
|
|||||||
return self._state_key is not None
|
return self._state_key is not None
|
||||||
|
|
||||||
async def build(
|
async def build(
|
||||||
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]],
|
self,
|
||||||
|
prev_event_ids: List[str],
|
||||||
|
auth_event_ids: Optional[List[str]],
|
||||||
) -> EventBase:
|
) -> EventBase:
|
||||||
"""Transform into a fully signed and hashed event
|
"""Transform into a fully signed and hashed event
|
||||||
|
|
||||||
|
@ -341,8 +341,7 @@ def _encode_state_dict(state_dict):
|
|||||||
|
|
||||||
|
|
||||||
def _decode_state_dict(input):
|
def _decode_state_dict(input):
|
||||||
"""Decodes a state dict encoded using `_encode_state_dict` above
|
"""Decodes a state dict encoded using `_encode_state_dict` above"""
|
||||||
"""
|
|
||||||
if input is None:
|
if input is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -40,7 +40,8 @@ class ThirdPartyEventRules:
|
|||||||
|
|
||||||
if module is not None:
|
if module is not None:
|
||||||
self.third_party_rules = module(
|
self.third_party_rules = module(
|
||||||
config=config, module_api=hs.get_module_api(),
|
config=config,
|
||||||
|
module_api=hs.get_module_api(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def check_event_allowed(
|
async def check_event_allowed(
|
||||||
|
@ -750,7 +750,11 @@ class FederationClient(FederationBase):
|
|||||||
return resp[1]
|
return resp[1]
|
||||||
|
|
||||||
async def send_invite(
|
async def send_invite(
|
||||||
self, destination: str, room_id: str, event_id: str, pdu: EventBase,
|
self,
|
||||||
|
destination: str,
|
||||||
|
room_id: str,
|
||||||
|
event_id: str,
|
||||||
|
pdu: EventBase,
|
||||||
) -> EventBase:
|
) -> EventBase:
|
||||||
room_version = await self.store.get_room_version(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
|
|
||||||
|
@ -85,7 +85,8 @@ received_queries_counter = Counter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
pdu_process_time = Histogram(
|
pdu_process_time = Histogram(
|
||||||
"synapse_federation_server_pdu_process_time", "Time taken to process an event",
|
"synapse_federation_server_pdu_process_time",
|
||||||
|
"Time taken to process an event",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -373,8 +374,7 @@ class FederationServer(FederationBase):
|
|||||||
return pdu_results
|
return pdu_results
|
||||||
|
|
||||||
async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
|
async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
|
||||||
"""Process the EDUs in a received transaction.
|
"""Process the EDUs in a received transaction."""
|
||||||
"""
|
|
||||||
|
|
||||||
async def _process_edu(edu_dict):
|
async def _process_edu(edu_dict):
|
||||||
received_edus_counter.inc()
|
received_edus_counter.inc()
|
||||||
@ -437,7 +437,10 @@ class FederationServer(FederationBase):
|
|||||||
raise AuthError(403, "Host not in room.")
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
resp = await self._state_ids_resp_cache.wrap(
|
resp = await self._state_ids_resp_cache.wrap(
|
||||||
(room_id, event_id), self._on_state_ids_request_compute, room_id, event_id,
|
(room_id, event_id),
|
||||||
|
self._on_state_ids_request_compute,
|
||||||
|
room_id,
|
||||||
|
event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, resp
|
return 200, resp
|
||||||
@ -906,13 +909,11 @@ class FederationHandlerRegistry:
|
|||||||
self.query_handlers[query_type] = handler
|
self.query_handlers[query_type] = handler
|
||||||
|
|
||||||
def register_instance_for_edu(self, edu_type: str, instance_name: str):
|
def register_instance_for_edu(self, edu_type: str, instance_name: str):
|
||||||
"""Register that the EDU handler is on a different instance than master.
|
"""Register that the EDU handler is on a different instance than master."""
|
||||||
"""
|
|
||||||
self._edu_type_to_instance[edu_type] = [instance_name]
|
self._edu_type_to_instance[edu_type] = [instance_name]
|
||||||
|
|
||||||
def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
|
def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
|
||||||
"""Register that the EDU handler is on multiple instances.
|
"""Register that the EDU handler is on multiple instances."""
|
||||||
"""
|
|
||||||
self._edu_type_to_instance[edu_type] = instance_names
|
self._edu_type_to_instance[edu_type] = instance_names
|
||||||
|
|
||||||
async def on_edu(self, edu_type: str, origin: str, content: dict):
|
async def on_edu(self, edu_type: str, origin: str, content: dict):
|
||||||
|
@ -30,8 +30,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class TransactionActions:
|
class TransactionActions:
|
||||||
""" Defines persistence actions that relate to handling Transactions.
|
"""Defines persistence actions that relate to handling Transactions."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, datastore):
|
def __init__(self, datastore):
|
||||||
self.store = datastore
|
self.store = datastore
|
||||||
@ -57,8 +56,7 @@ class TransactionActions:
|
|||||||
async def set_response(
|
async def set_response(
|
||||||
self, origin: str, transaction: Transaction, code: int, response: JsonDict
|
self, origin: str, transaction: Transaction, code: int, response: JsonDict
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Persist how we responded to a transaction.
|
"""Persist how we responded to a transaction."""
|
||||||
"""
|
|
||||||
transaction_id = transaction.transaction_id # type: ignore
|
transaction_id = transaction.transaction_id # type: ignore
|
||||||
if not transaction_id:
|
if not transaction_id:
|
||||||
raise RuntimeError("Cannot persist a transaction with no transaction_id")
|
raise RuntimeError("Cannot persist a transaction with no transaction_id")
|
||||||
|
@ -468,8 +468,7 @@ class KeyedEduRow(
|
|||||||
|
|
||||||
|
|
||||||
class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
|
class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
|
||||||
"""Streams EDUs that don't have keys. See KeyedEduRow
|
"""Streams EDUs that don't have keys. See KeyedEduRow"""
|
||||||
"""
|
|
||||||
|
|
||||||
TypeId = "e"
|
TypeId = "e"
|
||||||
|
|
||||||
@ -519,7 +518,10 @@ def process_rows_for_federation(transaction_queue, rows):
|
|||||||
# them into the appropriate collection and then send them off.
|
# them into the appropriate collection and then send them off.
|
||||||
|
|
||||||
buff = ParsedFederationStreamData(
|
buff = ParsedFederationStreamData(
|
||||||
presence=[], presence_destinations=[], keyed_edus={}, edus={},
|
presence=[],
|
||||||
|
presence_destinations=[],
|
||||||
|
keyed_edus={},
|
||||||
|
edus={},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse the rows in the stream and add to the buffer
|
# Parse the rows in the stream and add to the buffer
|
||||||
|
@ -328,7 +328,9 @@ class FederationSender:
|
|||||||
# to allow us to perform catch-up later on if the remote is unreachable
|
# to allow us to perform catch-up later on if the remote is unreachable
|
||||||
# for a while.
|
# for a while.
|
||||||
await self.store.store_destination_rooms_entries(
|
await self.store.store_destination_rooms_entries(
|
||||||
destinations, pdu.room_id, pdu.internal_metadata.stream_ordering,
|
destinations,
|
||||||
|
pdu.room_id,
|
||||||
|
pdu.internal_metadata.stream_ordering,
|
||||||
)
|
)
|
||||||
|
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
@ -616,8 +618,8 @@ class FederationSender:
|
|||||||
last_processed = None # type: Optional[str]
|
last_processed = None # type: Optional[str]
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
destinations_to_wake = await self.store.get_catch_up_outstanding_destinations(
|
destinations_to_wake = (
|
||||||
last_processed
|
await self.store.get_catch_up_outstanding_destinations(last_processed)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not destinations_to_wake:
|
if not destinations_to_wake:
|
||||||
|
@ -85,7 +85,8 @@ class PerDestinationQueue:
|
|||||||
# processing. We have a guard in `attempt_new_transaction` that
|
# processing. We have a guard in `attempt_new_transaction` that
|
||||||
# ensure we don't start sending stuff.
|
# ensure we don't start sending stuff.
|
||||||
logger.error(
|
logger.error(
|
||||||
"Create a per destination queue for %s on wrong worker", destination,
|
"Create a per destination queue for %s on wrong worker",
|
||||||
|
destination,
|
||||||
)
|
)
|
||||||
self._should_send_on_this_instance = False
|
self._should_send_on_this_instance = False
|
||||||
|
|
||||||
@ -440,9 +441,11 @@ class PerDestinationQueue:
|
|||||||
|
|
||||||
if first_catch_up_check:
|
if first_catch_up_check:
|
||||||
# first catchup so get last_successful_stream_ordering from database
|
# first catchup so get last_successful_stream_ordering from database
|
||||||
self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering(
|
self._last_successful_stream_ordering = (
|
||||||
|
await self._store.get_destination_last_successful_stream_ordering(
|
||||||
self._destination
|
self._destination
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if self._last_successful_stream_ordering is None:
|
if self._last_successful_stream_ordering is None:
|
||||||
# if it's still None, then this means we don't have the information
|
# if it's still None, then this means we don't have the information
|
||||||
@ -457,7 +460,8 @@ class PerDestinationQueue:
|
|||||||
# get at most 50 catchup room/PDUs
|
# get at most 50 catchup room/PDUs
|
||||||
while True:
|
while True:
|
||||||
event_ids = await self._store.get_catch_up_room_event_ids(
|
event_ids = await self._store.get_catch_up_room_event_ids(
|
||||||
self._destination, self._last_successful_stream_ordering,
|
self._destination,
|
||||||
|
self._last_successful_stream_ordering,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not event_ids:
|
if not event_ids:
|
||||||
|
@ -65,7 +65,10 @@ class TransactionManager:
|
|||||||
|
|
||||||
@measure_func("_send_new_transaction")
|
@measure_func("_send_new_transaction")
|
||||||
async def send_new_transaction(
|
async def send_new_transaction(
|
||||||
self, destination: str, pdus: List[EventBase], edus: List[Edu],
|
self,
|
||||||
|
destination: str,
|
||||||
|
pdus: List[EventBase],
|
||||||
|
edus: List[Edu],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -551,8 +551,7 @@ class TransportLayerClient:
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def get_group_profile(self, destination, group_id, requester_user_id):
|
def get_group_profile(self, destination, group_id, requester_user_id):
|
||||||
"""Get a group profile
|
"""Get a group profile"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/profile", group_id)
|
path = _create_v1_path("/groups/%s/profile", group_id)
|
||||||
|
|
||||||
return self.client.get_json(
|
return self.client.get_json(
|
||||||
@ -584,8 +583,7 @@ class TransportLayerClient:
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def get_group_summary(self, destination, group_id, requester_user_id):
|
def get_group_summary(self, destination, group_id, requester_user_id):
|
||||||
"""Get a group summary
|
"""Get a group summary"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/summary", group_id)
|
path = _create_v1_path("/groups/%s/summary", group_id)
|
||||||
|
|
||||||
return self.client.get_json(
|
return self.client.get_json(
|
||||||
@ -597,8 +595,7 @@ class TransportLayerClient:
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def get_rooms_in_group(self, destination, group_id, requester_user_id):
|
def get_rooms_in_group(self, destination, group_id, requester_user_id):
|
||||||
"""Get all rooms in a group
|
"""Get all rooms in a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/rooms", group_id)
|
path = _create_v1_path("/groups/%s/rooms", group_id)
|
||||||
|
|
||||||
return self.client.get_json(
|
return self.client.get_json(
|
||||||
@ -611,8 +608,7 @@ class TransportLayerClient:
|
|||||||
def add_room_to_group(
|
def add_room_to_group(
|
||||||
self, destination, group_id, requester_user_id, room_id, content
|
self, destination, group_id, requester_user_id, room_id, content
|
||||||
):
|
):
|
||||||
"""Add a room to a group
|
"""Add a room to a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
|
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
|
||||||
|
|
||||||
return self.client.post_json(
|
return self.client.post_json(
|
||||||
@ -626,8 +622,7 @@ class TransportLayerClient:
|
|||||||
def update_room_in_group(
|
def update_room_in_group(
|
||||||
self, destination, group_id, requester_user_id, room_id, config_key, content
|
self, destination, group_id, requester_user_id, room_id, config_key, content
|
||||||
):
|
):
|
||||||
"""Update room in group
|
"""Update room in group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path(
|
path = _create_v1_path(
|
||||||
"/groups/%s/room/%s/config/%s", group_id, room_id, config_key
|
"/groups/%s/room/%s/config/%s", group_id, room_id, config_key
|
||||||
)
|
)
|
||||||
@ -641,8 +636,7 @@ class TransportLayerClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
|
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
|
||||||
"""Remove a room from a group
|
"""Remove a room from a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
|
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
|
||||||
|
|
||||||
return self.client.delete_json(
|
return self.client.delete_json(
|
||||||
@ -654,8 +648,7 @@ class TransportLayerClient:
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def get_users_in_group(self, destination, group_id, requester_user_id):
|
def get_users_in_group(self, destination, group_id, requester_user_id):
|
||||||
"""Get users in a group
|
"""Get users in a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/users", group_id)
|
path = _create_v1_path("/groups/%s/users", group_id)
|
||||||
|
|
||||||
return self.client.get_json(
|
return self.client.get_json(
|
||||||
@ -667,8 +660,7 @@ class TransportLayerClient:
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
|
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
|
||||||
"""Get users that have been invited to a group
|
"""Get users that have been invited to a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/invited_users", group_id)
|
path = _create_v1_path("/groups/%s/invited_users", group_id)
|
||||||
|
|
||||||
return self.client.get_json(
|
return self.client.get_json(
|
||||||
@ -680,8 +672,7 @@ class TransportLayerClient:
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def accept_group_invite(self, destination, group_id, user_id, content):
|
def accept_group_invite(self, destination, group_id, user_id, content):
|
||||||
"""Accept a group invite
|
"""Accept a group invite"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
|
path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
|
||||||
|
|
||||||
return self.client.post_json(
|
return self.client.post_json(
|
||||||
@ -690,8 +681,7 @@ class TransportLayerClient:
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def join_group(self, destination, group_id, user_id, content):
|
def join_group(self, destination, group_id, user_id, content):
|
||||||
"""Attempts to join a group
|
"""Attempts to join a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
|
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
|
||||||
|
|
||||||
return self.client.post_json(
|
return self.client.post_json(
|
||||||
@ -702,8 +692,7 @@ class TransportLayerClient:
|
|||||||
def invite_to_group(
|
def invite_to_group(
|
||||||
self, destination, group_id, user_id, requester_user_id, content
|
self, destination, group_id, user_id, requester_user_id, content
|
||||||
):
|
):
|
||||||
"""Invite a user to a group
|
"""Invite a user to a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
|
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
|
||||||
|
|
||||||
return self.client.post_json(
|
return self.client.post_json(
|
||||||
@ -730,8 +719,7 @@ class TransportLayerClient:
|
|||||||
def remove_user_from_group(
|
def remove_user_from_group(
|
||||||
self, destination, group_id, requester_user_id, user_id, content
|
self, destination, group_id, requester_user_id, user_id, content
|
||||||
):
|
):
|
||||||
"""Remove a user from a group
|
"""Remove a user from a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
|
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
|
||||||
|
|
||||||
return self.client.post_json(
|
return self.client.post_json(
|
||||||
@ -772,8 +760,7 @@ class TransportLayerClient:
|
|||||||
def update_group_summary_room(
|
def update_group_summary_room(
|
||||||
self, destination, group_id, user_id, room_id, category_id, content
|
self, destination, group_id, user_id, room_id, category_id, content
|
||||||
):
|
):
|
||||||
"""Update a room entry in a group summary
|
"""Update a room entry in a group summary"""
|
||||||
"""
|
|
||||||
if category_id:
|
if category_id:
|
||||||
path = _create_v1_path(
|
path = _create_v1_path(
|
||||||
"/groups/%s/summary/categories/%s/rooms/%s",
|
"/groups/%s/summary/categories/%s/rooms/%s",
|
||||||
@ -796,8 +783,7 @@ class TransportLayerClient:
|
|||||||
def delete_group_summary_room(
|
def delete_group_summary_room(
|
||||||
self, destination, group_id, user_id, room_id, category_id
|
self, destination, group_id, user_id, room_id, category_id
|
||||||
):
|
):
|
||||||
"""Delete a room entry in a group summary
|
"""Delete a room entry in a group summary"""
|
||||||
"""
|
|
||||||
if category_id:
|
if category_id:
|
||||||
path = _create_v1_path(
|
path = _create_v1_path(
|
||||||
"/groups/%s/summary/categories/%s/rooms/%s",
|
"/groups/%s/summary/categories/%s/rooms/%s",
|
||||||
@ -817,8 +803,7 @@ class TransportLayerClient:
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def get_group_categories(self, destination, group_id, requester_user_id):
|
def get_group_categories(self, destination, group_id, requester_user_id):
|
||||||
"""Get all categories in a group
|
"""Get all categories in a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/categories", group_id)
|
path = _create_v1_path("/groups/%s/categories", group_id)
|
||||||
|
|
||||||
return self.client.get_json(
|
return self.client.get_json(
|
||||||
@ -830,8 +815,7 @@ class TransportLayerClient:
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def get_group_category(self, destination, group_id, requester_user_id, category_id):
|
def get_group_category(self, destination, group_id, requester_user_id, category_id):
|
||||||
"""Get category info in a group
|
"""Get category info in a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
|
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
|
||||||
|
|
||||||
return self.client.get_json(
|
return self.client.get_json(
|
||||||
@ -845,8 +829,7 @@ class TransportLayerClient:
|
|||||||
def update_group_category(
|
def update_group_category(
|
||||||
self, destination, group_id, requester_user_id, category_id, content
|
self, destination, group_id, requester_user_id, category_id, content
|
||||||
):
|
):
|
||||||
"""Update a category in a group
|
"""Update a category in a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
|
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
|
||||||
|
|
||||||
return self.client.post_json(
|
return self.client.post_json(
|
||||||
@ -861,8 +844,7 @@ class TransportLayerClient:
|
|||||||
def delete_group_category(
|
def delete_group_category(
|
||||||
self, destination, group_id, requester_user_id, category_id
|
self, destination, group_id, requester_user_id, category_id
|
||||||
):
|
):
|
||||||
"""Delete a category in a group
|
"""Delete a category in a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
|
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
|
||||||
|
|
||||||
return self.client.delete_json(
|
return self.client.delete_json(
|
||||||
@ -874,8 +856,7 @@ class TransportLayerClient:
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def get_group_roles(self, destination, group_id, requester_user_id):
|
def get_group_roles(self, destination, group_id, requester_user_id):
|
||||||
"""Get all roles in a group
|
"""Get all roles in a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/roles", group_id)
|
path = _create_v1_path("/groups/%s/roles", group_id)
|
||||||
|
|
||||||
return self.client.get_json(
|
return self.client.get_json(
|
||||||
@ -887,8 +868,7 @@ class TransportLayerClient:
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def get_group_role(self, destination, group_id, requester_user_id, role_id):
|
def get_group_role(self, destination, group_id, requester_user_id, role_id):
|
||||||
"""Get a roles info
|
"""Get a roles info"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
|
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
|
||||||
|
|
||||||
return self.client.get_json(
|
return self.client.get_json(
|
||||||
@ -902,8 +882,7 @@ class TransportLayerClient:
|
|||||||
def update_group_role(
|
def update_group_role(
|
||||||
self, destination, group_id, requester_user_id, role_id, content
|
self, destination, group_id, requester_user_id, role_id, content
|
||||||
):
|
):
|
||||||
"""Update a role in a group
|
"""Update a role in a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
|
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
|
||||||
|
|
||||||
return self.client.post_json(
|
return self.client.post_json(
|
||||||
@ -916,8 +895,7 @@ class TransportLayerClient:
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
|
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
|
||||||
"""Delete a role in a group
|
"""Delete a role in a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
|
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
|
||||||
|
|
||||||
return self.client.delete_json(
|
return self.client.delete_json(
|
||||||
@ -931,8 +909,7 @@ class TransportLayerClient:
|
|||||||
def update_group_summary_user(
|
def update_group_summary_user(
|
||||||
self, destination, group_id, requester_user_id, user_id, role_id, content
|
self, destination, group_id, requester_user_id, user_id, role_id, content
|
||||||
):
|
):
|
||||||
"""Update a users entry in a group
|
"""Update a users entry in a group"""
|
||||||
"""
|
|
||||||
if role_id:
|
if role_id:
|
||||||
path = _create_v1_path(
|
path = _create_v1_path(
|
||||||
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
|
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
|
||||||
@ -950,8 +927,7 @@ class TransportLayerClient:
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def set_group_join_policy(self, destination, group_id, requester_user_id, content):
|
def set_group_join_policy(self, destination, group_id, requester_user_id, content):
|
||||||
"""Sets the join policy for a group
|
"""Sets the join policy for a group"""
|
||||||
"""
|
|
||||||
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
|
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
|
||||||
|
|
||||||
return self.client.put_json(
|
return self.client.put_json(
|
||||||
@ -966,8 +942,7 @@ class TransportLayerClient:
|
|||||||
def delete_group_summary_user(
|
def delete_group_summary_user(
|
||||||
self, destination, group_id, requester_user_id, user_id, role_id
|
self, destination, group_id, requester_user_id, user_id, role_id
|
||||||
):
|
):
|
||||||
"""Delete a users entry in a group
|
"""Delete a users entry in a group"""
|
||||||
"""
|
|
||||||
if role_id:
|
if role_id:
|
||||||
path = _create_v1_path(
|
path = _create_v1_path(
|
||||||
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
|
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
|
||||||
@ -983,8 +958,7 @@ class TransportLayerClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def bulk_get_publicised_groups(self, destination, user_ids):
|
def bulk_get_publicised_groups(self, destination, user_ids):
|
||||||
"""Get the groups a list of users are publicising
|
"""Get the groups a list of users are publicising"""
|
||||||
"""
|
|
||||||
|
|
||||||
path = _create_v1_path("/get_groups_publicised")
|
path = _create_v1_path("/get_groups_publicised")
|
||||||
|
|
||||||
|
@ -364,7 +364,10 @@ class BaseFederationServlet:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
server.register_paths(
|
server.register_paths(
|
||||||
method, (pattern,), self._wrap(code), self.__class__.__name__,
|
method,
|
||||||
|
(pattern,),
|
||||||
|
self._wrap(code),
|
||||||
|
self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -855,8 +858,7 @@ class FederationVersionServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsProfileServlet(BaseFederationServlet):
|
class FederationGroupsProfileServlet(BaseFederationServlet):
|
||||||
"""Get/set the basic profile of a group on behalf of a user
|
"""Get/set the basic profile of a group on behalf of a user"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/profile"
|
PATH = "/groups/(?P<group_id>[^/]*)/profile"
|
||||||
|
|
||||||
@ -895,8 +897,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsRoomsServlet(BaseFederationServlet):
|
class FederationGroupsRoomsServlet(BaseFederationServlet):
|
||||||
"""Get the rooms in a group on behalf of a user
|
"""Get the rooms in a group on behalf of a user"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
|
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
|
||||||
|
|
||||||
@ -911,8 +912,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
|
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
|
||||||
"""Add/remove room from group
|
"""Add/remove room from group"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
|
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
|
||||||
|
|
||||||
@ -940,8 +940,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
|
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
|
||||||
"""Update room config in group
|
"""Update room config in group"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = (
|
PATH = (
|
||||||
"/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
|
"/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
|
||||||
@ -961,8 +960,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsUsersServlet(BaseFederationServlet):
|
class FederationGroupsUsersServlet(BaseFederationServlet):
|
||||||
"""Get the users in a group on behalf of a user
|
"""Get the users in a group on behalf of a user"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/users"
|
PATH = "/groups/(?P<group_id>[^/]*)/users"
|
||||||
|
|
||||||
@ -977,8 +975,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
|
class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
|
||||||
"""Get the users that have been invited to a group
|
"""Get the users that have been invited to a group"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
|
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
|
||||||
|
|
||||||
@ -995,8 +992,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsInviteServlet(BaseFederationServlet):
|
class FederationGroupsInviteServlet(BaseFederationServlet):
|
||||||
"""Ask a group server to invite someone to the group
|
"""Ask a group server to invite someone to the group"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
|
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
|
||||||
|
|
||||||
@ -1013,8 +1009,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
|
class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
|
||||||
"""Accept an invitation from the group server
|
"""Accept an invitation from the group server"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
|
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
|
||||||
|
|
||||||
@ -1028,8 +1023,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsJoinServlet(BaseFederationServlet):
|
class FederationGroupsJoinServlet(BaseFederationServlet):
|
||||||
"""Attempt to join a group
|
"""Attempt to join a group"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
|
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
|
||||||
|
|
||||||
@ -1043,8 +1037,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
|
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
|
||||||
"""Leave or kick a user from the group
|
"""Leave or kick a user from the group"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
|
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
|
||||||
|
|
||||||
@ -1061,8 +1054,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsLocalInviteServlet(BaseFederationServlet):
|
class FederationGroupsLocalInviteServlet(BaseFederationServlet):
|
||||||
"""A group server has invited a local user
|
"""A group server has invited a local user"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
|
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
|
||||||
|
|
||||||
@ -1076,8 +1068,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
|
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
|
||||||
"""A group server has removed a local user
|
"""A group server has removed a local user"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
|
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
|
||||||
|
|
||||||
@ -1093,8 +1084,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
|
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
|
||||||
"""A group or user's server renews their attestation
|
"""A group or user's server renews their attestation"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
|
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
|
||||||
|
|
||||||
@ -1156,8 +1146,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsCategoriesServlet(BaseFederationServlet):
|
class FederationGroupsCategoriesServlet(BaseFederationServlet):
|
||||||
"""Get all categories for a group
|
"""Get all categories for a group"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
|
PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
|
||||||
|
|
||||||
@ -1172,8 +1161,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsCategoryServlet(BaseFederationServlet):
|
class FederationGroupsCategoryServlet(BaseFederationServlet):
|
||||||
"""Add/remove/get a category in a group
|
"""Add/remove/get a category in a group"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
|
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
|
||||||
|
|
||||||
@ -1218,8 +1206,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsRolesServlet(BaseFederationServlet):
|
class FederationGroupsRolesServlet(BaseFederationServlet):
|
||||||
"""Get roles in a group
|
"""Get roles in a group"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
|
PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
|
||||||
|
|
||||||
@ -1234,8 +1221,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsRoleServlet(BaseFederationServlet):
|
class FederationGroupsRoleServlet(BaseFederationServlet):
|
||||||
"""Add/remove/get a role in a group
|
"""Add/remove/get a role in a group"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
|
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
|
||||||
|
|
||||||
@ -1325,8 +1311,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
|
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
|
||||||
"""Get roles in a group
|
"""Get roles in a group"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/get_groups_publicised"
|
PATH = "/get_groups_publicised"
|
||||||
|
|
||||||
@ -1339,8 +1324,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
|
|||||||
|
|
||||||
|
|
||||||
class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
|
class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
|
||||||
"""Sets whether a group is joinable without an invite or knock
|
"""Sets whether a group is joinable without an invite or knock"""
|
||||||
"""
|
|
||||||
|
|
||||||
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
|
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
|
||||||
|
|
||||||
|
@ -61,8 +61,7 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
|
|||||||
|
|
||||||
|
|
||||||
class GroupAttestationSigning:
|
class GroupAttestationSigning:
|
||||||
"""Creates and verifies group attestations.
|
"""Creates and verifies group attestations."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.keyring = hs.get_keyring()
|
self.keyring = hs.get_keyring()
|
||||||
@ -125,8 +124,7 @@ class GroupAttestationSigning:
|
|||||||
|
|
||||||
|
|
||||||
class GroupAttestionRenewer:
|
class GroupAttestionRenewer:
|
||||||
"""Responsible for sending and receiving attestation updates.
|
"""Responsible for sending and receiving attestation updates."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
@ -142,8 +140,7 @@ class GroupAttestionRenewer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def on_renew_attestation(self, group_id, user_id, content):
|
async def on_renew_attestation(self, group_id, user_id, content):
|
||||||
"""When a remote updates an attestation
|
"""When a remote updates an attestation"""
|
||||||
"""
|
|
||||||
attestation = content["attestation"]
|
attestation = content["attestation"]
|
||||||
|
|
||||||
if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
|
if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
|
||||||
@ -161,8 +158,7 @@ class GroupAttestionRenewer:
|
|||||||
return run_as_background_process("renew_attestations", self._renew_attestations)
|
return run_as_background_process("renew_attestations", self._renew_attestations)
|
||||||
|
|
||||||
async def _renew_attestations(self):
|
async def _renew_attestations(self):
|
||||||
"""Called periodically to check if we need to update any of our attestations
|
"""Called periodically to check if we need to update any of our attestations"""
|
||||||
"""
|
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
|
|
||||||
|
@ -165,16 +165,14 @@ class GroupsServerWorkerHandler:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def get_group_categories(self, group_id, requester_user_id):
|
async def get_group_categories(self, group_id, requester_user_id):
|
||||||
"""Get all categories in a group (as seen by user)
|
"""Get all categories in a group (as seen by user)"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
categories = await self.store.get_group_categories(group_id=group_id)
|
categories = await self.store.get_group_categories(group_id=group_id)
|
||||||
return {"categories": categories}
|
return {"categories": categories}
|
||||||
|
|
||||||
async def get_group_category(self, group_id, requester_user_id, category_id):
|
async def get_group_category(self, group_id, requester_user_id, category_id):
|
||||||
"""Get a specific category in a group (as seen by user)
|
"""Get a specific category in a group (as seen by user)"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
res = await self.store.get_group_category(
|
res = await self.store.get_group_category(
|
||||||
@ -186,24 +184,21 @@ class GroupsServerWorkerHandler:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
async def get_group_roles(self, group_id, requester_user_id):
|
async def get_group_roles(self, group_id, requester_user_id):
|
||||||
"""Get all roles in a group (as seen by user)
|
"""Get all roles in a group (as seen by user)"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
roles = await self.store.get_group_roles(group_id=group_id)
|
roles = await self.store.get_group_roles(group_id=group_id)
|
||||||
return {"roles": roles}
|
return {"roles": roles}
|
||||||
|
|
||||||
async def get_group_role(self, group_id, requester_user_id, role_id):
|
async def get_group_role(self, group_id, requester_user_id, role_id):
|
||||||
"""Get a specific role in a group (as seen by user)
|
"""Get a specific role in a group (as seen by user)"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
res = await self.store.get_group_role(group_id=group_id, role_id=role_id)
|
res = await self.store.get_group_role(group_id=group_id, role_id=role_id)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def get_group_profile(self, group_id, requester_user_id):
|
async def get_group_profile(self, group_id, requester_user_id):
|
||||||
"""Get the group profile as seen by requester_user_id
|
"""Get the group profile as seen by requester_user_id"""
|
||||||
"""
|
|
||||||
|
|
||||||
await self.check_group_is_ours(group_id, requester_user_id)
|
await self.check_group_is_ours(group_id, requester_user_id)
|
||||||
|
|
||||||
@ -350,8 +345,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
async def update_group_summary_room(
|
async def update_group_summary_room(
|
||||||
self, group_id, requester_user_id, room_id, category_id, content
|
self, group_id, requester_user_id, room_id, category_id, content
|
||||||
):
|
):
|
||||||
"""Add/update a room to the group summary
|
"""Add/update a room to the group summary"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(
|
await self.check_group_is_ours(
|
||||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
)
|
)
|
||||||
@ -375,8 +369,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
async def delete_group_summary_room(
|
async def delete_group_summary_room(
|
||||||
self, group_id, requester_user_id, room_id, category_id
|
self, group_id, requester_user_id, room_id, category_id
|
||||||
):
|
):
|
||||||
"""Remove a room from the summary
|
"""Remove a room from the summary"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(
|
await self.check_group_is_ours(
|
||||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
)
|
)
|
||||||
@ -409,8 +402,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
async def update_group_category(
|
async def update_group_category(
|
||||||
self, group_id, requester_user_id, category_id, content
|
self, group_id, requester_user_id, category_id, content
|
||||||
):
|
):
|
||||||
"""Add/Update a group category
|
"""Add/Update a group category"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(
|
await self.check_group_is_ours(
|
||||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
)
|
)
|
||||||
@ -428,8 +420,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def delete_group_category(self, group_id, requester_user_id, category_id):
|
async def delete_group_category(self, group_id, requester_user_id, category_id):
|
||||||
"""Delete a group category
|
"""Delete a group category"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(
|
await self.check_group_is_ours(
|
||||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
)
|
)
|
||||||
@ -441,8 +432,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def update_group_role(self, group_id, requester_user_id, role_id, content):
|
async def update_group_role(self, group_id, requester_user_id, role_id, content):
|
||||||
"""Add/update a role in a group
|
"""Add/update a role in a group"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(
|
await self.check_group_is_ours(
|
||||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
)
|
)
|
||||||
@ -458,8 +448,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def delete_group_role(self, group_id, requester_user_id, role_id):
|
async def delete_group_role(self, group_id, requester_user_id, role_id):
|
||||||
"""Remove role from group
|
"""Remove role from group"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(
|
await self.check_group_is_ours(
|
||||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
)
|
)
|
||||||
@ -471,8 +460,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
async def update_group_summary_user(
|
async def update_group_summary_user(
|
||||||
self, group_id, requester_user_id, user_id, role_id, content
|
self, group_id, requester_user_id, user_id, role_id, content
|
||||||
):
|
):
|
||||||
"""Add/update a users entry in the group summary
|
"""Add/update a users entry in the group summary"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(
|
await self.check_group_is_ours(
|
||||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
)
|
)
|
||||||
@ -494,8 +482,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
async def delete_group_summary_user(
|
async def delete_group_summary_user(
|
||||||
self, group_id, requester_user_id, user_id, role_id
|
self, group_id, requester_user_id, user_id, role_id
|
||||||
):
|
):
|
||||||
"""Remove a user from the group summary
|
"""Remove a user from the group summary"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(
|
await self.check_group_is_ours(
|
||||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
)
|
)
|
||||||
@ -507,8 +494,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def update_group_profile(self, group_id, requester_user_id, content):
|
async def update_group_profile(self, group_id, requester_user_id, content):
|
||||||
"""Update the group profile
|
"""Update the group profile"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(
|
await self.check_group_is_ours(
|
||||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
)
|
)
|
||||||
@ -539,8 +525,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
await self.store.update_group_profile(group_id, profile)
|
await self.store.update_group_profile(group_id, profile)
|
||||||
|
|
||||||
async def add_room_to_group(self, group_id, requester_user_id, room_id, content):
|
async def add_room_to_group(self, group_id, requester_user_id, room_id, content):
|
||||||
"""Add room to group
|
"""Add room to group"""
|
||||||
"""
|
|
||||||
RoomID.from_string(room_id) # Ensure valid room id
|
RoomID.from_string(room_id) # Ensure valid room id
|
||||||
|
|
||||||
await self.check_group_is_ours(
|
await self.check_group_is_ours(
|
||||||
@ -556,8 +541,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
async def update_room_in_group(
|
async def update_room_in_group(
|
||||||
self, group_id, requester_user_id, room_id, config_key, content
|
self, group_id, requester_user_id, room_id, config_key, content
|
||||||
):
|
):
|
||||||
"""Update room in group
|
"""Update room in group"""
|
||||||
"""
|
|
||||||
RoomID.from_string(room_id) # Ensure valid room id
|
RoomID.from_string(room_id) # Ensure valid room id
|
||||||
|
|
||||||
await self.check_group_is_ours(
|
await self.check_group_is_ours(
|
||||||
@ -576,8 +560,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def remove_room_from_group(self, group_id, requester_user_id, room_id):
|
async def remove_room_from_group(self, group_id, requester_user_id, room_id):
|
||||||
"""Remove room from group
|
"""Remove room from group"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(
|
await self.check_group_is_ours(
|
||||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
)
|
)
|
||||||
@ -587,8 +570,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def invite_to_group(self, group_id, user_id, requester_user_id, content):
|
async def invite_to_group(self, group_id, user_id, requester_user_id, content):
|
||||||
"""Invite user to group
|
"""Invite user to group"""
|
||||||
"""
|
|
||||||
|
|
||||||
group = await self.check_group_is_ours(
|
group = await self.check_group_is_ours(
|
||||||
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
|
||||||
@ -724,8 +706,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
return {"state": "join", "attestation": local_attestation}
|
return {"state": "join", "attestation": local_attestation}
|
||||||
|
|
||||||
async def knock(self, group_id, requester_user_id, content):
|
async def knock(self, group_id, requester_user_id, content):
|
||||||
"""A user requests becoming a member of the group
|
"""A user requests becoming a member of the group"""
|
||||||
"""
|
|
||||||
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -918,8 +899,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
|
|
||||||
|
|
||||||
def _parse_join_policy_from_contents(content):
|
def _parse_join_policy_from_contents(content):
|
||||||
"""Given a content for a request, return the specified join policy or None
|
"""Given a content for a request, return the specified join policy or None"""
|
||||||
"""
|
|
||||||
|
|
||||||
join_policy_dict = content.get("m.join_policy")
|
join_policy_dict = content.get("m.join_policy")
|
||||||
if join_policy_dict:
|
if join_policy_dict:
|
||||||
@ -929,8 +909,7 @@ def _parse_join_policy_from_contents(content):
|
|||||||
|
|
||||||
|
|
||||||
def _parse_join_policy_dict(join_policy_dict):
|
def _parse_join_policy_dict(join_policy_dict):
|
||||||
"""Given a dict for the "m.join_policy" config return the join policy specified
|
"""Given a dict for the "m.join_policy" config return the join policy specified"""
|
||||||
"""
|
|
||||||
join_policy_type = join_policy_dict.get("type")
|
join_policy_type = join_policy_dict.get("type")
|
||||||
if not join_policy_type:
|
if not join_policy_type:
|
||||||
return "invite"
|
return "invite"
|
||||||
|
@ -203,13 +203,11 @@ class AdminHandler(BaseHandler):
|
|||||||
|
|
||||||
|
|
||||||
class ExfiltrationWriter(metaclass=abc.ABCMeta):
|
class ExfiltrationWriter(metaclass=abc.ABCMeta):
|
||||||
"""Interface used to specify how to write exported data.
|
"""Interface used to specify how to write exported data."""
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def write_events(self, room_id: str, events: List[EventBase]) -> None:
|
def write_events(self, room_id: str, events: List[EventBase]) -> None:
|
||||||
"""Write a batch of events for a room.
|
"""Write a batch of events for a room."""
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
@ -290,7 +290,9 @@ class ApplicationServicesHandler:
|
|||||||
if not interested:
|
if not interested:
|
||||||
continue
|
continue
|
||||||
presence_events, _ = await presence_source.get_new_events(
|
presence_events, _ = await presence_source.get_new_events(
|
||||||
user=user, service=service, from_key=from_key,
|
user=user,
|
||||||
|
service=service,
|
||||||
|
from_key=from_key,
|
||||||
)
|
)
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
events.extend(
|
events.extend(
|
||||||
|
@ -120,7 +120,9 @@ def convert_client_dict_legacy_fields_to_identifier(
|
|||||||
# Ensure the identifier has a type
|
# Ensure the identifier has a type
|
||||||
if "type" not in identifier:
|
if "type" not in identifier:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
|
400,
|
||||||
|
"'identifier' dict has no key 'type'",
|
||||||
|
errcode=Codes.MISSING_PARAM,
|
||||||
)
|
)
|
||||||
|
|
||||||
return identifier
|
return identifier
|
||||||
@ -351,7 +353,11 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result, params, session_id = await self.check_ui_auth(
|
result, params, session_id = await self.check_ui_auth(
|
||||||
flows, request, request_body, description, get_new_session_data,
|
flows,
|
||||||
|
request,
|
||||||
|
request_body,
|
||||||
|
description,
|
||||||
|
get_new_session_data,
|
||||||
)
|
)
|
||||||
except LoginError:
|
except LoginError:
|
||||||
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
|
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
|
||||||
@ -379,8 +385,7 @@ class AuthHandler(BaseHandler):
|
|||||||
return params, session_id
|
return params, session_id
|
||||||
|
|
||||||
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
|
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
|
||||||
"""Get a list of the authentication types this user can use
|
"""Get a list of the authentication types this user can use"""
|
||||||
"""
|
|
||||||
|
|
||||||
ui_auth_types = set()
|
ui_auth_types = set()
|
||||||
|
|
||||||
@ -723,7 +728,9 @@ class AuthHandler(BaseHandler):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _auth_dict_for_flows(
|
def _auth_dict_for_flows(
|
||||||
self, flows: List[List[str]], session_id: str,
|
self,
|
||||||
|
flows: List[List[str]],
|
||||||
|
session_id: str,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
public_flows = []
|
public_flows = []
|
||||||
for f in flows:
|
for f in flows:
|
||||||
@ -880,7 +887,9 @@ class AuthHandler(BaseHandler):
|
|||||||
return self._supported_login_types
|
return self._supported_login_types
|
||||||
|
|
||||||
async def validate_login(
|
async def validate_login(
|
||||||
self, login_submission: Dict[str, Any], ratelimit: bool = False,
|
self,
|
||||||
|
login_submission: Dict[str, Any],
|
||||||
|
ratelimit: bool = False,
|
||||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
||||||
"""Authenticates the user for the /login API
|
"""Authenticates the user for the /login API
|
||||||
|
|
||||||
@ -1023,7 +1032,9 @@ class AuthHandler(BaseHandler):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def _validate_userid_login(
|
async def _validate_userid_login(
|
||||||
self, username: str, login_submission: Dict[str, Any],
|
self,
|
||||||
|
username: str,
|
||||||
|
login_submission: Dict[str, Any],
|
||||||
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
|
||||||
"""Helper for validate_login
|
"""Helper for validate_login
|
||||||
|
|
||||||
@ -1446,7 +1457,8 @@ class AuthHandler(BaseHandler):
|
|||||||
# is considered OK since the newest SSO attributes should be most valid.
|
# is considered OK since the newest SSO attributes should be most valid.
|
||||||
if extra_attributes:
|
if extra_attributes:
|
||||||
self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
|
self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
|
||||||
self._clock.time_msec(), extra_attributes,
|
self._clock.time_msec(),
|
||||||
|
extra_attributes,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a login token
|
# Create a login token
|
||||||
@ -1702,5 +1714,9 @@ class PasswordProvider:
|
|||||||
# This might return an awaitable, if it does block the log out
|
# This might return an awaitable, if it does block the log out
|
||||||
# until it completes.
|
# until it completes.
|
||||||
await maybe_awaitable(
|
await maybe_awaitable(
|
||||||
g(user_id=user_id, device_id=device_id, access_token=access_token,)
|
g(
|
||||||
|
user_id=user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
access_token=access_token,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
@ -33,8 +33,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class CasError(Exception):
|
class CasError(Exception):
|
||||||
"""Used to catch errors when validating the CAS ticket.
|
"""Used to catch errors when validating the CAS ticket."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, error, error_description=None):
|
def __init__(self, error, error_description=None):
|
||||||
self.error = error
|
self.error = error
|
||||||
@ -100,7 +99,10 @@ class CasHandler:
|
|||||||
Returns:
|
Returns:
|
||||||
The URL to use as a "service" parameter.
|
The URL to use as a "service" parameter.
|
||||||
"""
|
"""
|
||||||
return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
|
return "%s?%s" % (
|
||||||
|
self._cas_service_url,
|
||||||
|
urllib.parse.urlencode(args),
|
||||||
|
)
|
||||||
|
|
||||||
async def _validate_ticket(
|
async def _validate_ticket(
|
||||||
self, ticket: str, service_args: Dict[str, str]
|
self, ticket: str, service_args: Dict[str, str]
|
||||||
@ -296,7 +298,10 @@ class CasHandler:
|
|||||||
# first check if we're doing a UIA
|
# first check if we're doing a UIA
|
||||||
if session:
|
if session:
|
||||||
return await self._sso_handler.complete_sso_ui_auth_request(
|
return await self._sso_handler.complete_sso_ui_auth_request(
|
||||||
self.idp_id, cas_response.username, session, request,
|
self.idp_id,
|
||||||
|
cas_response.username,
|
||||||
|
session,
|
||||||
|
request,
|
||||||
)
|
)
|
||||||
|
|
||||||
# otherwise, we're handling a login request.
|
# otherwise, we're handling a login request.
|
||||||
@ -366,7 +371,8 @@ class CasHandler:
|
|||||||
user_id = UserID(localpart, self._hostname).to_string()
|
user_id = UserID(localpart, self._hostname).to_string()
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Looking for existing account based on mapped %s", user_id,
|
"Looking for existing account based on mapped %s",
|
||||||
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
users = await self._store.get_users_by_id_case_insensitive(user_id)
|
users = await self._store.get_users_by_id_case_insensitive(user_id)
|
||||||
|
@ -196,8 +196,7 @@ class DeactivateAccountHandler(BaseHandler):
|
|||||||
run_as_background_process("user_parter_loop", self._user_parter_loop)
|
run_as_background_process("user_parter_loop", self._user_parter_loop)
|
||||||
|
|
||||||
async def _user_parter_loop(self) -> None:
|
async def _user_parter_loop(self) -> None:
|
||||||
"""Loop that parts deactivated users from rooms
|
"""Loop that parts deactivated users from rooms"""
|
||||||
"""
|
|
||||||
self._user_parter_running = True
|
self._user_parter_running = True
|
||||||
logger.info("Starting user parter")
|
logger.info("Starting user parter")
|
||||||
try:
|
try:
|
||||||
@ -214,8 +213,7 @@ class DeactivateAccountHandler(BaseHandler):
|
|||||||
self._user_parter_running = False
|
self._user_parter_running = False
|
||||||
|
|
||||||
async def _part_user(self, user_id: str) -> None:
|
async def _part_user(self, user_id: str) -> None:
|
||||||
"""Causes the given user_id to leave all the rooms they're joined to
|
"""Causes the given user_id to leave all the rooms they're joined to"""
|
||||||
"""
|
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
rooms_for_user = await self.store.get_rooms_for_user(user_id)
|
rooms_for_user = await self.store.get_rooms_for_user(user_id)
|
||||||
|
@ -534,7 +534,9 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||||||
device id of the dehydrated device
|
device id of the dehydrated device
|
||||||
"""
|
"""
|
||||||
device_id = await self.check_device_registered(
|
device_id = await self.check_device_registered(
|
||||||
user_id, None, initial_device_display_name,
|
user_id,
|
||||||
|
None,
|
||||||
|
initial_device_display_name,
|
||||||
)
|
)
|
||||||
old_device_id = await self.store.store_dehydrated_device(
|
old_device_id = await self.store.store_dehydrated_device(
|
||||||
user_id, device_id, device_data
|
user_id, device_id, device_data
|
||||||
@ -803,7 +805,8 @@ class DeviceListUpdater:
|
|||||||
try:
|
try:
|
||||||
# Try to resync the current user's devices list.
|
# Try to resync the current user's devices list.
|
||||||
result = await self.user_device_resync(
|
result = await self.user_device_resync(
|
||||||
user_id=user_id, mark_failed_as_stale=False,
|
user_id=user_id,
|
||||||
|
mark_failed_as_stale=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# user_device_resync only returns a result if it managed to
|
# user_device_resync only returns a result if it managed to
|
||||||
@ -813,14 +816,17 @@ class DeviceListUpdater:
|
|||||||
# self.store.update_remote_device_list_cache).
|
# self.store.update_remote_device_list_cache).
|
||||||
if result:
|
if result:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Successfully resynced the device list for %s", user_id,
|
"Successfully resynced the device list for %s",
|
||||||
|
user_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If there was an issue resyncing this user, e.g. if the remote
|
# If there was an issue resyncing this user, e.g. if the remote
|
||||||
# server sent a malformed result, just log the error instead of
|
# server sent a malformed result, just log the error instead of
|
||||||
# aborting all the subsequent resyncs.
|
# aborting all the subsequent resyncs.
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Could not resync the device list for %s: %s", user_id, e,
|
"Could not resync the device list for %s: %s",
|
||||||
|
user_id,
|
||||||
|
e,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
# Allow future calls to retry resyncinc out of sync device lists.
|
# Allow future calls to retry resyncinc out of sync device lists.
|
||||||
@ -855,7 +861,9 @@ class DeviceListUpdater:
|
|||||||
return None
|
return None
|
||||||
except (RequestSendFailed, HttpResponseException) as e:
|
except (RequestSendFailed, HttpResponseException) as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to handle device list update for %s: %s", user_id, e,
|
"Failed to handle device list update for %s: %s",
|
||||||
|
user_id,
|
||||||
|
e,
|
||||||
)
|
)
|
||||||
|
|
||||||
if mark_failed_as_stale:
|
if mark_failed_as_stale:
|
||||||
@ -931,7 +939,9 @@ class DeviceListUpdater:
|
|||||||
|
|
||||||
# Handle cross-signing keys.
|
# Handle cross-signing keys.
|
||||||
cross_signing_device_ids = await self.process_cross_signing_key_update(
|
cross_signing_device_ids = await self.process_cross_signing_key_update(
|
||||||
user_id, master_key, self_signing_key,
|
user_id,
|
||||||
|
master_key,
|
||||||
|
self_signing_key,
|
||||||
)
|
)
|
||||||
device_ids = device_ids + cross_signing_device_ids
|
device_ids = device_ids + cross_signing_device_ids
|
||||||
|
|
||||||
|
@ -62,7 +62,8 @@ class DeviceMessageHandler:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hs.get_federation_registry().register_instances_for_edu(
|
hs.get_federation_registry().register_instances_for_edu(
|
||||||
"m.direct_to_device", hs.config.worker.writers.to_device,
|
"m.direct_to_device",
|
||||||
|
hs.config.worker.writers.to_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The handler to call when we think a user's device list might be out of
|
# The handler to call when we think a user's device list might be out of
|
||||||
@ -73,8 +74,8 @@ class DeviceMessageHandler:
|
|||||||
hs.get_device_handler().device_list_updater.user_device_resync
|
hs.get_device_handler().device_list_updater.user_device_resync
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
|
self._user_device_resync = (
|
||||||
hs
|
ReplicationUserDevicesResyncRestServlet.make_client(hs)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
|
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
|
||||||
|
@ -61,8 +61,8 @@ class E2eKeysHandler:
|
|||||||
|
|
||||||
self._is_master = hs.config.worker_app is None
|
self._is_master = hs.config.worker_app is None
|
||||||
if not self._is_master:
|
if not self._is_master:
|
||||||
self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client(
|
self._user_device_resync_client = (
|
||||||
hs
|
ReplicationUserDevicesResyncRestServlet.make_client(hs)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Only register this edu handler on master as it requires writing
|
# Only register this edu handler on master as it requires writing
|
||||||
@ -391,8 +391,7 @@ class E2eKeysHandler:
|
|||||||
async def on_federation_query_client_keys(
|
async def on_federation_query_client_keys(
|
||||||
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
|
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
""" Handle a device key query from a federated server
|
"""Handle a device key query from a federated server"""
|
||||||
"""
|
|
||||||
device_keys_query = query_body.get(
|
device_keys_query = query_body.get(
|
||||||
"device_keys", {}
|
"device_keys", {}
|
||||||
) # type: Dict[str, Optional[List[str]]]
|
) # type: Dict[str, Optional[List[str]]]
|
||||||
@ -1065,7 +1064,9 @@ class E2eKeysHandler:
|
|||||||
return key, key_id, verify_key
|
return key, key_id, verify_key
|
||||||
|
|
||||||
async def _retrieve_cross_signing_keys_for_remote_user(
|
async def _retrieve_cross_signing_keys_for_remote_user(
|
||||||
self, user: UserID, desired_key_type: str,
|
self,
|
||||||
|
user: UserID,
|
||||||
|
desired_key_type: str,
|
||||||
) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
|
) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
|
||||||
"""Queries cross-signing keys for a remote user and saves them to the database
|
"""Queries cross-signing keys for a remote user and saves them to the database
|
||||||
|
|
||||||
@ -1269,8 +1270,7 @@ def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
|
|||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
class SignatureListItem:
|
class SignatureListItem:
|
||||||
"""An item in the signature list as used by upload_signatures_for_device_keys.
|
"""An item in the signature list as used by upload_signatures_for_device_keys."""
|
||||||
"""
|
|
||||||
|
|
||||||
signing_key_id = attr.ib(type=str)
|
signing_key_id = attr.ib(type=str)
|
||||||
target_user_id = attr.ib(type=str)
|
target_user_id = attr.ib(type=str)
|
||||||
@ -1355,8 +1355,12 @@ class SigningKeyEduUpdater:
|
|||||||
logger.info("pending updates: %r", pending_updates)
|
logger.info("pending updates: %r", pending_updates)
|
||||||
|
|
||||||
for master_key, self_signing_key in pending_updates:
|
for master_key, self_signing_key in pending_updates:
|
||||||
new_device_ids = await device_list_updater.process_cross_signing_key_update(
|
new_device_ids = (
|
||||||
user_id, master_key, self_signing_key,
|
await device_list_updater.process_cross_signing_key_update(
|
||||||
|
user_id,
|
||||||
|
master_key,
|
||||||
|
self_signing_key,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
device_ids = device_ids + new_device_ids
|
device_ids = device_ids + new_device_ids
|
||||||
|
|
||||||
|
@ -57,8 +57,7 @@ class EventStreamHandler(BaseHandler):
|
|||||||
room_id: Optional[str] = None,
|
room_id: Optional[str] = None,
|
||||||
is_guest: bool = False,
|
is_guest: bool = False,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""Fetches the events stream for a given user.
|
"""Fetches the events stream for a given user."""
|
||||||
"""
|
|
||||||
|
|
||||||
if room_id:
|
if room_id:
|
||||||
blocked = await self.store.is_room_blocked(room_id)
|
blocked = await self.store.is_room_blocked(room_id)
|
||||||
|
@ -150,11 +150,11 @@ class FederationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if hs.config.worker_app:
|
if hs.config.worker_app:
|
||||||
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
|
self._user_device_resync = (
|
||||||
hs
|
ReplicationUserDevicesResyncRestServlet.make_client(hs)
|
||||||
)
|
)
|
||||||
self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(
|
self._maybe_store_room_on_outlier_membership = (
|
||||||
hs
|
ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._device_list_updater = hs.get_device_handler().device_list_updater
|
self._device_list_updater = hs.get_device_handler().device_list_updater
|
||||||
@ -368,7 +368,8 @@ class FederationHandler(BaseHandler):
|
|||||||
# know about
|
# know about
|
||||||
for p in prevs - seen:
|
for p in prevs - seen:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Requesting state at missing prev_event %s", event_id,
|
"Requesting state at missing prev_event %s",
|
||||||
|
event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
with nested_logging_context(p):
|
with nested_logging_context(p):
|
||||||
@ -388,13 +389,15 @@ class FederationHandler(BaseHandler):
|
|||||||
event_map[x.event_id] = x
|
event_map[x.event_id] = x
|
||||||
|
|
||||||
room_version = await self.store.get_room_version_id(room_id)
|
room_version = await self.store.get_room_version_id(room_id)
|
||||||
state_map = await self._state_resolution_handler.resolve_events_with_store(
|
state_map = (
|
||||||
|
await self._state_resolution_handler.resolve_events_with_store(
|
||||||
room_id,
|
room_id,
|
||||||
room_version,
|
room_version,
|
||||||
state_maps,
|
state_maps,
|
||||||
event_map,
|
event_map,
|
||||||
state_res_store=StateResolutionStore(self.store),
|
state_res_store=StateResolutionStore(self.store),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# We need to give _process_received_pdu the actual state events
|
# We need to give _process_received_pdu the actual state events
|
||||||
# rather than event ids, so generate that now.
|
# rather than event ids, so generate that now.
|
||||||
@ -687,7 +690,10 @@ class FederationHandler(BaseHandler):
|
|||||||
return fetched_events
|
return fetched_events
|
||||||
|
|
||||||
async def _process_received_pdu(
|
async def _process_received_pdu(
|
||||||
self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
|
self,
|
||||||
|
origin: str,
|
||||||
|
event: EventBase,
|
||||||
|
state: Optional[Iterable[EventBase]],
|
||||||
):
|
):
|
||||||
"""Called when we have a new pdu. We need to do auth checks and put it
|
"""Called when we have a new pdu. We need to do auth checks and put it
|
||||||
through the StateHandler.
|
through the StateHandler.
|
||||||
@ -1204,11 +1210,16 @@ class FederationHandler(BaseHandler):
|
|||||||
with nested_logging_context(event_id):
|
with nested_logging_context(event_id):
|
||||||
try:
|
try:
|
||||||
event = await self.federation_client.get_pdu(
|
event = await self.federation_client.get_pdu(
|
||||||
[destination], event_id, room_version, outlier=True,
|
[destination],
|
||||||
|
event_id,
|
||||||
|
room_version,
|
||||||
|
outlier=True,
|
||||||
)
|
)
|
||||||
if event is None:
|
if event is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Server %s didn't return event %s", destination, event_id,
|
"Server %s didn't return event %s",
|
||||||
|
destination,
|
||||||
|
event_id,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -1235,7 +1246,8 @@ class FederationHandler(BaseHandler):
|
|||||||
if aid not in event_map
|
if aid not in event_map
|
||||||
]
|
]
|
||||||
persisted_events = await self.store.get_events(
|
persisted_events = await self.store.get_events(
|
||||||
auth_events, allow_rejected=True,
|
auth_events,
|
||||||
|
allow_rejected=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
event_infos = []
|
event_infos = []
|
||||||
@ -1251,7 +1263,9 @@ class FederationHandler(BaseHandler):
|
|||||||
event_infos.append(_NewEventInfo(event, None, auth))
|
event_infos.append(_NewEventInfo(event, None, auth))
|
||||||
|
|
||||||
await self._handle_new_events(
|
await self._handle_new_events(
|
||||||
destination, room_id, event_infos,
|
destination,
|
||||||
|
room_id,
|
||||||
|
event_infos,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _sanity_check_event(self, ev):
|
def _sanity_check_event(self, ev):
|
||||||
@ -1388,7 +1402,8 @@ class FederationHandler(BaseHandler):
|
|||||||
# so we can rely on it now.
|
# so we can rely on it now.
|
||||||
#
|
#
|
||||||
await self.store.upsert_room_on_join(
|
await self.store.upsert_room_on_join(
|
||||||
room_id=room_id, room_version=room_version_obj,
|
room_id=room_id,
|
||||||
|
room_version=room_version_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
max_stream_id = await self._persist_auth_tree(
|
max_stream_id = await self._persist_auth_tree(
|
||||||
@ -1483,7 +1498,8 @@ class FederationHandler(BaseHandler):
|
|||||||
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
|
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
|
||||||
if not is_in_room:
|
if not is_in_room:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Got /make_join request for room %s we are no longer in", room_id,
|
"Got /make_join request for room %s we are no longer in",
|
||||||
|
room_id,
|
||||||
)
|
)
|
||||||
raise NotFoundError("Not an active room on this server")
|
raise NotFoundError("Not an active room on this server")
|
||||||
|
|
||||||
@ -1776,8 +1792,7 @@ class FederationHandler(BaseHandler):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
|
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
|
||||||
"""Returns the state at the event. i.e. not including said event.
|
"""Returns the state at the event. i.e. not including said event."""
|
||||||
"""
|
|
||||||
|
|
||||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||||
|
|
||||||
@ -1803,8 +1818,7 @@ class FederationHandler(BaseHandler):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
|
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
|
||||||
"""Returns the state at the event. i.e. not including said event.
|
"""Returns the state at the event. i.e. not including said event."""
|
||||||
"""
|
|
||||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||||
|
|
||||||
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
|
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
|
||||||
@ -2010,7 +2024,11 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
for e_id in missing_auth_events:
|
for e_id in missing_auth_events:
|
||||||
m_ev = await self.federation_client.get_pdu(
|
m_ev = await self.federation_client.get_pdu(
|
||||||
[origin], e_id, room_version=room_version, outlier=True, timeout=10000,
|
[origin],
|
||||||
|
e_id,
|
||||||
|
room_version=room_version,
|
||||||
|
outlier=True,
|
||||||
|
timeout=10000,
|
||||||
)
|
)
|
||||||
if m_ev and m_ev.event_id == e_id:
|
if m_ev and m_ev.event_id == e_id:
|
||||||
event_map[e_id] = m_ev
|
event_map[e_id] = m_ev
|
||||||
@ -2160,7 +2178,9 @@ class FederationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
|
"Doing soft-fail check for %s: state %s",
|
||||||
|
event.event_id,
|
||||||
|
current_state_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now check if event pass auth against said current state
|
# Now check if event pass auth against said current state
|
||||||
|
@ -146,8 +146,7 @@ class GroupsLocalWorkerHandler:
|
|||||||
async def get_users_in_group(
|
async def get_users_in_group(
|
||||||
self, group_id: str, requester_user_id: str
|
self, group_id: str, requester_user_id: str
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""Get users in a group
|
"""Get users in a group"""
|
||||||
"""
|
|
||||||
if self.is_mine_id(group_id):
|
if self.is_mine_id(group_id):
|
||||||
return await self.groups_server_handler.get_users_in_group(
|
return await self.groups_server_handler.get_users_in_group(
|
||||||
group_id, requester_user_id
|
group_id, requester_user_id
|
||||||
@ -283,8 +282,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
async def create_group(
|
async def create_group(
|
||||||
self, group_id: str, user_id: str, content: JsonDict
|
self, group_id: str, user_id: str, content: JsonDict
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""Create a group
|
"""Create a group"""
|
||||||
"""
|
|
||||||
|
|
||||||
logger.info("Asking to create group with ID: %r", group_id)
|
logger.info("Asking to create group with ID: %r", group_id)
|
||||||
|
|
||||||
@ -314,8 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
async def join_group(
|
async def join_group(
|
||||||
self, group_id: str, user_id: str, content: JsonDict
|
self, group_id: str, user_id: str, content: JsonDict
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""Request to join a group
|
"""Request to join a group"""
|
||||||
"""
|
|
||||||
if self.is_mine_id(group_id):
|
if self.is_mine_id(group_id):
|
||||||
await self.groups_server_handler.join_group(group_id, user_id, content)
|
await self.groups_server_handler.join_group(group_id, user_id, content)
|
||||||
local_attestation = None
|
local_attestation = None
|
||||||
@ -361,8 +358,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
async def accept_invite(
|
async def accept_invite(
|
||||||
self, group_id: str, user_id: str, content: JsonDict
|
self, group_id: str, user_id: str, content: JsonDict
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""Accept an invite to a group
|
"""Accept an invite to a group"""
|
||||||
"""
|
|
||||||
if self.is_mine_id(group_id):
|
if self.is_mine_id(group_id):
|
||||||
await self.groups_server_handler.accept_invite(group_id, user_id, content)
|
await self.groups_server_handler.accept_invite(group_id, user_id, content)
|
||||||
local_attestation = None
|
local_attestation = None
|
||||||
@ -408,8 +404,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
async def invite(
|
async def invite(
|
||||||
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
|
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""Invite a user to a group
|
"""Invite a user to a group"""
|
||||||
"""
|
|
||||||
content = {"requester_user_id": requester_user_id, "config": config}
|
content = {"requester_user_id": requester_user_id, "config": config}
|
||||||
if self.is_mine_id(group_id):
|
if self.is_mine_id(group_id):
|
||||||
res = await self.groups_server_handler.invite_to_group(
|
res = await self.groups_server_handler.invite_to_group(
|
||||||
@ -434,8 +429,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
async def on_invite(
|
async def on_invite(
|
||||||
self, group_id: str, user_id: str, content: JsonDict
|
self, group_id: str, user_id: str, content: JsonDict
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""One of our users were invited to a group
|
"""One of our users were invited to a group"""
|
||||||
"""
|
|
||||||
# TODO: Support auto join and rejection
|
# TODO: Support auto join and rejection
|
||||||
|
|
||||||
if not self.is_mine_id(user_id):
|
if not self.is_mine_id(user_id):
|
||||||
@ -466,8 +460,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
async def remove_user_from_group(
|
async def remove_user_from_group(
|
||||||
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
|
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
"""Remove a user from a group
|
"""Remove a user from a group"""
|
||||||
"""
|
|
||||||
if user_id == requester_user_id:
|
if user_id == requester_user_id:
|
||||||
token = await self.store.register_user_group_membership(
|
token = await self.store.register_user_group_membership(
|
||||||
group_id, user_id, membership="leave"
|
group_id, user_id, membership="leave"
|
||||||
@ -501,8 +494,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
async def user_removed_from_group(
|
async def user_removed_from_group(
|
||||||
self, group_id: str, user_id: str, content: JsonDict
|
self, group_id: str, user_id: str, content: JsonDict
|
||||||
) -> None:
|
) -> None:
|
||||||
"""One of our users was removed/kicked from a group
|
"""One of our users was removed/kicked from a group"""
|
||||||
"""
|
|
||||||
# TODO: Check if user in group
|
# TODO: Check if user in group
|
||||||
token = await self.store.register_user_group_membership(
|
token = await self.store.register_user_group_membership(
|
||||||
group_id, user_id, membership="leave"
|
group_id, user_id, membership="leave"
|
||||||
|
@ -72,7 +72,10 @@ class IdentityHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def ratelimit_request_token_requests(
|
def ratelimit_request_token_requests(
|
||||||
self, request: SynapseRequest, medium: str, address: str,
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
medium: str,
|
||||||
|
address: str,
|
||||||
):
|
):
|
||||||
"""Used to ratelimit requests to `/requestToken` by IP and address.
|
"""Used to ratelimit requests to `/requestToken` by IP and address.
|
||||||
|
|
||||||
|
@ -124,7 +124,8 @@ class InitialSyncHandler(BaseHandler):
|
|||||||
|
|
||||||
joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
|
joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
|
||||||
receipt = await self.store.get_linearized_receipts_for_rooms(
|
receipt = await self.store.get_linearized_receipts_for_rooms(
|
||||||
joined_rooms, to_key=int(now_token.receipt_key),
|
joined_rooms,
|
||||||
|
to_key=int(now_token.receipt_key),
|
||||||
)
|
)
|
||||||
|
|
||||||
tags_by_room = await self.store.get_tags_for_user(user_id)
|
tags_by_room = await self.store.get_tags_for_user(user_id)
|
||||||
@ -169,7 +170,10 @@ class InitialSyncHandler(BaseHandler):
|
|||||||
self.state_handler.get_current_state, event.room_id
|
self.state_handler.get_current_state, event.room_id
|
||||||
)
|
)
|
||||||
elif event.membership == Membership.LEAVE:
|
elif event.membership == Membership.LEAVE:
|
||||||
room_end_token = RoomStreamToken(None, event.stream_ordering,)
|
room_end_token = RoomStreamToken(
|
||||||
|
None,
|
||||||
|
event.stream_ordering,
|
||||||
|
)
|
||||||
deferred_room_state = run_in_background(
|
deferred_room_state = run_in_background(
|
||||||
self.state_store.get_state_for_events, [event.event_id]
|
self.state_store.get_state_for_events, [event.event_id]
|
||||||
)
|
)
|
||||||
@ -284,7 +288,9 @@ class InitialSyncHandler(BaseHandler):
|
|||||||
membership,
|
membership,
|
||||||
member_event_id,
|
member_event_id,
|
||||||
) = await self.auth.check_user_in_room_or_world_readable(
|
) = await self.auth.check_user_in_room_or_world_readable(
|
||||||
room_id, user_id, allow_departed_users=True,
|
room_id,
|
||||||
|
user_id,
|
||||||
|
allow_departed_users=True,
|
||||||
)
|
)
|
||||||
is_peeking = member_event_id is None
|
is_peeking = member_event_id is None
|
||||||
|
|
||||||
|
@ -65,8 +65,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class MessageHandler:
|
class MessageHandler:
|
||||||
"""Contains some read only APIs to get state about a room
|
"""Contains some read only APIs to get state about a room"""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
@ -88,7 +87,11 @@ class MessageHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def get_room_data(
|
async def get_room_data(
|
||||||
self, user_id: str, room_id: str, event_type: str, state_key: str,
|
self,
|
||||||
|
user_id: str,
|
||||||
|
room_id: str,
|
||||||
|
event_type: str,
|
||||||
|
state_key: str,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Get data from a room.
|
"""Get data from a room.
|
||||||
|
|
||||||
@ -174,7 +177,10 @@ class MessageHandler:
|
|||||||
raise NotFoundError("Can't find event for token %s" % (at_token,))
|
raise NotFoundError("Can't find event for token %s" % (at_token,))
|
||||||
|
|
||||||
visible_events = await filter_events_for_client(
|
visible_events = await filter_events_for_client(
|
||||||
self.storage, user_id, last_events, filter_send_to_client=False,
|
self.storage,
|
||||||
|
user_id,
|
||||||
|
last_events,
|
||||||
|
filter_send_to_client=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
event = last_events[0]
|
event = last_events[0]
|
||||||
@ -793,9 +799,10 @@ class EventCreationHandler:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if prev_event_ids is not None:
|
if prev_event_ids is not None:
|
||||||
assert len(prev_event_ids) <= 10, (
|
assert (
|
||||||
"Attempting to create an event with %i prev_events"
|
len(prev_event_ids) <= 10
|
||||||
% (len(prev_event_ids),)
|
), "Attempting to create an event with %i prev_events" % (
|
||||||
|
len(prev_event_ids),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
|
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
|
||||||
@ -821,7 +828,8 @@ class EventCreationHandler:
|
|||||||
)
|
)
|
||||||
if not third_party_result:
|
if not third_party_result:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Event %s forbidden by third-party rules", event,
|
"Event %s forbidden by third-party rules",
|
||||||
|
event,
|
||||||
)
|
)
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403, "This event is not allowed in this context", Codes.FORBIDDEN
|
403, "This event is not allowed in this context", Codes.FORBIDDEN
|
||||||
@ -1316,7 +1324,11 @@ class EventCreationHandler:
|
|||||||
# Since this is a dummy-event it is OK if it is sent by a
|
# Since this is a dummy-event it is OK if it is sent by a
|
||||||
# shadow-banned user.
|
# shadow-banned user.
|
||||||
await self.handle_new_client_event(
|
await self.handle_new_client_event(
|
||||||
requester, event, context, ratelimit=False, ignore_shadow_ban=True,
|
requester,
|
||||||
|
event,
|
||||||
|
context,
|
||||||
|
ratelimit=False,
|
||||||
|
ignore_shadow_ban=True,
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
except AuthError:
|
except AuthError:
|
||||||
|
@ -73,8 +73,7 @@ JWKS = TypedDict("JWKS", {"keys": List[JWK]})
|
|||||||
|
|
||||||
|
|
||||||
class OidcHandler:
|
class OidcHandler:
|
||||||
"""Handles requests related to the OpenID Connect login flow.
|
"""Handles requests related to the OpenID Connect login flow."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
@ -216,8 +215,7 @@ class OidcHandler:
|
|||||||
|
|
||||||
|
|
||||||
class OidcError(Exception):
|
class OidcError(Exception):
|
||||||
"""Used to catch errors when calling the token_endpoint
|
"""Used to catch errors when calling the token_endpoint"""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, error, error_description=None):
|
def __init__(self, error, error_description=None):
|
||||||
self.error = error
|
self.error = error
|
||||||
@ -252,7 +250,9 @@ class OidcProvider:
|
|||||||
self._scopes = provider.scopes
|
self._scopes = provider.scopes
|
||||||
self._user_profile_method = provider.user_profile_method
|
self._user_profile_method = provider.user_profile_method
|
||||||
self._client_auth = ClientAuth(
|
self._client_auth = ClientAuth(
|
||||||
provider.client_id, provider.client_secret, provider.client_auth_method,
|
provider.client_id,
|
||||||
|
provider.client_secret,
|
||||||
|
provider.client_auth_method,
|
||||||
) # type: ClientAuth
|
) # type: ClientAuth
|
||||||
self._client_auth_method = provider.client_auth_method
|
self._client_auth_method = provider.client_auth_method
|
||||||
|
|
||||||
@ -509,7 +509,10 @@ class OidcProvider:
|
|||||||
# We're not using the SimpleHttpClient util methods as we don't want to
|
# We're not using the SimpleHttpClient util methods as we don't want to
|
||||||
# check the HTTP status code and we do the body encoding ourself.
|
# check the HTTP status code and we do the body encoding ourself.
|
||||||
response = await self._http_client.request(
|
response = await self._http_client.request(
|
||||||
method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
|
method="POST",
|
||||||
|
uri=uri,
|
||||||
|
data=body.encode("utf-8"),
|
||||||
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
# This is used in multiple error messages below
|
# This is used in multiple error messages below
|
||||||
@ -966,7 +969,9 @@ class OidcSessionTokenGenerator:
|
|||||||
A signed macaroon token with the session information.
|
A signed macaroon token with the session information.
|
||||||
"""
|
"""
|
||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
|
location=self._server_name,
|
||||||
|
identifier="key",
|
||||||
|
key=self._macaroon_secret_key,
|
||||||
)
|
)
|
||||||
macaroon.add_first_party_caveat("gen = 1")
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
macaroon.add_first_party_caveat("type = session")
|
macaroon.add_first_party_caveat("type = session")
|
||||||
|
@ -197,7 +197,8 @@ class PaginationHandler:
|
|||||||
stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
|
stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
|
||||||
|
|
||||||
r = await self.store.get_room_event_before_stream_ordering(
|
r = await self.store.get_room_event_before_stream_ordering(
|
||||||
room_id, stream_ordering,
|
room_id,
|
||||||
|
stream_ordering,
|
||||||
)
|
)
|
||||||
if not r:
|
if not r:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -223,7 +224,12 @@ class PaginationHandler:
|
|||||||
# the background so that it's not blocking any other operation apart from
|
# the background so that it's not blocking any other operation apart from
|
||||||
# other purges in the same room.
|
# other purges in the same room.
|
||||||
run_as_background_process(
|
run_as_background_process(
|
||||||
"_purge_history", self._purge_history, purge_id, room_id, token, True,
|
"_purge_history",
|
||||||
|
self._purge_history,
|
||||||
|
purge_id,
|
||||||
|
room_id,
|
||||||
|
token,
|
||||||
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def start_purge_history(
|
def start_purge_history(
|
||||||
@ -389,7 +395,9 @@ class PaginationHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
await self.hs.get_federation_handler().maybe_backfill(
|
await self.hs.get_federation_handler().maybe_backfill(
|
||||||
room_id, curr_topo, limit=pagin_config.limit,
|
room_id,
|
||||||
|
curr_topo,
|
||||||
|
limit=pagin_config.limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
to_room_key = None
|
to_room_key = None
|
||||||
|
@ -635,8 +635,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||||||
self.external_process_last_updated_ms.pop(process_id, None)
|
self.external_process_last_updated_ms.pop(process_id, None)
|
||||||
|
|
||||||
async def current_state_for_user(self, user_id):
|
async def current_state_for_user(self, user_id):
|
||||||
"""Get the current presence state for a user.
|
"""Get the current presence state for a user."""
|
||||||
"""
|
|
||||||
res = await self.current_state_for_users([user_id])
|
res = await self.current_state_for_users([user_id])
|
||||||
return res[user_id]
|
return res[user_id]
|
||||||
|
|
||||||
@ -678,8 +677,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||||||
self.federation.send_presence(states)
|
self.federation.send_presence(states)
|
||||||
|
|
||||||
async def incoming_presence(self, origin, content):
|
async def incoming_presence(self, origin, content):
|
||||||
"""Called when we receive a `m.presence` EDU from a remote server.
|
"""Called when we receive a `m.presence` EDU from a remote server."""
|
||||||
"""
|
|
||||||
if not self._presence_enabled:
|
if not self._presence_enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -729,8 +727,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||||||
await self._update_states(updates)
|
await self._update_states(updates)
|
||||||
|
|
||||||
async def set_state(self, target_user, state, ignore_status_msg=False):
|
async def set_state(self, target_user, state, ignore_status_msg=False):
|
||||||
"""Set the presence state of the user.
|
"""Set the presence state of the user."""
|
||||||
"""
|
|
||||||
status_msg = state.get("status_msg", None)
|
status_msg = state.get("status_msg", None)
|
||||||
presence = state["presence"]
|
presence = state["presence"]
|
||||||
|
|
||||||
@ -758,8 +755,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||||||
await self._update_states([prev_state.copy_and_replace(**new_fields)])
|
await self._update_states([prev_state.copy_and_replace(**new_fields)])
|
||||||
|
|
||||||
async def is_visible(self, observed_user, observer_user):
|
async def is_visible(self, observed_user, observer_user):
|
||||||
"""Returns whether a user can see another user's presence.
|
"""Returns whether a user can see another user's presence."""
|
||||||
"""
|
|
||||||
observer_room_ids = await self.store.get_rooms_for_user(
|
observer_room_ids = await self.store.get_rooms_for_user(
|
||||||
observer_user.to_string()
|
observer_user.to_string()
|
||||||
)
|
)
|
||||||
@ -953,8 +949,7 @@ class PresenceHandler(BasePresenceHandler):
|
|||||||
|
|
||||||
|
|
||||||
def should_notify(old_state, new_state):
|
def should_notify(old_state, new_state):
|
||||||
"""Decides if a presence state change should be sent to interested parties.
|
"""Decides if a presence state change should be sent to interested parties."""
|
||||||
"""
|
|
||||||
if old_state == new_state:
|
if old_state == new_state:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -207,7 +207,8 @@ class ProfileHandler(BaseHandler):
|
|||||||
# This must be done by the target user himself.
|
# This must be done by the target user himself.
|
||||||
if by_admin:
|
if by_admin:
|
||||||
requester = create_requester(
|
requester = create_requester(
|
||||||
target_user, authenticated_entity=requester.authenticated_entity,
|
target_user,
|
||||||
|
authenticated_entity=requester.authenticated_entity,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.store.set_profile_displayname(
|
await self.store.set_profile_displayname(
|
||||||
|
@ -49,15 +49,15 @@ class ReceiptsHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hs.get_federation_registry().register_instances_for_edu(
|
hs.get_federation_registry().register_instances_for_edu(
|
||||||
"m.receipt", hs.config.worker.writers.receipts,
|
"m.receipt",
|
||||||
|
hs.config.worker.writers.receipts,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
|
async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
|
||||||
"""Called when we receive an EDU of type m.receipt from a remote HS.
|
"""Called when we receive an EDU of type m.receipt from a remote HS."""
|
||||||
"""
|
|
||||||
receipts = []
|
receipts = []
|
||||||
for room_id, room_values in content.items():
|
for room_id, room_values in content.items():
|
||||||
for receipt_type, users in room_values.items():
|
for receipt_type, users in room_values.items():
|
||||||
@ -83,8 +83,7 @@ class ReceiptsHandler(BaseHandler):
|
|||||||
await self._handle_new_receipts(receipts)
|
await self._handle_new_receipts(receipts)
|
||||||
|
|
||||||
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
|
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
|
||||||
"""Takes a list of receipts, stores them and informs the notifier.
|
"""Takes a list of receipts, stores them and informs the notifier."""
|
||||||
"""
|
|
||||||
min_batch_id = None # type: Optional[int]
|
min_batch_id = None # type: Optional[int]
|
||||||
max_batch_id = None # type: Optional[int]
|
max_batch_id = None # type: Optional[int]
|
||||||
|
|
||||||
|
@ -62,8 +62,8 @@ class RegistrationHandler(BaseHandler):
|
|||||||
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
|
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
|
||||||
hs
|
hs
|
||||||
)
|
)
|
||||||
self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client(
|
self._post_registration_client = (
|
||||||
hs
|
ReplicationPostRegisterActionsServlet.make_client(hs)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
@ -189,12 +189,15 @@ class RegistrationHandler(BaseHandler):
|
|||||||
self.check_registration_ratelimit(address)
|
self.check_registration_ratelimit(address)
|
||||||
|
|
||||||
result = await self.spam_checker.check_registration_for_spam(
|
result = await self.spam_checker.check_registration_for_spam(
|
||||||
threepid, localpart, user_agent_ips or [],
|
threepid,
|
||||||
|
localpart,
|
||||||
|
user_agent_ips or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
if result == RegistrationBehaviour.DENY:
|
if result == RegistrationBehaviour.DENY:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Blocked registration of %r", localpart,
|
"Blocked registration of %r",
|
||||||
|
localpart,
|
||||||
)
|
)
|
||||||
# We return a 429 to make it not obvious that they've been
|
# We return a 429 to make it not obvious that they've been
|
||||||
# denied.
|
# denied.
|
||||||
@ -203,7 +206,8 @@ class RegistrationHandler(BaseHandler):
|
|||||||
shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
|
shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
|
||||||
if shadow_banned:
|
if shadow_banned:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Shadow banning registration of %r", localpart,
|
"Shadow banning registration of %r",
|
||||||
|
localpart,
|
||||||
)
|
)
|
||||||
|
|
||||||
# do not check_auth_blocking if the call is coming through the Admin API
|
# do not check_auth_blocking if the call is coming through the Admin API
|
||||||
@ -369,7 +373,9 @@ class RegistrationHandler(BaseHandler):
|
|||||||
config["room_alias_name"] = room_alias.localpart
|
config["room_alias_name"] = room_alias.localpart
|
||||||
|
|
||||||
info, _ = await room_creation_handler.create_room(
|
info, _ = await room_creation_handler.create_room(
|
||||||
fake_requester, config=config, ratelimit=False,
|
fake_requester,
|
||||||
|
config=config,
|
||||||
|
ratelimit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If the room does not require an invite, but another user
|
# If the room does not require an invite, but another user
|
||||||
@ -753,7 +759,10 @@ class RegistrationHandler(BaseHandler):
|
|||||||
return
|
return
|
||||||
|
|
||||||
await self._auth_handler.add_threepid(
|
await self._auth_handler.add_threepid(
|
||||||
user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
|
user_id,
|
||||||
|
threepid["medium"],
|
||||||
|
threepid["address"],
|
||||||
|
threepid["validated_at"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# And we add an email pusher for them by default, but only
|
# And we add an email pusher for them by default, but only
|
||||||
@ -805,5 +814,8 @@ class RegistrationHandler(BaseHandler):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
await self._auth_handler.add_threepid(
|
await self._auth_handler.add_threepid(
|
||||||
user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
|
user_id,
|
||||||
|
threepid["medium"],
|
||||||
|
threepid["address"],
|
||||||
|
threepid["validated_at"],
|
||||||
)
|
)
|
||||||
|
@ -198,7 +198,9 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
if r is None:
|
if r is None:
|
||||||
raise NotFoundError("Unknown room id %s" % (old_room_id,))
|
raise NotFoundError("Unknown room id %s" % (old_room_id,))
|
||||||
new_room_id = await self._generate_room_id(
|
new_room_id = await self._generate_room_id(
|
||||||
creator_id=user_id, is_public=r["is_public"], room_version=new_version,
|
creator_id=user_id,
|
||||||
|
is_public=r["is_public"],
|
||||||
|
room_version=new_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
|
logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
|
||||||
@ -236,7 +238,9 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
|
|
||||||
# now send the tombstone
|
# now send the tombstone
|
||||||
await self.event_creation_handler.handle_new_client_event(
|
await self.event_creation_handler.handle_new_client_event(
|
||||||
requester=requester, event=tombstone_event, context=tombstone_context,
|
requester=requester,
|
||||||
|
event=tombstone_event,
|
||||||
|
context=tombstone_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
old_room_state = await tombstone_context.get_current_state_ids()
|
old_room_state = await tombstone_context.get_current_state_ids()
|
||||||
@ -257,7 +261,10 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
# finally, shut down the PLs in the old room, and update them in the new
|
# finally, shut down the PLs in the old room, and update them in the new
|
||||||
# room.
|
# room.
|
||||||
await self._update_upgraded_room_pls(
|
await self._update_upgraded_room_pls(
|
||||||
requester, old_room_id, new_room_id, old_room_state,
|
requester,
|
||||||
|
old_room_id,
|
||||||
|
new_room_id,
|
||||||
|
old_room_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_room_id
|
return new_room_id
|
||||||
@ -691,7 +698,9 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
is_public = visibility == "public"
|
is_public = visibility == "public"
|
||||||
|
|
||||||
room_id = await self._generate_room_id(
|
room_id = await self._generate_room_id(
|
||||||
creator_id=user_id, is_public=is_public, room_version=room_version,
|
creator_id=user_id,
|
||||||
|
is_public=is_public,
|
||||||
|
room_version=room_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check whether this visibility value is blocked by a third party module
|
# Check whether this visibility value is blocked by a third party module
|
||||||
@ -884,7 +893,10 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
_,
|
_,
|
||||||
last_stream_id,
|
last_stream_id,
|
||||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||||
creator, event, ratelimit=False, ignore_shadow_ban=True,
|
creator,
|
||||||
|
event,
|
||||||
|
ratelimit=False,
|
||||||
|
ignore_shadow_ban=True,
|
||||||
)
|
)
|
||||||
return last_stream_id
|
return last_stream_id
|
||||||
|
|
||||||
@ -984,7 +996,10 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
return last_sent_stream_id
|
return last_sent_stream_id
|
||||||
|
|
||||||
async def _generate_room_id(
|
async def _generate_room_id(
|
||||||
self, creator_id: str, is_public: bool, room_version: RoomVersion,
|
self,
|
||||||
|
creator_id: str,
|
||||||
|
is_public: bool,
|
||||||
|
room_version: RoomVersion,
|
||||||
):
|
):
|
||||||
# autogen room IDs and try to create it. We may clash, so just
|
# autogen room IDs and try to create it. We may clash, so just
|
||||||
# try a few times till one goes through, giving up eventually.
|
# try a few times till one goes through, giving up eventually.
|
||||||
|
@ -191,7 +191,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||||||
# do it up front for efficiency.)
|
# do it up front for efficiency.)
|
||||||
if txn_id and requester.access_token_id:
|
if txn_id and requester.access_token_id:
|
||||||
existing_event_id = await self.store.get_event_id_from_transaction_id(
|
existing_event_id = await self.store.get_event_id_from_transaction_id(
|
||||||
room_id, requester.user.to_string(), requester.access_token_id, txn_id,
|
room_id,
|
||||||
|
requester.user.to_string(),
|
||||||
|
requester.access_token_id,
|
||||||
|
txn_id,
|
||||||
)
|
)
|
||||||
if existing_event_id:
|
if existing_event_id:
|
||||||
event_pos = await self.store.get_position_for_event(existing_event_id)
|
event_pos = await self.store.get_position_for_event(existing_event_id)
|
||||||
@ -238,7 +241,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||||||
)
|
)
|
||||||
|
|
||||||
result_event = await self.event_creation_handler.handle_new_client_event(
|
result_event = await self.event_creation_handler.handle_new_client_event(
|
||||||
requester, event, context, extra_users=[target], ratelimit=ratelimit,
|
requester,
|
||||||
|
event,
|
||||||
|
context,
|
||||||
|
extra_users=[target],
|
||||||
|
ratelimit=ratelimit,
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.membership == Membership.LEAVE:
|
if event.membership == Membership.LEAVE:
|
||||||
@ -583,7 +590,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||||||
# send the rejection to the inviter's HS (with fallback to
|
# send the rejection to the inviter's HS (with fallback to
|
||||||
# local event)
|
# local event)
|
||||||
return await self.remote_reject_invite(
|
return await self.remote_reject_invite(
|
||||||
invite.event_id, txn_id, requester, content,
|
invite.event_id,
|
||||||
|
txn_id,
|
||||||
|
requester,
|
||||||
|
content,
|
||||||
)
|
)
|
||||||
|
|
||||||
# the inviter was on our server, but has now left. Carry on
|
# the inviter was on our server, but has now left. Carry on
|
||||||
@ -1056,8 +1066,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||||||
user: UserID,
|
user: UserID,
|
||||||
content: dict,
|
content: dict,
|
||||||
) -> Tuple[str, int]:
|
) -> Tuple[str, int]:
|
||||||
"""Implements RoomMemberHandler._remote_join
|
"""Implements RoomMemberHandler._remote_join"""
|
||||||
"""
|
|
||||||
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
|
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
|
||||||
# and if it is the only entry we'd like to return a 404 rather than a
|
# and if it is the only entry we'd like to return a 404 rather than a
|
||||||
# 500.
|
# 500.
|
||||||
@ -1211,7 +1220,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||||||
event.internal_metadata.out_of_band_membership = True
|
event.internal_metadata.out_of_band_membership = True
|
||||||
|
|
||||||
result_event = await self.event_creation_handler.handle_new_client_event(
|
result_event = await self.event_creation_handler.handle_new_client_event(
|
||||||
requester, event, context, extra_users=[UserID.from_string(target_user)],
|
requester,
|
||||||
|
event,
|
||||||
|
context,
|
||||||
|
extra_users=[UserID.from_string(target_user)],
|
||||||
)
|
)
|
||||||
# we know it was persisted, so must have a stream ordering
|
# we know it was persisted, so must have a stream ordering
|
||||||
assert result_event.internal_metadata.stream_ordering
|
assert result_event.internal_metadata.stream_ordering
|
||||||
@ -1219,8 +1231,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||||||
return result_event.event_id, result_event.internal_metadata.stream_ordering
|
return result_event.event_id, result_event.internal_metadata.stream_ordering
|
||||||
|
|
||||||
async def _user_left_room(self, target: UserID, room_id: str) -> None:
|
async def _user_left_room(self, target: UserID, room_id: str) -> None:
|
||||||
"""Implements RoomMemberHandler._user_left_room
|
"""Implements RoomMemberHandler._user_left_room"""
|
||||||
"""
|
|
||||||
user_left_room(self.distributor, target, room_id)
|
user_left_room(self.distributor, target, room_id)
|
||||||
|
|
||||||
async def forget(self, user: UserID, room_id: str) -> None:
|
async def forget(self, user: UserID, room_id: str) -> None:
|
||||||
|
@ -44,8 +44,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
|
|||||||
user: UserID,
|
user: UserID,
|
||||||
content: dict,
|
content: dict,
|
||||||
) -> Tuple[str, int]:
|
) -> Tuple[str, int]:
|
||||||
"""Implements RoomMemberHandler._remote_join
|
"""Implements RoomMemberHandler._remote_join"""
|
||||||
"""
|
|
||||||
if len(remote_room_hosts) == 0:
|
if len(remote_room_hosts) == 0:
|
||||||
raise SynapseError(404, "No known servers")
|
raise SynapseError(404, "No known servers")
|
||||||
|
|
||||||
@ -80,8 +79,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
|
|||||||
return ret["event_id"], ret["stream_id"]
|
return ret["event_id"], ret["stream_id"]
|
||||||
|
|
||||||
async def _user_left_room(self, target: UserID, room_id: str) -> None:
|
async def _user_left_room(self, target: UserID, room_id: str) -> None:
|
||||||
"""Implements RoomMemberHandler._user_left_room
|
"""Implements RoomMemberHandler._user_left_room"""
|
||||||
"""
|
|
||||||
await self._notify_change_client(
|
await self._notify_change_client(
|
||||||
user_id=target.to_string(), room_id=room_id, change="left"
|
user_id=target.to_string(), room_id=room_id, change="left"
|
||||||
)
|
)
|
||||||
|
@ -121,7 +121,8 @@ class SamlHandler(BaseHandler):
|
|||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
self._outstanding_requests_dict[reqid] = Saml2SessionData(
|
self._outstanding_requests_dict[reqid] = Saml2SessionData(
|
||||||
creation_time=now, ui_auth_session_id=ui_auth_session_id,
|
creation_time=now,
|
||||||
|
ui_auth_session_id=ui_auth_session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
for key, value in info["headers"]:
|
for key, value in info["headers"]:
|
||||||
@ -450,7 +451,8 @@ class DefaultSamlMappingProvider:
|
|||||||
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
|
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
|
"SAML2 response lacks a '%s' attestation",
|
||||||
|
self._mxid_source_attribute,
|
||||||
)
|
)
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
|
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
|
||||||
|
@ -327,7 +327,8 @@ class SsoHandler:
|
|||||||
|
|
||||||
# Check if we already have a mapping for this user.
|
# Check if we already have a mapping for this user.
|
||||||
previously_registered_user_id = await self._store.get_user_by_external_id(
|
previously_registered_user_id = await self._store.get_user_by_external_id(
|
||||||
auth_provider_id, remote_user_id,
|
auth_provider_id,
|
||||||
|
remote_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# A match was found, return the user ID.
|
# A match was found, return the user ID.
|
||||||
@ -416,7 +417,8 @@ class SsoHandler:
|
|||||||
with await self._mapping_lock.queue(auth_provider_id):
|
with await self._mapping_lock.queue(auth_provider_id):
|
||||||
# first of all, check if we already have a mapping for this user
|
# first of all, check if we already have a mapping for this user
|
||||||
user_id = await self.get_sso_user_by_remote_user_id(
|
user_id = await self.get_sso_user_by_remote_user_id(
|
||||||
auth_provider_id, remote_user_id,
|
auth_provider_id,
|
||||||
|
remote_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for grandfathering of users.
|
# Check for grandfathering of users.
|
||||||
@ -461,7 +463,8 @@ class SsoHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _call_attribute_mapper(
|
async def _call_attribute_mapper(
|
||||||
self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
self,
|
||||||
|
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
|
||||||
) -> UserAttributes:
|
) -> UserAttributes:
|
||||||
"""Call the attribute mapper function in a loop, until we get a unique userid"""
|
"""Call the attribute mapper function in a loop, until we get a unique userid"""
|
||||||
for i in range(self._MAP_USERNAME_RETRIES):
|
for i in range(self._MAP_USERNAME_RETRIES):
|
||||||
@ -632,7 +635,8 @@ class SsoHandler:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
user_id = await self.get_sso_user_by_remote_user_id(
|
user_id = await self.get_sso_user_by_remote_user_id(
|
||||||
auth_provider_id, remote_user_id,
|
auth_provider_id,
|
||||||
|
remote_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
user_id_to_verify = await self._auth_handler.get_session_data(
|
user_id_to_verify = await self._auth_handler.get_session_data(
|
||||||
@ -671,7 +675,8 @@ class SsoHandler:
|
|||||||
|
|
||||||
# render an error page.
|
# render an error page.
|
||||||
html = self._bad_user_template.render(
|
html = self._bad_user_template.render(
|
||||||
server_name=self._server_name, user_id_to_verify=user_id_to_verify,
|
server_name=self._server_name,
|
||||||
|
user_id_to_verify=user_id_to_verify,
|
||||||
)
|
)
|
||||||
respond_with_html(request, 200, html)
|
respond_with_html(request, 200, html)
|
||||||
|
|
||||||
@ -695,7 +700,9 @@ class SsoHandler:
|
|||||||
raise SynapseError(400, "unknown session")
|
raise SynapseError(400, "unknown session")
|
||||||
|
|
||||||
async def check_username_availability(
|
async def check_username_availability(
|
||||||
self, localpart: str, session_id: str,
|
self,
|
||||||
|
localpart: str,
|
||||||
|
session_id: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Handle an "is username available" callback check
|
"""Handle an "is username available" callback check
|
||||||
|
|
||||||
@ -833,7 +840,8 @@ class SsoHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
attributes = UserAttributes(
|
attributes = UserAttributes(
|
||||||
localpart=session.chosen_localpart, emails=session.emails_to_use,
|
localpart=session.chosen_localpart,
|
||||||
|
emails=session.emails_to_use,
|
||||||
)
|
)
|
||||||
|
|
||||||
if session.use_display_name:
|
if session.use_display_name:
|
||||||
|
@ -63,8 +63,7 @@ class StatsHandler:
|
|||||||
self.clock.call_later(0, self.notify_new_event)
|
self.clock.call_later(0, self.notify_new_event)
|
||||||
|
|
||||||
def notify_new_event(self) -> None:
|
def notify_new_event(self) -> None:
|
||||||
"""Called when there may be more deltas to process
|
"""Called when there may be more deltas to process"""
|
||||||
"""
|
|
||||||
if not self.stats_enabled or self._is_processing:
|
if not self.stats_enabled or self._is_processing:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -339,8 +339,7 @@ class SyncHandler:
|
|||||||
since_token: Optional[StreamToken] = None,
|
since_token: Optional[StreamToken] = None,
|
||||||
full_state: bool = False,
|
full_state: bool = False,
|
||||||
) -> SyncResult:
|
) -> SyncResult:
|
||||||
"""Get the sync for client needed to match what the server has now.
|
"""Get the sync for client needed to match what the server has now."""
|
||||||
"""
|
|
||||||
return await self.generate_sync_result(sync_config, since_token, full_state)
|
return await self.generate_sync_result(sync_config, since_token, full_state)
|
||||||
|
|
||||||
async def push_rules_for_user(self, user: UserID) -> JsonDict:
|
async def push_rules_for_user(self, user: UserID) -> JsonDict:
|
||||||
@ -820,9 +819,11 @@ class SyncHandler:
|
|||||||
)
|
)
|
||||||
elif batch.limited:
|
elif batch.limited:
|
||||||
if batch:
|
if batch:
|
||||||
state_at_timeline_start = await self.state_store.get_state_ids_for_event(
|
state_at_timeline_start = (
|
||||||
|
await self.state_store.get_state_ids_for_event(
|
||||||
batch.events[0].event_id, state_filter=state_filter
|
batch.events[0].event_id, state_filter=state_filter
|
||||||
)
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# We can get here if the user has ignored the senders of all
|
# We can get here if the user has ignored the senders of all
|
||||||
# the recent events.
|
# the recent events.
|
||||||
@ -955,8 +956,7 @@ class SyncHandler:
|
|||||||
since_token: Optional[StreamToken] = None,
|
since_token: Optional[StreamToken] = None,
|
||||||
full_state: bool = False,
|
full_state: bool = False,
|
||||||
) -> SyncResult:
|
) -> SyncResult:
|
||||||
"""Generates a sync result.
|
"""Generates a sync result."""
|
||||||
"""
|
|
||||||
# NB: The now_token gets changed by some of the generate_sync_* methods,
|
# NB: The now_token gets changed by some of the generate_sync_* methods,
|
||||||
# this is due to some of the underlying streams not supporting the ability
|
# this is due to some of the underlying streams not supporting the ability
|
||||||
# to query up to a given point.
|
# to query up to a given point.
|
||||||
@ -1030,8 +1030,8 @@ class SyncHandler:
|
|||||||
one_time_key_counts = await self.store.count_e2e_one_time_keys(
|
one_time_key_counts = await self.store.count_e2e_one_time_keys(
|
||||||
user_id, device_id
|
user_id, device_id
|
||||||
)
|
)
|
||||||
unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
|
unused_fallback_key_types = (
|
||||||
user_id, device_id
|
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Fetching group data")
|
logger.debug("Fetching group data")
|
||||||
@ -1176,9 +1176,11 @@ class SyncHandler:
|
|||||||
# weren't in the previous sync *or* they left and rejoined.
|
# weren't in the previous sync *or* they left and rejoined.
|
||||||
users_that_have_changed.update(newly_joined_or_invited_users)
|
users_that_have_changed.update(newly_joined_or_invited_users)
|
||||||
|
|
||||||
user_signatures_changed = await self.store.get_users_whose_signatures_changed(
|
user_signatures_changed = (
|
||||||
|
await self.store.get_users_whose_signatures_changed(
|
||||||
user_id, since_token.device_list_key
|
user_id, since_token.device_list_key
|
||||||
)
|
)
|
||||||
|
)
|
||||||
users_that_have_changed.update(user_signatures_changed)
|
users_that_have_changed.update(user_signatures_changed)
|
||||||
|
|
||||||
# Now find users that we no longer track
|
# Now find users that we no longer track
|
||||||
@ -1393,9 +1395,11 @@ class SyncHandler:
|
|||||||
logger.debug("no-oping sync")
|
logger.debug("no-oping sync")
|
||||||
return set(), set(), set(), set()
|
return set(), set(), set(), set()
|
||||||
|
|
||||||
ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
|
ignored_account_data = (
|
||||||
|
await self.store.get_global_account_data_by_type_for_user(
|
||||||
AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
|
AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# If there is ignored users account data and it matches the proper type,
|
# If there is ignored users account data and it matches the proper type,
|
||||||
# then use it.
|
# then use it.
|
||||||
@ -1499,8 +1503,7 @@ class SyncHandler:
|
|||||||
async def _get_rooms_changed(
|
async def _get_rooms_changed(
|
||||||
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
|
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
|
||||||
) -> _RoomChanges:
|
) -> _RoomChanges:
|
||||||
"""Gets the the changes that have happened since the last sync.
|
"""Gets the the changes that have happened since the last sync."""
|
||||||
"""
|
|
||||||
user_id = sync_result_builder.sync_config.user.to_string()
|
user_id = sync_result_builder.sync_config.user.to_string()
|
||||||
since_token = sync_result_builder.since_token
|
since_token = sync_result_builder.since_token
|
||||||
now_token = sync_result_builder.now_token
|
now_token = sync_result_builder.now_token
|
||||||
|
@ -61,7 +61,8 @@ class FollowerTypingHandler:
|
|||||||
|
|
||||||
if hs.config.worker.writers.typing != hs.get_instance_name():
|
if hs.config.worker.writers.typing != hs.get_instance_name():
|
||||||
hs.get_federation_registry().register_instance_for_edu(
|
hs.get_federation_registry().register_instance_for_edu(
|
||||||
"m.typing", hs.config.worker.writers.typing,
|
"m.typing",
|
||||||
|
hs.config.worker.writers.typing,
|
||||||
)
|
)
|
||||||
|
|
||||||
# map room IDs to serial numbers
|
# map room IDs to serial numbers
|
||||||
@ -76,8 +77,7 @@ class FollowerTypingHandler:
|
|||||||
self.clock.looping_call(self._handle_timeouts, 5000)
|
self.clock.looping_call(self._handle_timeouts, 5000)
|
||||||
|
|
||||||
def _reset(self) -> None:
|
def _reset(self) -> None:
|
||||||
"""Reset the typing handler's data caches.
|
"""Reset the typing handler's data caches."""
|
||||||
"""
|
|
||||||
# map room IDs to serial numbers
|
# map room IDs to serial numbers
|
||||||
self._room_serials = {}
|
self._room_serials = {}
|
||||||
# map room IDs to sets of users currently typing
|
# map room IDs to sets of users currently typing
|
||||||
@ -149,8 +149,7 @@ class FollowerTypingHandler:
|
|||||||
def process_replication_rows(
|
def process_replication_rows(
|
||||||
self, token: int, rows: List[TypingStream.TypingStreamRow]
|
self, token: int, rows: List[TypingStream.TypingStreamRow]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Should be called whenever we receive updates for typing stream.
|
"""Should be called whenever we receive updates for typing stream."""
|
||||||
"""
|
|
||||||
|
|
||||||
if self._latest_room_serial > token:
|
if self._latest_room_serial > token:
|
||||||
# The master has gone backwards. To prevent inconsistent data, just
|
# The master has gone backwards. To prevent inconsistent data, just
|
||||||
|
@ -97,8 +97,7 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def notify_new_event(self) -> None:
|
def notify_new_event(self) -> None:
|
||||||
"""Called when there may be more deltas to process
|
"""Called when there may be more deltas to process"""
|
||||||
"""
|
|
||||||
if not self.update_user_directory:
|
if not self.update_user_directory:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -134,8 +133,7 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def handle_user_deactivated(self, user_id: str) -> None:
|
async def handle_user_deactivated(self, user_id: str) -> None:
|
||||||
"""Called when a user ID is deactivated
|
"""Called when a user ID is deactivated"""
|
||||||
"""
|
|
||||||
# FIXME(#3714): We should probably do this in the same worker as all
|
# FIXME(#3714): We should probably do this in the same worker as all
|
||||||
# the other changes.
|
# the other changes.
|
||||||
await self.store.remove_from_user_dir(user_id)
|
await self.store.remove_from_user_dir(user_id)
|
||||||
@ -172,8 +170,7 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||||||
await self.store.update_user_directory_stream_pos(max_pos)
|
await self.store.update_user_directory_stream_pos(max_pos)
|
||||||
|
|
||||||
async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
|
async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
|
||||||
"""Called with the state deltas to process
|
"""Called with the state deltas to process"""
|
||||||
"""
|
|
||||||
for delta in deltas:
|
for delta in deltas:
|
||||||
typ = delta["type"]
|
typ = delta["type"]
|
||||||
state_key = delta["state_key"]
|
state_key = delta["state_key"]
|
||||||
|
@ -54,8 +54,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
|
|||||||
|
|
||||||
|
|
||||||
def get_request_user_agent(request: IRequest, default: str = "") -> str:
|
def get_request_user_agent(request: IRequest, default: str = "") -> str:
|
||||||
"""Return the last User-Agent header, or the given default.
|
"""Return the last User-Agent header, or the given default."""
|
||||||
"""
|
|
||||||
# There could be raw utf-8 bytes in the User-Agent header.
|
# There could be raw utf-8 bytes in the User-Agent header.
|
||||||
|
|
||||||
# N.B. if you don't do this, the logger explodes cryptically
|
# N.B. if you don't do this, the logger explodes cryptically
|
||||||
|
@ -398,7 +398,8 @@ class SimpleHttpClient:
|
|||||||
body_producer = None
|
body_producer = None
|
||||||
if data is not None:
|
if data is not None:
|
||||||
body_producer = QuieterFileBodyProducer(
|
body_producer = QuieterFileBodyProducer(
|
||||||
BytesIO(data), cooperator=self._cooperator,
|
BytesIO(data),
|
||||||
|
cooperator=self._cooperator,
|
||||||
)
|
)
|
||||||
|
|
||||||
request_deferred = treq.request(
|
request_deferred = treq.request(
|
||||||
@ -413,7 +414,9 @@ class SimpleHttpClient:
|
|||||||
# we use our own timeout mechanism rather than treq's as a workaround
|
# we use our own timeout mechanism rather than treq's as a workaround
|
||||||
# for https://twistedmatrix.com/trac/ticket/9534.
|
# for https://twistedmatrix.com/trac/ticket/9534.
|
||||||
request_deferred = timeout_deferred(
|
request_deferred = timeout_deferred(
|
||||||
request_deferred, 60, self.hs.get_reactor(),
|
request_deferred,
|
||||||
|
60,
|
||||||
|
self.hs.get_reactor(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# turn timeouts into RequestTimedOutErrors
|
# turn timeouts into RequestTimedOutErrors
|
||||||
|
@ -195,8 +195,7 @@ class MatrixFederationAgent:
|
|||||||
|
|
||||||
@implementer(IAgentEndpointFactory)
|
@implementer(IAgentEndpointFactory)
|
||||||
class MatrixHostnameEndpointFactory:
|
class MatrixHostnameEndpointFactory:
|
||||||
"""Factory for MatrixHostnameEndpoint for parsing to an Agent.
|
"""Factory for MatrixHostnameEndpoint for parsing to an Agent."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -261,8 +260,7 @@ class MatrixHostnameEndpoint:
|
|||||||
self._srv_resolver = srv_resolver
|
self._srv_resolver = srv_resolver
|
||||||
|
|
||||||
def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
|
def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
|
||||||
"""Implements IStreamClientEndpoint interface
|
"""Implements IStreamClientEndpoint interface"""
|
||||||
"""
|
|
||||||
|
|
||||||
return run_in_background(self._do_connect, protocol_factory)
|
return run_in_background(self._do_connect, protocol_factory)
|
||||||
|
|
||||||
|
@ -81,8 +81,7 @@ class WellKnownLookupResult:
|
|||||||
|
|
||||||
|
|
||||||
class WellKnownResolver:
|
class WellKnownResolver:
|
||||||
"""Handles well-known lookups for matrix servers.
|
"""Handles well-known lookups for matrix servers."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -254,7 +254,8 @@ class MatrixFederationHttpClient:
|
|||||||
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
|
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
|
||||||
# blacklist via IP literals in server names
|
# blacklist via IP literals in server names
|
||||||
self.agent = BlacklistingAgentWrapper(
|
self.agent = BlacklistingAgentWrapper(
|
||||||
self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
|
self.agent,
|
||||||
|
ip_blacklist=hs.config.federation_ip_range_blacklist,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
@ -799,7 +800,11 @@ class MatrixFederationHttpClient:
|
|||||||
_sec_timeout = self.default_timeout
|
_sec_timeout = self.default_timeout
|
||||||
|
|
||||||
body = await _handle_json_response(
|
body = await _handle_json_response(
|
||||||
self.reactor, _sec_timeout, request, response, start_ms,
|
self.reactor,
|
||||||
|
_sec_timeout,
|
||||||
|
request,
|
||||||
|
response,
|
||||||
|
start_ms,
|
||||||
)
|
)
|
||||||
return body
|
return body
|
||||||
|
|
||||||
@ -994,7 +999,10 @@ class MatrixFederationHttpClient:
|
|||||||
except BodyExceededMaxSize:
|
except BodyExceededMaxSize:
|
||||||
msg = "Requested file is too large > %r bytes" % (max_size,)
|
msg = "Requested file is too large > %r bytes" % (max_size,)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"{%s} [%s] %s", request.txn_id, request.destination, msg,
|
"{%s} [%s] %s",
|
||||||
|
request.txn_id,
|
||||||
|
request.destination,
|
||||||
|
msg,
|
||||||
)
|
)
|
||||||
raise SynapseError(502, msg, Codes.TOO_LARGE)
|
raise SynapseError(502, msg, Codes.TOO_LARGE)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -213,8 +213,7 @@ class RequestMetrics:
|
|||||||
self.update_metrics()
|
self.update_metrics()
|
||||||
|
|
||||||
def update_metrics(self):
|
def update_metrics(self):
|
||||||
"""Updates the in flight metrics with values from this request.
|
"""Updates the in flight metrics with values from this request."""
|
||||||
"""
|
|
||||||
new_stats = self.start_context.get_resource_usage()
|
new_stats = self.start_context.get_resource_usage()
|
||||||
|
|
||||||
diff = new_stats - self._request_stats
|
diff = new_stats - self._request_stats
|
||||||
|
@ -76,8 +76,7 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
|
|||||||
|
|
||||||
|
|
||||||
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
||||||
"""Sends a JSON error response to clients.
|
"""Sends a JSON error response to clients."""
|
||||||
"""
|
|
||||||
|
|
||||||
if f.check(SynapseError):
|
if f.check(SynapseError):
|
||||||
error_code = f.value.code
|
error_code = f.value.code
|
||||||
@ -106,12 +105,17 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
respond_with_json(
|
respond_with_json(
|
||||||
request, error_code, error_dict, send_cors=True,
|
request,
|
||||||
|
error_code,
|
||||||
|
error_dict,
|
||||||
|
send_cors=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def return_html_error(
|
def return_html_error(
|
||||||
f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template],
|
f: failure.Failure,
|
||||||
|
request: Request,
|
||||||
|
error_template: Union[str, jinja2.Template],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Sends an HTML error page corresponding to the given failure.
|
"""Sends an HTML error page corresponding to the given failure.
|
||||||
|
|
||||||
@ -189,8 +193,7 @@ ServletCallback = Callable[
|
|||||||
|
|
||||||
|
|
||||||
class HttpServer(Protocol):
|
class HttpServer(Protocol):
|
||||||
""" Interface for registering callbacks on a HTTP server
|
"""Interface for registering callbacks on a HTTP server"""
|
||||||
"""
|
|
||||||
|
|
||||||
def register_paths(
|
def register_paths(
|
||||||
self,
|
self,
|
||||||
@ -235,8 +238,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
|||||||
self._extract_context = extract_context
|
self._extract_context = extract_context
|
||||||
|
|
||||||
def render(self, request):
|
def render(self, request):
|
||||||
""" This gets called by twisted every time someone sends us a request.
|
"""This gets called by twisted every time someone sends us a request."""
|
||||||
"""
|
|
||||||
defer.ensureDeferred(self._async_render_wrapper(request))
|
defer.ensureDeferred(self._async_render_wrapper(request))
|
||||||
return NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
|
|
||||||
@ -287,13 +289,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _send_response(
|
def _send_response(
|
||||||
self, request: SynapseRequest, code: int, response_object: Any,
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
code: int,
|
||||||
|
response_object: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _send_error_response(
|
def _send_error_response(
|
||||||
self, f: failure.Failure, request: SynapseRequest,
|
self,
|
||||||
|
f: failure.Failure,
|
||||||
|
request: SynapseRequest,
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -308,10 +315,12 @@ class DirectServeJsonResource(_AsyncResource):
|
|||||||
self.canonical_json = canonical_json
|
self.canonical_json = canonical_json
|
||||||
|
|
||||||
def _send_response(
|
def _send_response(
|
||||||
self, request: Request, code: int, response_object: Any,
|
self,
|
||||||
|
request: Request,
|
||||||
|
code: int,
|
||||||
|
response_object: Any,
|
||||||
):
|
):
|
||||||
"""Implements _AsyncResource._send_response
|
"""Implements _AsyncResource._send_response"""
|
||||||
"""
|
|
||||||
# TODO: Only enable CORS for the requests that need it.
|
# TODO: Only enable CORS for the requests that need it.
|
||||||
respond_with_json(
|
respond_with_json(
|
||||||
request,
|
request,
|
||||||
@ -322,10 +331,11 @@ class DirectServeJsonResource(_AsyncResource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _send_error_response(
|
def _send_error_response(
|
||||||
self, f: failure.Failure, request: SynapseRequest,
|
self,
|
||||||
|
f: failure.Failure,
|
||||||
|
request: SynapseRequest,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Implements _AsyncResource._send_error_response
|
"""Implements _AsyncResource._send_error_response"""
|
||||||
"""
|
|
||||||
return_json_error(f, request)
|
return_json_error(f, request)
|
||||||
|
|
||||||
|
|
||||||
@ -443,10 +453,12 @@ class DirectServeHtmlResource(_AsyncResource):
|
|||||||
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
|
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
|
||||||
|
|
||||||
def _send_response(
|
def _send_response(
|
||||||
self, request: SynapseRequest, code: int, response_object: Any,
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
code: int,
|
||||||
|
response_object: Any,
|
||||||
):
|
):
|
||||||
"""Implements _AsyncResource._send_response
|
"""Implements _AsyncResource._send_response"""
|
||||||
"""
|
|
||||||
# We expect to get bytes for us to write
|
# We expect to get bytes for us to write
|
||||||
assert isinstance(response_object, bytes)
|
assert isinstance(response_object, bytes)
|
||||||
html_bytes = response_object
|
html_bytes = response_object
|
||||||
@ -454,10 +466,11 @@ class DirectServeHtmlResource(_AsyncResource):
|
|||||||
respond_with_html_bytes(request, 200, html_bytes)
|
respond_with_html_bytes(request, 200, html_bytes)
|
||||||
|
|
||||||
def _send_error_response(
|
def _send_error_response(
|
||||||
self, f: failure.Failure, request: SynapseRequest,
|
self,
|
||||||
|
f: failure.Failure,
|
||||||
|
request: SynapseRequest,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Implements _AsyncResource._send_error_response
|
"""Implements _AsyncResource._send_error_response"""
|
||||||
"""
|
|
||||||
return_html_error(f, request, self.ERROR_TEMPLATE)
|
return_html_error(f, request, self.ERROR_TEMPLATE)
|
||||||
|
|
||||||
|
|
||||||
@ -534,7 +547,9 @@ class _ByteProducer:
|
|||||||
min_chunk_size = 1024
|
min_chunk_size = 1024
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, request: Request, iterator: Iterator[bytes],
|
self,
|
||||||
|
request: Request,
|
||||||
|
iterator: Iterator[bytes],
|
||||||
):
|
):
|
||||||
self._request = request
|
self._request = request
|
||||||
self._iterator = iterator
|
self._iterator = iterator
|
||||||
@ -654,7 +669,10 @@ def respond_with_json(
|
|||||||
|
|
||||||
|
|
||||||
def respond_with_json_bytes(
|
def respond_with_json_bytes(
|
||||||
request: Request, code: int, json_bytes: bytes, send_cors: bool = False,
|
request: Request,
|
||||||
|
code: int,
|
||||||
|
json_bytes: bytes,
|
||||||
|
send_cors: bool = False,
|
||||||
):
|
):
|
||||||
"""Sends encoded JSON in response to the given request.
|
"""Sends encoded JSON in response to the given request.
|
||||||
|
|
||||||
|
@ -249,8 +249,7 @@ class SynapseRequest(Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _finished_processing(self):
|
def _finished_processing(self):
|
||||||
"""Log the completion of this request and update the metrics
|
"""Log the completion of this request and update the metrics"""
|
||||||
"""
|
|
||||||
assert self.logcontext is not None
|
assert self.logcontext is not None
|
||||||
usage = self.logcontext.get_resource_usage()
|
usage = self.logcontext.get_resource_usage()
|
||||||
|
|
||||||
@ -276,7 +275,8 @@ class SynapseRequest(Request):
|
|||||||
# authenticated (e.g. and admin is puppetting a user) then we log both.
|
# authenticated (e.g. and admin is puppetting a user) then we log both.
|
||||||
if self.requester.user.to_string() != authenticated_entity:
|
if self.requester.user.to_string() != authenticated_entity:
|
||||||
authenticated_entity = "{},{}".format(
|
authenticated_entity = "{},{}".format(
|
||||||
authenticated_entity, self.requester.user.to_string(),
|
authenticated_entity,
|
||||||
|
self.requester.user.to_string(),
|
||||||
)
|
)
|
||||||
elif self.requester is not None:
|
elif self.requester is not None:
|
||||||
# This shouldn't happen, but we log it so we don't lose information
|
# This shouldn't happen, but we log it so we don't lose information
|
||||||
@ -322,8 +322,7 @@ class SynapseRequest(Request):
|
|||||||
logger.warning("Failed to stop metrics: %r", e)
|
logger.warning("Failed to stop metrics: %r", e)
|
||||||
|
|
||||||
def _should_log_request(self) -> bool:
|
def _should_log_request(self) -> bool:
|
||||||
"""Whether we should log at INFO that we processed the request.
|
"""Whether we should log at INFO that we processed the request."""
|
||||||
"""
|
|
||||||
if self.path == b"/health":
|
if self.path == b"/health":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -174,7 +174,9 @@ class RemoteHandler(logging.Handler):
|
|||||||
|
|
||||||
# Make a new producer and start it.
|
# Make a new producer and start it.
|
||||||
self._producer = LogProducer(
|
self._producer = LogProducer(
|
||||||
buffer=self._buffer, transport=result.transport, format=self.format,
|
buffer=self._buffer,
|
||||||
|
transport=result.transport,
|
||||||
|
format=self.format,
|
||||||
)
|
)
|
||||||
result.transport.registerProducer(self._producer, True)
|
result.transport.registerProducer(self._producer, True)
|
||||||
self._producer.resumeProducing()
|
self._producer.resumeProducing()
|
||||||
|
@ -60,7 +60,10 @@ def parse_drain_configs(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Either use the default formatter or the tersejson one.
|
# Either use the default formatter or the tersejson one.
|
||||||
if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,):
|
if logging_type in (
|
||||||
|
DrainType.CONSOLE_JSON,
|
||||||
|
DrainType.FILE_JSON,
|
||||||
|
):
|
||||||
formatter = "json" # type: Optional[str]
|
formatter = "json" # type: Optional[str]
|
||||||
elif logging_type in (
|
elif logging_type in (
|
||||||
DrainType.CONSOLE_JSON_TERSE,
|
DrainType.CONSOLE_JSON_TERSE,
|
||||||
@ -131,7 +134,9 @@ def parse_drain_configs(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def setup_structured_logging(log_config: dict,) -> dict:
|
def setup_structured_logging(
|
||||||
|
log_config: dict,
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Convert a legacy structured logging configuration (from Synapse < v1.23.0)
|
Convert a legacy structured logging configuration (from Synapse < v1.23.0)
|
||||||
to one compatible with the new standard library handlers.
|
to one compatible with the new standard library handlers.
|
||||||
|
@ -338,7 +338,10 @@ class LoggingContext:
|
|||||||
if self.previous_context != old_context:
|
if self.previous_context != old_context:
|
||||||
logcontext_error(
|
logcontext_error(
|
||||||
"Expected previous context %r, found %r"
|
"Expected previous context %r, found %r"
|
||||||
% (self.previous_context, old_context,)
|
% (
|
||||||
|
self.previous_context,
|
||||||
|
old_context,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -585,7 +588,10 @@ class PreserveLoggingContext:
|
|||||||
else:
|
else:
|
||||||
logcontext_error(
|
logcontext_error(
|
||||||
"Expected logging context %s but found %s"
|
"Expected logging context %s but found %s"
|
||||||
% (self._new_context, context,)
|
% (
|
||||||
|
self._new_context,
|
||||||
|
context,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -238,8 +238,7 @@ try:
|
|||||||
|
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(slots=True, frozen=True)
|
||||||
class _WrappedRustReporter:
|
class _WrappedRustReporter:
|
||||||
"""Wrap the reporter to ensure `report_span` never throws.
|
"""Wrap the reporter to ensure `report_span` never throws."""
|
||||||
"""
|
|
||||||
|
|
||||||
_reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
|
_reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
|
||||||
|
|
||||||
@ -326,8 +325,7 @@ def noop_context_manager(*args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def init_tracer(hs: "HomeServer"):
|
def init_tracer(hs: "HomeServer"):
|
||||||
"""Set the whitelists and initialise the JaegerClient tracer
|
"""Set the whitelists and initialise the JaegerClient tracer"""
|
||||||
"""
|
|
||||||
global opentracing
|
global opentracing
|
||||||
if not hs.config.opentracer_enabled:
|
if not hs.config.opentracer_enabled:
|
||||||
# We don't have a tracer
|
# We don't have a tracer
|
||||||
|
@ -43,8 +43,7 @@ def _log_debug_as_f(f, msg, msg_args):
|
|||||||
|
|
||||||
|
|
||||||
def log_function(f):
|
def log_function(f):
|
||||||
""" Function decorator that logs every call to that function.
|
"""Function decorator that logs every call to that function."""
|
||||||
"""
|
|
||||||
func_name = f.__name__
|
func_name = f.__name__
|
||||||
|
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
|
@ -155,8 +155,7 @@ class InFlightGauge:
|
|||||||
self._registrations.setdefault(key, set()).add(callback)
|
self._registrations.setdefault(key, set()).add(callback)
|
||||||
|
|
||||||
def unregister(self, key, callback):
|
def unregister(self, key, callback):
|
||||||
"""Registers that we've exited a block with labels `key`.
|
"""Registers that we've exited a block with labels `key`."""
|
||||||
"""
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._registrations.setdefault(key, set()).discard(callback)
|
self._registrations.setdefault(key, set()).discard(callback)
|
||||||
@ -402,7 +401,9 @@ class PyPyGCStats:
|
|||||||
# Total time spent in GC: 0.073 # s.total_gc_time
|
# Total time spent in GC: 0.073 # s.total_gc_time
|
||||||
|
|
||||||
pypy_gc_time = CounterMetricFamily(
|
pypy_gc_time = CounterMetricFamily(
|
||||||
"pypy_gc_time_seconds_total", "Total time spent in PyPy GC", labels=[],
|
"pypy_gc_time_seconds_total",
|
||||||
|
"Total time spent in PyPy GC",
|
||||||
|
labels=[],
|
||||||
)
|
)
|
||||||
pypy_gc_time.add_metric([], s.total_gc_time / 1000)
|
pypy_gc_time.add_metric([], s.total_gc_time / 1000)
|
||||||
yield pypy_gc_time
|
yield pypy_gc_time
|
||||||
|
@ -208,7 +208,8 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
|
|||||||
return await maybe_awaitable(func(*args, **kwargs))
|
return await maybe_awaitable(func(*args, **kwargs))
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Background process '%s' threw an exception", desc,
|
"Background process '%s' threw an exception",
|
||||||
|
desc,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
_background_process_in_flight_count.labels(desc).dec()
|
_background_process_in_flight_count.labels(desc).dec()
|
||||||
@ -249,8 +250,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
|
|||||||
self._proc = _BackgroundProcess(name, self)
|
self._proc = _BackgroundProcess(name, self)
|
||||||
|
|
||||||
def start(self, rusage: "Optional[resource._RUsage]"):
|
def start(self, rusage: "Optional[resource._RUsage]"):
|
||||||
"""Log context has started running (again).
|
"""Log context has started running (again)."""
|
||||||
"""
|
|
||||||
|
|
||||||
super().start(rusage)
|
super().start(rusage)
|
||||||
|
|
||||||
@ -261,8 +261,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
|
|||||||
_background_processes_active_since_last_scrape.add(self._proc)
|
_background_processes_active_since_last_scrape.add(self._proc)
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback) -> None:
|
def __exit__(self, type, value, traceback) -> None:
|
||||||
"""Log context has finished.
|
"""Log context has finished."""
|
||||||
"""
|
|
||||||
|
|
||||||
super().__exit__(type, value, traceback)
|
super().__exit__(type, value, traceback)
|
||||||
|
|
||||||
|
@ -275,7 +275,9 @@ class ModuleApi:
|
|||||||
redirect them directly if whitelisted).
|
redirect them directly if whitelisted).
|
||||||
"""
|
"""
|
||||||
self._auth_handler._complete_sso_login(
|
self._auth_handler._complete_sso_login(
|
||||||
registered_user_id, request, client_redirect_url,
|
registered_user_id,
|
||||||
|
request,
|
||||||
|
client_redirect_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def complete_sso_login_async(
|
async def complete_sso_login_async(
|
||||||
@ -352,7 +354,10 @@ class ModuleApi:
|
|||||||
event,
|
event,
|
||||||
_,
|
_,
|
||||||
) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event(
|
) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event(
|
||||||
requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
|
requester,
|
||||||
|
event_dict,
|
||||||
|
ratelimit=False,
|
||||||
|
ignore_shadow_ban=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
@ -119,7 +119,10 @@ class _NotifierUserStream:
|
|||||||
self.notify_deferred = ObservableDeferred(defer.Deferred())
|
self.notify_deferred = ObservableDeferred(defer.Deferred())
|
||||||
|
|
||||||
def notify(
|
def notify(
|
||||||
self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int,
|
self,
|
||||||
|
stream_key: str,
|
||||||
|
stream_id: Union[int, RoomStreamToken],
|
||||||
|
time_now_ms: int,
|
||||||
):
|
):
|
||||||
"""Notify any listeners for this user of a new event from an
|
"""Notify any listeners for this user of a new event from an
|
||||||
event source.
|
event source.
|
||||||
@ -265,8 +268,7 @@ class Notifier:
|
|||||||
max_room_stream_token: RoomStreamToken,
|
max_room_stream_token: RoomStreamToken,
|
||||||
extra_users: Collection[UserID] = [],
|
extra_users: Collection[UserID] = [],
|
||||||
):
|
):
|
||||||
"""Unwraps event and calls `on_new_room_event_args`.
|
"""Unwraps event and calls `on_new_room_event_args`."""
|
||||||
"""
|
|
||||||
self.on_new_room_event_args(
|
self.on_new_room_event_args(
|
||||||
event_pos=event_pos,
|
event_pos=event_pos,
|
||||||
room_id=event.room_id,
|
room_id=event.room_id,
|
||||||
@ -341,7 +343,10 @@ class Notifier:
|
|||||||
|
|
||||||
if users or rooms:
|
if users or rooms:
|
||||||
self.on_new_event(
|
self.on_new_event(
|
||||||
"room_key", max_room_stream_token, users=users, rooms=rooms,
|
"room_key",
|
||||||
|
max_room_stream_token,
|
||||||
|
users=users,
|
||||||
|
rooms=rooms,
|
||||||
)
|
)
|
||||||
self._on_updated_room_token(max_room_stream_token)
|
self._on_updated_room_token(max_room_stream_token)
|
||||||
|
|
||||||
@ -418,7 +423,9 @@ class Notifier:
|
|||||||
|
|
||||||
# Notify appservices
|
# Notify appservices
|
||||||
self._notify_app_services_ephemeral(
|
self._notify_app_services_ephemeral(
|
||||||
stream_key, new_token, users,
|
stream_key,
|
||||||
|
new_token,
|
||||||
|
users,
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_new_replication_data(self) -> None:
|
def on_new_replication_data(self) -> None:
|
||||||
@ -651,8 +658,7 @@ class Notifier:
|
|||||||
cb()
|
cb()
|
||||||
|
|
||||||
def notify_remote_server_up(self, server: str):
|
def notify_remote_server_up(self, server: str):
|
||||||
"""Notify any replication that a remote server has come back up
|
"""Notify any replication that a remote server has come back up"""
|
||||||
"""
|
|
||||||
# We call federation_sender directly rather than registering as a
|
# We call federation_sender directly rather than registering as a
|
||||||
# callback as a) we already have a reference to it and b) it introduces
|
# callback as a) we already have a reference to it and b) it introduces
|
||||||
# circular dependencies.
|
# circular dependencies.
|
||||||
|
@ -144,8 +144,7 @@ class BulkPushRuleEvaluator:
|
|||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
|
def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
|
||||||
"""Get the current RulesForRoom object for the given room id
|
"""Get the current RulesForRoom object for the given room id"""
|
||||||
"""
|
|
||||||
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
|
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
|
||||||
# before any lookup methods get called on it as otherwise there may be
|
# before any lookup methods get called on it as otherwise there may be
|
||||||
# a race if invalidate_all gets called (which assumes its in the cache)
|
# a race if invalidate_all gets called (which assumes its in the cache)
|
||||||
@ -252,7 +251,9 @@ class BulkPushRuleEvaluator:
|
|||||||
# notified for this event. (This will then get handled when we persist
|
# notified for this event. (This will then get handled when we persist
|
||||||
# the event)
|
# the event)
|
||||||
await self.store.add_push_actions_to_staging(
|
await self.store.add_push_actions_to_staging(
|
||||||
event.event_id, actions_by_user, count_as_unread,
|
event.event_id,
|
||||||
|
actions_by_user,
|
||||||
|
count_as_unread,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,8 +116,7 @@ class EmailPusher(Pusher):
|
|||||||
self._is_processing = True
|
self._is_processing = True
|
||||||
|
|
||||||
def _resume_processing(self) -> None:
|
def _resume_processing(self) -> None:
|
||||||
"""Used by tests to resume processing of events after pausing.
|
"""Used by tests to resume processing of events after pausing."""
|
||||||
"""
|
|
||||||
assert self._is_processing
|
assert self._is_processing
|
||||||
self._is_processing = False
|
self._is_processing = False
|
||||||
self._start_processing()
|
self._start_processing()
|
||||||
@ -157,9 +156,11 @@ class EmailPusher(Pusher):
|
|||||||
being run.
|
being run.
|
||||||
"""
|
"""
|
||||||
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
|
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
|
||||||
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
|
unprocessed = (
|
||||||
|
await self.store.get_unread_push_actions_for_user_in_range_for_email(
|
||||||
self.user_id, start, self.max_stream_ordering
|
self.user_id, start, self.max_stream_ordering
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
soonest_due_at = None # type: Optional[int]
|
soonest_due_at = None # type: Optional[int]
|
||||||
|
|
||||||
@ -222,13 +223,15 @@ class EmailPusher(Pusher):
|
|||||||
self, last_stream_ordering: int
|
self, last_stream_ordering: int
|
||||||
) -> None:
|
) -> None:
|
||||||
self.last_stream_ordering = last_stream_ordering
|
self.last_stream_ordering = last_stream_ordering
|
||||||
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
|
pusher_still_exists = (
|
||||||
|
await self.store.update_pusher_last_stream_ordering_and_success(
|
||||||
self.app_id,
|
self.app_id,
|
||||||
self.email,
|
self.email,
|
||||||
self.user_id,
|
self.user_id,
|
||||||
last_stream_ordering,
|
last_stream_ordering,
|
||||||
self.clock.time_msec(),
|
self.clock.time_msec(),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
if not pusher_still_exists:
|
if not pusher_still_exists:
|
||||||
# The pusher has been deleted while we were processing, so
|
# The pusher has been deleted while we were processing, so
|
||||||
# lets just stop and return.
|
# lets just stop and return.
|
||||||
@ -298,7 +301,8 @@ class EmailPusher(Pusher):
|
|||||||
current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
|
current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
|
||||||
)
|
)
|
||||||
self.throttle_params[room_id] = ThrottleParams(
|
self.throttle_params[room_id] = ThrottleParams(
|
||||||
self.clock.time_msec(), new_throttle_ms,
|
self.clock.time_msec(),
|
||||||
|
new_throttle_ms,
|
||||||
)
|
)
|
||||||
assert self.pusher_id is not None
|
assert self.pusher_id is not None
|
||||||
await self.store.set_throttle_params(
|
await self.store.set_throttle_params(
|
||||||
|
@ -176,9 +176,11 @@ class HttpPusher(Pusher):
|
|||||||
Never call this directly: use _process which will only allow this to
|
Never call this directly: use _process which will only allow this to
|
||||||
run once per pusher.
|
run once per pusher.
|
||||||
"""
|
"""
|
||||||
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
|
unprocessed = (
|
||||||
|
await self.store.get_unread_push_actions_for_user_in_range_for_http(
|
||||||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Processing %i unprocessed push actions for %s starting at "
|
"Processing %i unprocessed push actions for %s starting at "
|
||||||
@ -204,13 +206,15 @@ class HttpPusher(Pusher):
|
|||||||
http_push_processed_counter.inc()
|
http_push_processed_counter.inc()
|
||||||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
||||||
self.last_stream_ordering = push_action["stream_ordering"]
|
self.last_stream_ordering = push_action["stream_ordering"]
|
||||||
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
|
pusher_still_exists = (
|
||||||
|
await self.store.update_pusher_last_stream_ordering_and_success(
|
||||||
self.app_id,
|
self.app_id,
|
||||||
self.pushkey,
|
self.pushkey,
|
||||||
self.user_id,
|
self.user_id,
|
||||||
self.last_stream_ordering,
|
self.last_stream_ordering,
|
||||||
self.clock.time_msec(),
|
self.clock.time_msec(),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
if not pusher_still_exists:
|
if not pusher_still_exists:
|
||||||
# The pusher has been deleted while we were processing, so
|
# The pusher has been deleted while we were processing, so
|
||||||
# lets just stop and return.
|
# lets just stop and return.
|
||||||
@ -290,7 +294,8 @@ class HttpPusher(Pusher):
|
|||||||
# for sanity, we only remove the pushkey if it
|
# for sanity, we only remove the pushkey if it
|
||||||
# was the one we actually sent...
|
# was the one we actually sent...
|
||||||
logger.warning(
|
logger.warning(
|
||||||
("Ignoring rejected pushkey %s because we didn't send it"), pk,
|
("Ignoring rejected pushkey %s because we didn't send it"),
|
||||||
|
pk,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Pushkey %s was rejected: removing", pk)
|
logger.info("Pushkey %s was rejected: removing", pk)
|
||||||
|
@ -78,8 +78,7 @@ class PusherPool:
|
|||||||
self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
|
self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
"""Starts the pushers off in a background process.
|
"""Starts the pushers off in a background process."""
|
||||||
"""
|
|
||||||
if not self._should_start_pushers:
|
if not self._should_start_pushers:
|
||||||
logger.info("Not starting pushers because they are disabled in the config")
|
logger.info("Not starting pushers because they are disabled in the config")
|
||||||
return
|
return
|
||||||
@ -297,8 +296,7 @@ class PusherPool:
|
|||||||
return pusher
|
return pusher
|
||||||
|
|
||||||
async def _start_pushers(self) -> None:
|
async def _start_pushers(self) -> None:
|
||||||
"""Start all the pushers
|
"""Start all the pushers"""
|
||||||
"""
|
|
||||||
pushers = await self.store.get_all_pushers()
|
pushers = await self.store.get_all_pushers()
|
||||||
|
|
||||||
# Stagger starting up the pushers so we don't completely drown the
|
# Stagger starting up the pushers so we don't completely drown the
|
||||||
@ -335,7 +333,8 @@ class PusherPool:
|
|||||||
return None
|
return None
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Couldn't start pusher id %i: caught Exception", pusher_config.id,
|
"Couldn't start pusher id %i: caught Exception",
|
||||||
|
pusher_config.id,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -273,7 +273,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|||||||
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
|
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
|
||||||
|
|
||||||
http_server.register_paths(
|
http_server.register_paths(
|
||||||
method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
|
method,
|
||||||
|
[pattern],
|
||||||
|
self._check_auth_and_handle,
|
||||||
|
self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_auth_and_handle(self, request, **kwargs):
|
def _check_auth_and_handle(self, request, **kwargs):
|
||||||
|
@ -175,7 +175,11 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _handle_request(self, request, user_id, room_id, tag):
|
async def _handle_request(self, request, user_id, room_id, tag):
|
||||||
max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
|
max_stream_id = await self.handler.remove_tag_from_room(
|
||||||
|
user_id,
|
||||||
|
room_id,
|
||||||
|
tag,
|
||||||
|
)
|
||||||
|
|
||||||
return 200, {"max_stream_id": max_stream_id}
|
return 200, {"max_stream_id": max_stream_id}
|
||||||
|
|
||||||
|
@ -160,7 +160,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
|||||||
|
|
||||||
# hopefully we're now on the master, so this won't recurse!
|
# hopefully we're now on the master, so this won't recurse!
|
||||||
event_id, stream_id = await self.member_handler.remote_reject_invite(
|
event_id, stream_id = await self.member_handler.remote_reject_invite(
|
||||||
invite_event_id, txn_id, requester, event_content,
|
invite_event_id,
|
||||||
|
txn_id,
|
||||||
|
requester,
|
||||||
|
event_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, {"event_id": event_id, "stream_id": stream_id}
|
return 200, {"event_id": event_id, "stream_id": stream_id}
|
||||||
|
@ -22,8 +22,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ReplicationRegisterServlet(ReplicationEndpoint):
|
class ReplicationRegisterServlet(ReplicationEndpoint):
|
||||||
"""Register a new user
|
"""Register a new user"""
|
||||||
"""
|
|
||||||
|
|
||||||
NAME = "register_user"
|
NAME = "register_user"
|
||||||
PATH_ARGS = ("user_id",)
|
PATH_ARGS = ("user_id",)
|
||||||
@ -97,8 +96,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
|
|||||||
|
|
||||||
|
|
||||||
class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
|
class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
|
||||||
"""Run any post registration actions
|
"""Run any post registration actions"""
|
||||||
"""
|
|
||||||
|
|
||||||
NAME = "post_register"
|
NAME = "post_register"
|
||||||
PATH_ARGS = ("user_id",)
|
PATH_ARGS = ("user_id",)
|
||||||
|
@ -196,8 +196,7 @@ class ErrorCommand(_SimpleCommand):
|
|||||||
|
|
||||||
|
|
||||||
class PingCommand(_SimpleCommand):
|
class PingCommand(_SimpleCommand):
|
||||||
"""Sent by either side as a keep alive. The data is arbitrary (often timestamp)
|
"""Sent by either side as a keep alive. The data is arbitrary (often timestamp)"""
|
||||||
"""
|
|
||||||
|
|
||||||
NAME = "PING"
|
NAME = "PING"
|
||||||
|
|
||||||
|
@ -60,8 +60,7 @@ class ExternalCache:
|
|||||||
return self._redis_connection is not None
|
return self._redis_connection is not None
|
||||||
|
|
||||||
async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
|
async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
|
||||||
"""Add the key/value to the named cache, with the expiry time given.
|
"""Add the key/value to the named cache, with the expiry time given."""
|
||||||
"""
|
|
||||||
|
|
||||||
if self._redis_connection is None:
|
if self._redis_connection is None:
|
||||||
return
|
return
|
||||||
@ -76,13 +75,14 @@ class ExternalCache:
|
|||||||
|
|
||||||
return await make_deferred_yieldable(
|
return await make_deferred_yieldable(
|
||||||
self._redis_connection.set(
|
self._redis_connection.set(
|
||||||
self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
|
self._get_redis_key(cache_name, key),
|
||||||
|
encoded_value,
|
||||||
|
pexpire=expiry_ms,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get(self, cache_name: str, key: str) -> Optional[Any]:
|
async def get(self, cache_name: str, key: str) -> Optional[Any]:
|
||||||
"""Look up a key/value in the named cache.
|
"""Look up a key/value in the named cache."""
|
||||||
"""
|
|
||||||
|
|
||||||
if self._redis_connection is None:
|
if self._redis_connection is None:
|
||||||
return None
|
return None
|
||||||
|
@ -303,7 +303,9 @@ class ReplicationCommandHandler:
|
|||||||
hs, outbound_redis_connection
|
hs, outbound_redis_connection
|
||||||
)
|
)
|
||||||
hs.get_reactor().connectTCP(
|
hs.get_reactor().connectTCP(
|
||||||
hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
|
hs.config.redis.redis_host,
|
||||||
|
hs.config.redis.redis_port,
|
||||||
|
self._factory,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
client_name = hs.get_instance_name()
|
client_name = hs.get_instance_name()
|
||||||
@ -313,13 +315,11 @@ class ReplicationCommandHandler:
|
|||||||
hs.get_reactor().connectTCP(host, port, self._factory)
|
hs.get_reactor().connectTCP(host, port, self._factory)
|
||||||
|
|
||||||
def get_streams(self) -> Dict[str, Stream]:
|
def get_streams(self) -> Dict[str, Stream]:
|
||||||
"""Get a map from stream name to all streams.
|
"""Get a map from stream name to all streams."""
|
||||||
"""
|
|
||||||
return self._streams
|
return self._streams
|
||||||
|
|
||||||
def get_streams_to_replicate(self) -> List[Stream]:
|
def get_streams_to_replicate(self) -> List[Stream]:
|
||||||
"""Get a list of streams that this instances replicates.
|
"""Get a list of streams that this instances replicates."""
|
||||||
"""
|
|
||||||
return self._streams_to_replicate
|
return self._streams_to_replicate
|
||||||
|
|
||||||
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
|
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
|
||||||
@ -340,7 +340,10 @@ class ReplicationCommandHandler:
|
|||||||
current_token = stream.current_token(self._instance_name)
|
current_token = stream.current_token(self._instance_name)
|
||||||
self.send_command(
|
self.send_command(
|
||||||
PositionCommand(
|
PositionCommand(
|
||||||
stream.NAME, self._instance_name, current_token, current_token,
|
stream.NAME,
|
||||||
|
self._instance_name,
|
||||||
|
current_token,
|
||||||
|
current_token,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -592,8 +595,7 @@ class ReplicationCommandHandler:
|
|||||||
self.send_command(cmd, ignore_conn=conn)
|
self.send_command(cmd, ignore_conn=conn)
|
||||||
|
|
||||||
def new_connection(self, connection: AbstractConnection):
|
def new_connection(self, connection: AbstractConnection):
|
||||||
"""Called when we have a new connection.
|
"""Called when we have a new connection."""
|
||||||
"""
|
|
||||||
self._connections.append(connection)
|
self._connections.append(connection)
|
||||||
|
|
||||||
# If we are connected to replication as a client (rather than a server)
|
# If we are connected to replication as a client (rather than a server)
|
||||||
@ -620,8 +622,7 @@ class ReplicationCommandHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def lost_connection(self, connection: AbstractConnection):
|
def lost_connection(self, connection: AbstractConnection):
|
||||||
"""Called when a connection is closed/lost.
|
"""Called when a connection is closed/lost."""
|
||||||
"""
|
|
||||||
# we no longer need _streams_by_connection for this connection.
|
# we no longer need _streams_by_connection for this connection.
|
||||||
streams = self._streams_by_connection.pop(connection, None)
|
streams = self._streams_by_connection.pop(connection, None)
|
||||||
if streams:
|
if streams:
|
||||||
@ -678,15 +679,13 @@ class ReplicationCommandHandler:
|
|||||||
def send_user_sync(
|
def send_user_sync(
|
||||||
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
|
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
|
||||||
):
|
):
|
||||||
"""Poke the master that a user has started/stopped syncing.
|
"""Poke the master that a user has started/stopped syncing."""
|
||||||
"""
|
|
||||||
self.send_command(
|
self.send_command(
|
||||||
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
|
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
|
def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
|
||||||
"""Poke the master to remove a pusher for a user
|
"""Poke the master to remove a pusher for a user"""
|
||||||
"""
|
|
||||||
cmd = RemovePusherCommand(app_id, push_key, user_id)
|
cmd = RemovePusherCommand(app_id, push_key, user_id)
|
||||||
self.send_command(cmd)
|
self.send_command(cmd)
|
||||||
|
|
||||||
@ -699,8 +698,7 @@ class ReplicationCommandHandler:
|
|||||||
device_id: str,
|
device_id: str,
|
||||||
last_seen: int,
|
last_seen: int,
|
||||||
):
|
):
|
||||||
"""Tell the master that the user made a request.
|
"""Tell the master that the user made a request."""
|
||||||
"""
|
|
||||||
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
|
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
|
||||||
self.send_command(cmd)
|
self.send_command(cmd)
|
||||||
|
|
||||||
|
@ -222,8 +222,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
self.send_error("ping timeout")
|
self.send_error("ping timeout")
|
||||||
|
|
||||||
def lineReceived(self, line: bytes):
|
def lineReceived(self, line: bytes):
|
||||||
"""Called when we've received a line
|
"""Called when we've received a line"""
|
||||||
"""
|
|
||||||
with PreserveLoggingContext(self._logging_context):
|
with PreserveLoggingContext(self._logging_context):
|
||||||
self._parse_and_dispatch_line(line)
|
self._parse_and_dispatch_line(line)
|
||||||
|
|
||||||
@ -299,8 +298,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
self.on_connection_closed()
|
self.on_connection_closed()
|
||||||
|
|
||||||
def send_error(self, error_string, *args):
|
def send_error(self, error_string, *args):
|
||||||
"""Send an error to remote and close the connection.
|
"""Send an error to remote and close the connection."""
|
||||||
"""
|
|
||||||
self.send_command(ErrorCommand(error_string % args))
|
self.send_command(ErrorCommand(error_string % args))
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
@ -341,8 +339,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
self.last_sent_command = self.clock.time_msec()
|
self.last_sent_command = self.clock.time_msec()
|
||||||
|
|
||||||
def _queue_command(self, cmd):
|
def _queue_command(self, cmd):
|
||||||
"""Queue the command until the connection is ready to write to again.
|
"""Queue the command until the connection is ready to write to again."""
|
||||||
"""
|
|
||||||
logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
|
logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
|
||||||
self.pending_commands.append(cmd)
|
self.pending_commands.append(cmd)
|
||||||
|
|
||||||
@ -355,8 +352,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def _send_pending_commands(self):
|
def _send_pending_commands(self):
|
||||||
"""Send any queued commandes
|
"""Send any queued commandes"""
|
||||||
"""
|
|
||||||
pending = self.pending_commands
|
pending = self.pending_commands
|
||||||
self.pending_commands = []
|
self.pending_commands = []
|
||||||
for cmd in pending:
|
for cmd in pending:
|
||||||
@ -380,8 +376,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
self.state = ConnectionStates.PAUSED
|
self.state = ConnectionStates.PAUSED
|
||||||
|
|
||||||
def resumeProducing(self):
|
def resumeProducing(self):
|
||||||
"""The remote has caught up after we started buffering!
|
"""The remote has caught up after we started buffering!"""
|
||||||
"""
|
|
||||||
logger.info("[%s] Resume producing", self.id())
|
logger.info("[%s] Resume producing", self.id())
|
||||||
self.state = ConnectionStates.ESTABLISHED
|
self.state = ConnectionStates.ESTABLISHED
|
||||||
self._send_pending_commands()
|
self._send_pending_commands()
|
||||||
@ -440,8 +435,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
return "%s-%s" % (self.name, self.conn_id)
|
return "%s-%s" % (self.name, self.conn_id)
|
||||||
|
|
||||||
def lineLengthExceeded(self, line):
|
def lineLengthExceeded(self, line):
|
||||||
"""Called when we receive a line that is above the maximum line length
|
"""Called when we receive a line that is above the maximum line length"""
|
||||||
"""
|
|
||||||
self.send_error("Line length exceeded")
|
self.send_error("Line length exceeded")
|
||||||
|
|
||||||
|
|
||||||
@ -495,21 +489,18 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||||||
self.send_error("Wrong remote")
|
self.send_error("Wrong remote")
|
||||||
|
|
||||||
def replicate(self):
|
def replicate(self):
|
||||||
"""Send the subscription request to the server
|
"""Send the subscription request to the server"""
|
||||||
"""
|
|
||||||
logger.info("[%s] Subscribing to replication streams", self.id())
|
logger.info("[%s] Subscribing to replication streams", self.id())
|
||||||
|
|
||||||
self.send_command(ReplicateCommand())
|
self.send_command(ReplicateCommand())
|
||||||
|
|
||||||
|
|
||||||
class AbstractConnection(abc.ABC):
|
class AbstractConnection(abc.ABC):
|
||||||
"""An interface for replication connections.
|
"""An interface for replication connections."""
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def send_command(self, cmd: Command):
|
def send_command(self, cmd: Command):
|
||||||
"""Send the command down the connection
|
"""Send the command down the connection"""
|
||||||
"""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user