mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-11 21:54:18 -05:00
Convert synapse.api to async/await (#8031)
This commit is contained in:
parent
c36228c403
commit
d4a7829b12
1
changelog.d/8031.misc
Normal file
1
changelog.d/8031.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Convert various parts of the codebase to async/await.
|
@ -13,12 +13,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
from netaddr import IPAddress
|
from netaddr import IPAddress
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
|
||||||
import synapse.types
|
import synapse.types
|
||||||
@ -80,13 +79,14 @@ class Auth(object):
|
|||||||
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
|
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
|
||||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def check_from_context(
|
||||||
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
|
self, room_version: str, event, context, do_sig_check=True
|
||||||
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
):
|
||||||
auth_events_ids = yield self.compute_auth_events(
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
|
auth_events_ids = self.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=True
|
event, prev_state_ids, for_verification=True
|
||||||
)
|
)
|
||||||
auth_events = yield self.store.get_events(auth_events_ids)
|
auth_events = await self.store.get_events(auth_events_ids)
|
||||||
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
||||||
|
|
||||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||||
@ -94,14 +94,13 @@ class Auth(object):
|
|||||||
room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
|
room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def check_user_in_room(
|
||||||
def check_user_in_room(
|
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
current_state: Optional[StateMap[EventBase]] = None,
|
current_state: Optional[StateMap[EventBase]] = None,
|
||||||
allow_departed_users: bool = False,
|
allow_departed_users: bool = False,
|
||||||
):
|
) -> EventBase:
|
||||||
"""Check if the user is in the room, or was at some point.
|
"""Check if the user is in the room, or was at some point.
|
||||||
Args:
|
Args:
|
||||||
room_id: The room to check.
|
room_id: The room to check.
|
||||||
@ -119,37 +118,35 @@ class Auth(object):
|
|||||||
Raises:
|
Raises:
|
||||||
AuthError if the user is/was not in the room.
|
AuthError if the user is/was not in the room.
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[Optional[EventBase]]:
|
Membership event for the user if the user was in the
|
||||||
Membership event for the user if the user was in the
|
room. This will be the join event if they are currently joined to
|
||||||
room. This will be the join event if they are currently joined to
|
the room. This will be the leave event if they have left the room.
|
||||||
the room. This will be the leave event if they have left the room.
|
|
||||||
"""
|
"""
|
||||||
if current_state:
|
if current_state:
|
||||||
member = current_state.get((EventTypes.Member, user_id), None)
|
member = current_state.get((EventTypes.Member, user_id), None)
|
||||||
else:
|
else:
|
||||||
member = yield defer.ensureDeferred(
|
member = await self.state.get_current_state(
|
||||||
self.state.get_current_state(
|
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
||||||
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
membership = member.membership if member else None
|
|
||||||
|
|
||||||
if membership == Membership.JOIN:
|
if member:
|
||||||
return member
|
membership = member.membership
|
||||||
|
|
||||||
# XXX this looks totally bogus. Why do we not allow users who have been banned,
|
if membership == Membership.JOIN:
|
||||||
# or those who were members previously and have been re-invited?
|
|
||||||
if allow_departed_users and membership == Membership.LEAVE:
|
|
||||||
forgot = yield self.store.did_forget(user_id, room_id)
|
|
||||||
if not forgot:
|
|
||||||
return member
|
return member
|
||||||
|
|
||||||
|
# XXX this looks totally bogus. Why do we not allow users who have been banned,
|
||||||
|
# or those who were members previously and have been re-invited?
|
||||||
|
if allow_departed_users and membership == Membership.LEAVE:
|
||||||
|
forgot = await self.store.did_forget(user_id, room_id)
|
||||||
|
if not forgot:
|
||||||
|
return member
|
||||||
|
|
||||||
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
|
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def check_host_in_room(self, room_id, host):
|
||||||
def check_host_in_room(self, room_id, host):
|
|
||||||
with Measure(self.clock, "check_host_in_room"):
|
with Measure(self.clock, "check_host_in_room"):
|
||||||
latest_event_ids = yield self.store.is_host_joined(room_id, host)
|
latest_event_ids = await self.store.is_host_joined(room_id, host)
|
||||||
return latest_event_ids
|
return latest_event_ids
|
||||||
|
|
||||||
def can_federate(self, event, auth_events):
|
def can_federate(self, event, auth_events):
|
||||||
@ -160,14 +157,13 @@ class Auth(object):
|
|||||||
def get_public_keys(self, invite_event):
|
def get_public_keys(self, invite_event):
|
||||||
return event_auth.get_public_keys(invite_event)
|
return event_auth.get_public_keys(invite_event)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_user_by_req(
|
||||||
def get_user_by_req(
|
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: Request,
|
||||||
allow_guest: bool = False,
|
allow_guest: bool = False,
|
||||||
rights: str = "access",
|
rights: str = "access",
|
||||||
allow_expired: bool = False,
|
allow_expired: bool = False,
|
||||||
):
|
) -> synapse.types.Requester:
|
||||||
""" Get a registered user's ID.
|
""" Get a registered user's ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -180,7 +176,7 @@ class Auth(object):
|
|||||||
/login will deliver access tokens regardless of expiration.
|
/login will deliver access tokens regardless of expiration.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: resolves to a `synapse.types.Requester` object
|
Resolves to the requester
|
||||||
Raises:
|
Raises:
|
||||||
InvalidClientCredentialsError if no user by that token exists or the token
|
InvalidClientCredentialsError if no user by that token exists or the token
|
||||||
is invalid.
|
is invalid.
|
||||||
@ -194,14 +190,14 @@ class Auth(object):
|
|||||||
|
|
||||||
access_token = self.get_access_token_from_request(request)
|
access_token = self.get_access_token_from_request(request)
|
||||||
|
|
||||||
user_id, app_service = yield self._get_appservice_user_id(request)
|
user_id, app_service = await self._get_appservice_user_id(request)
|
||||||
if user_id:
|
if user_id:
|
||||||
request.authenticated_entity = user_id
|
request.authenticated_entity = user_id
|
||||||
opentracing.set_tag("authenticated_entity", user_id)
|
opentracing.set_tag("authenticated_entity", user_id)
|
||||||
opentracing.set_tag("appservice_id", app_service.id)
|
opentracing.set_tag("appservice_id", app_service.id)
|
||||||
|
|
||||||
if ip_addr and self._track_appservice_user_ips:
|
if ip_addr and self._track_appservice_user_ips:
|
||||||
yield self.store.insert_client_ip(
|
await self.store.insert_client_ip(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
ip=ip_addr,
|
ip=ip_addr,
|
||||||
@ -211,7 +207,7 @@ class Auth(object):
|
|||||||
|
|
||||||
return synapse.types.create_requester(user_id, app_service=app_service)
|
return synapse.types.create_requester(user_id, app_service=app_service)
|
||||||
|
|
||||||
user_info = yield self.get_user_by_access_token(
|
user_info = await self.get_user_by_access_token(
|
||||||
access_token, rights, allow_expired=allow_expired
|
access_token, rights, allow_expired=allow_expired
|
||||||
)
|
)
|
||||||
user = user_info["user"]
|
user = user_info["user"]
|
||||||
@ -221,7 +217,7 @@ class Auth(object):
|
|||||||
# Deny the request if the user account has expired.
|
# Deny the request if the user account has expired.
|
||||||
if self._account_validity.enabled and not allow_expired:
|
if self._account_validity.enabled and not allow_expired:
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
|
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
|
||||||
if (
|
if (
|
||||||
expiration_ts is not None
|
expiration_ts is not None
|
||||||
and self.clock.time_msec() >= expiration_ts
|
and self.clock.time_msec() >= expiration_ts
|
||||||
@ -235,7 +231,7 @@ class Auth(object):
|
|||||||
device_id = user_info.get("device_id")
|
device_id = user_info.get("device_id")
|
||||||
|
|
||||||
if user and access_token and ip_addr:
|
if user and access_token and ip_addr:
|
||||||
yield self.store.insert_client_ip(
|
await self.store.insert_client_ip(
|
||||||
user_id=user.to_string(),
|
user_id=user.to_string(),
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
ip=ip_addr,
|
ip=ip_addr,
|
||||||
@ -261,8 +257,7 @@ class Auth(object):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
raise MissingClientTokenError()
|
raise MissingClientTokenError()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _get_appservice_user_id(self, request):
|
||||||
def _get_appservice_user_id(self, request):
|
|
||||||
app_service = self.store.get_app_service_by_token(
|
app_service = self.store.get_app_service_by_token(
|
||||||
self.get_access_token_from_request(request)
|
self.get_access_token_from_request(request)
|
||||||
)
|
)
|
||||||
@ -283,14 +278,13 @@ class Auth(object):
|
|||||||
|
|
||||||
if not app_service.is_interested_in_user(user_id):
|
if not app_service.is_interested_in_user(user_id):
|
||||||
raise AuthError(403, "Application service cannot masquerade as this user.")
|
raise AuthError(403, "Application service cannot masquerade as this user.")
|
||||||
if not (yield self.store.get_user_by_id(user_id)):
|
if not (await self.store.get_user_by_id(user_id)):
|
||||||
raise AuthError(403, "Application service has not registered this user")
|
raise AuthError(403, "Application service has not registered this user")
|
||||||
return user_id, app_service
|
return user_id, app_service
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_user_by_access_token(
|
||||||
def get_user_by_access_token(
|
|
||||||
self, token: str, rights: str = "access", allow_expired: bool = False,
|
self, token: str, rights: str = "access", allow_expired: bool = False,
|
||||||
):
|
) -> dict:
|
||||||
""" Validate access token and get user_id from it
|
""" Validate access token and get user_id from it
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -300,7 +294,7 @@ class Auth(object):
|
|||||||
allow_expired: If False, raises an InvalidClientTokenError
|
allow_expired: If False, raises an InvalidClientTokenError
|
||||||
if the token is expired
|
if the token is expired
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict]: dict that includes:
|
dict that includes:
|
||||||
`user` (UserID)
|
`user` (UserID)
|
||||||
`is_guest` (bool)
|
`is_guest` (bool)
|
||||||
`token_id` (int|None): access token id. May be None if guest
|
`token_id` (int|None): access token id. May be None if guest
|
||||||
@ -314,7 +308,7 @@ class Auth(object):
|
|||||||
|
|
||||||
if rights == "access":
|
if rights == "access":
|
||||||
# first look in the database
|
# first look in the database
|
||||||
r = yield self._look_up_user_by_access_token(token)
|
r = await self._look_up_user_by_access_token(token)
|
||||||
if r:
|
if r:
|
||||||
valid_until_ms = r["valid_until_ms"]
|
valid_until_ms = r["valid_until_ms"]
|
||||||
if (
|
if (
|
||||||
@ -352,7 +346,7 @@ class Auth(object):
|
|||||||
# It would of course be much easier to store guest access
|
# It would of course be much easier to store guest access
|
||||||
# tokens in the database as well, but that would break existing
|
# tokens in the database as well, but that would break existing
|
||||||
# guest tokens.
|
# guest tokens.
|
||||||
stored_user = yield self.store.get_user_by_id(user_id)
|
stored_user = await self.store.get_user_by_id(user_id)
|
||||||
if not stored_user:
|
if not stored_user:
|
||||||
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
|
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
|
||||||
if not stored_user["is_guest"]:
|
if not stored_user["is_guest"]:
|
||||||
@ -482,9 +476,8 @@ class Auth(object):
|
|||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
return now < expiry
|
return now < expiry
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _look_up_user_by_access_token(self, token):
|
||||||
def _look_up_user_by_access_token(self, token):
|
ret = await self.store.get_user_by_access_token(token)
|
||||||
ret = yield self.store.get_user_by_access_token(token)
|
|
||||||
if not ret:
|
if not ret:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -507,7 +500,7 @@ class Auth(object):
|
|||||||
logger.warning("Unrecognised appservice access token.")
|
logger.warning("Unrecognised appservice access token.")
|
||||||
raise InvalidClientTokenError()
|
raise InvalidClientTokenError()
|
||||||
request.authenticated_entity = service.sender
|
request.authenticated_entity = service.sender
|
||||||
return defer.succeed(service)
|
return service
|
||||||
|
|
||||||
async def is_server_admin(self, user: UserID) -> bool:
|
async def is_server_admin(self, user: UserID) -> bool:
|
||||||
""" Check if the given user is a local server admin.
|
""" Check if the given user is a local server admin.
|
||||||
@ -522,7 +515,7 @@ class Auth(object):
|
|||||||
|
|
||||||
def compute_auth_events(
|
def compute_auth_events(
|
||||||
self, event, current_state_ids: StateMap[str], for_verification: bool = False,
|
self, event, current_state_ids: StateMap[str], for_verification: bool = False,
|
||||||
):
|
) -> List[str]:
|
||||||
"""Given an event and current state return the list of event IDs used
|
"""Given an event and current state return the list of event IDs used
|
||||||
to auth an event.
|
to auth an event.
|
||||||
|
|
||||||
@ -530,11 +523,11 @@ class Auth(object):
|
|||||||
should be added to the event's `auth_events`.
|
should be added to the event's `auth_events`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred(list[str]): List of event IDs.
|
List of event IDs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if event.type == EventTypes.Create:
|
if event.type == EventTypes.Create:
|
||||||
return defer.succeed([])
|
return []
|
||||||
|
|
||||||
# Currently we ignore the `for_verification` flag even though there are
|
# Currently we ignore the `for_verification` flag even though there are
|
||||||
# some situations where we can drop particular auth events when adding
|
# some situations where we can drop particular auth events when adding
|
||||||
@ -553,7 +546,7 @@ class Auth(object):
|
|||||||
if auth_ev_id:
|
if auth_ev_id:
|
||||||
auth_ids.append(auth_ev_id)
|
auth_ids.append(auth_ev_id)
|
||||||
|
|
||||||
return defer.succeed(auth_ids)
|
return auth_ids
|
||||||
|
|
||||||
async def check_can_change_room_list(self, room_id: str, user: UserID):
|
async def check_can_change_room_list(self, room_id: str, user: UserID):
|
||||||
"""Determine whether the user is allowed to edit the room's entry in the
|
"""Determine whether the user is allowed to edit the room's entry in the
|
||||||
@ -636,10 +629,9 @@ class Auth(object):
|
|||||||
|
|
||||||
return query_params[0].decode("ascii")
|
return query_params[0].decode("ascii")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def check_user_in_room_or_world_readable(
|
||||||
def check_user_in_room_or_world_readable(
|
|
||||||
self, room_id: str, user_id: str, allow_departed_users: bool = False
|
self, room_id: str, user_id: str, allow_departed_users: bool = False
|
||||||
):
|
) -> Tuple[str, Optional[str]]:
|
||||||
"""Checks that the user is or was in the room or the room is world
|
"""Checks that the user is or was in the room or the room is world
|
||||||
readable. If it isn't then an exception is raised.
|
readable. If it isn't then an exception is raised.
|
||||||
|
|
||||||
@ -650,10 +642,9 @@ class Auth(object):
|
|||||||
members but have now departed
|
members but have now departed
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[tuple[str, str|None]]: Resolves to the current membership of
|
Resolves to the current membership of the user in the room and the
|
||||||
the user in the room and the membership event ID of the user. If
|
membership event ID of the user. If the user is not in the room and
|
||||||
the user is not in the room and never has been, then
|
never has been, then `(Membership.JOIN, None)` is returned.
|
||||||
`(Membership.JOIN, None)` is returned.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -662,15 +653,13 @@ class Auth(object):
|
|||||||
# * The user is a non-guest user, and was ever in the room
|
# * The user is a non-guest user, and was ever in the room
|
||||||
# * The user is a guest user, and has joined the room
|
# * The user is a guest user, and has joined the room
|
||||||
# else it will throw.
|
# else it will throw.
|
||||||
member_event = yield self.check_user_in_room(
|
member_event = await self.check_user_in_room(
|
||||||
room_id, user_id, allow_departed_users=allow_departed_users
|
room_id, user_id, allow_departed_users=allow_departed_users
|
||||||
)
|
)
|
||||||
return member_event.membership, member_event.event_id
|
return member_event.membership, member_event.event_id
|
||||||
except AuthError:
|
except AuthError:
|
||||||
visibility = yield defer.ensureDeferred(
|
visibility = await self.state.get_current_state(
|
||||||
self.state.get_current_state(
|
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
visibility
|
visibility
|
||||||
|
@ -15,8 +15,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import LimitBlockingTypes, UserTypes
|
from synapse.api.constants import LimitBlockingTypes, UserTypes
|
||||||
from synapse.api.errors import Codes, ResourceLimitError
|
from synapse.api.errors import Codes, ResourceLimitError
|
||||||
from synapse.config.server import is_threepid_reserved
|
from synapse.config.server import is_threepid_reserved
|
||||||
@ -36,8 +34,7 @@ class AuthBlocking(object):
|
|||||||
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
|
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
|
||||||
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
|
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
|
||||||
def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
|
|
||||||
"""Checks if the user should be rejected for some external reason,
|
"""Checks if the user should be rejected for some external reason,
|
||||||
such as monthly active user limiting or global disable flag
|
such as monthly active user limiting or global disable flag
|
||||||
|
|
||||||
@ -60,7 +57,7 @@ class AuthBlocking(object):
|
|||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
if user_id == self._server_notices_mxid:
|
if user_id == self._server_notices_mxid:
|
||||||
return
|
return
|
||||||
if (yield self.store.is_support_user(user_id)):
|
if await self.store.is_support_user(user_id):
|
||||||
return
|
return
|
||||||
|
|
||||||
if self._hs_disabled:
|
if self._hs_disabled:
|
||||||
@ -76,11 +73,11 @@ class AuthBlocking(object):
|
|||||||
|
|
||||||
# If the user is already part of the MAU cohort or a trial user
|
# If the user is already part of the MAU cohort or a trial user
|
||||||
if user_id:
|
if user_id:
|
||||||
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
|
timestamp = await self.store.user_last_seen_monthly_active(user_id)
|
||||||
if timestamp:
|
if timestamp:
|
||||||
return
|
return
|
||||||
|
|
||||||
is_trial = yield self.store.is_trial_user(user_id)
|
is_trial = await self.store.is_trial_user(user_id)
|
||||||
if is_trial:
|
if is_trial:
|
||||||
return
|
return
|
||||||
elif threepid:
|
elif threepid:
|
||||||
@ -93,7 +90,7 @@ class AuthBlocking(object):
|
|||||||
# allow registration. Support users are excluded from MAU checks.
|
# allow registration. Support users are excluded from MAU checks.
|
||||||
return
|
return
|
||||||
# Else if there is no room in the MAU bucket, bail
|
# Else if there is no room in the MAU bucket, bail
|
||||||
current_mau = yield self.store.get_monthly_active_count()
|
current_mau = await self.store.get_monthly_active_count()
|
||||||
if current_mau >= self._max_mau_value:
|
if current_mau >= self._max_mau_value:
|
||||||
raise ResourceLimitError(
|
raise ResourceLimitError(
|
||||||
403,
|
403,
|
||||||
|
@ -21,8 +21,6 @@ import jsonschema
|
|||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
from jsonschema import FormatChecker
|
from jsonschema import FormatChecker
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import EventContentFields
|
from synapse.api.constants import EventContentFields
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.storage.presence import UserPresenceState
|
from synapse.storage.presence import UserPresenceState
|
||||||
@ -137,9 +135,8 @@ class Filtering(object):
|
|||||||
super(Filtering, self).__init__()
|
super(Filtering, self).__init__()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_user_filter(self, user_localpart, filter_id):
|
||||||
def get_user_filter(self, user_localpart, filter_id):
|
result = await self.store.get_user_filter(user_localpart, filter_id)
|
||||||
result = yield self.store.get_user_filter(user_localpart, filter_id)
|
|
||||||
return FilterCollection(result)
|
return FilterCollection(result)
|
||||||
|
|
||||||
def add_user_filter(self, user_localpart, user_filter):
|
def add_user_filter(self, user_localpart, user_filter):
|
||||||
|
@ -106,7 +106,7 @@ class EventBuilder(object):
|
|||||||
state_ids = await self._state.get_current_state_ids(
|
state_ids = await self._state.get_current_state_ids(
|
||||||
self.room_id, prev_event_ids
|
self.room_id, prev_event_ids
|
||||||
)
|
)
|
||||||
auth_ids = await self._auth.compute_auth_events(self, state_ids)
|
auth_ids = self._auth.compute_auth_events(self, state_ids)
|
||||||
|
|
||||||
format_version = self.room_version.event_format
|
format_version = self.room_version.event_format
|
||||||
if format_version == EventFormatVersions.V1:
|
if format_version == EventFormatVersions.V1:
|
||||||
|
@ -2064,7 +2064,7 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
if not auth_events:
|
if not auth_events:
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
auth_events_ids = await self.auth.compute_auth_events(
|
auth_events_ids = self.auth.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=True
|
event, prev_state_ids, for_verification=True
|
||||||
)
|
)
|
||||||
auth_events_x = await self.store.get_events(auth_events_ids)
|
auth_events_x = await self.store.get_events(auth_events_ids)
|
||||||
|
@ -1061,7 +1061,7 @@ class EventCreationHandler(object):
|
|||||||
raise SynapseError(400, "Cannot redact event from a different room")
|
raise SynapseError(400, "Cannot redact event from a different room")
|
||||||
|
|
||||||
prev_state_ids = await context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
auth_events_ids = await self.auth.compute_auth_events(
|
auth_events_ids = self.auth.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=True
|
event, prev_state_ids, for_verification=True
|
||||||
)
|
)
|
||||||
auth_events = await self.store.get_events(auth_events_ids)
|
auth_events = await self.store.get_events(auth_events_ids)
|
||||||
|
@ -194,12 +194,16 @@ class ModuleApi(object):
|
|||||||
synapse.api.errors.AuthError: the access token is invalid
|
synapse.api.errors.AuthError: the access token is invalid
|
||||||
"""
|
"""
|
||||||
# see if the access token corresponds to a device
|
# see if the access token corresponds to a device
|
||||||
user_info = yield self._auth.get_user_by_access_token(access_token)
|
user_info = yield defer.ensureDeferred(
|
||||||
|
self._auth.get_user_by_access_token(access_token)
|
||||||
|
)
|
||||||
device_id = user_info.get("device_id")
|
device_id = user_info.get("device_id")
|
||||||
user_id = user_info["user"].to_string()
|
user_id = user_info["user"].to_string()
|
||||||
if device_id:
|
if device_id:
|
||||||
# delete the device, which will also delete its access tokens
|
# delete the device, which will also delete its access tokens
|
||||||
yield self._hs.get_device_handler().delete_device(user_id, device_id)
|
yield defer.ensureDeferred(
|
||||||
|
self._hs.get_device_handler().delete_device(user_id, device_id)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# no associated device. Just delete the access token.
|
# no associated device. Just delete the access token.
|
||||||
yield defer.ensureDeferred(
|
yield defer.ensureDeferred(
|
||||||
|
@ -120,7 +120,7 @@ class BulkPushRuleEvaluator(object):
|
|||||||
pl_event = await self.store.get_event(pl_event_id)
|
pl_event = await self.store.get_event(pl_event_id)
|
||||||
auth_events = {POWER_KEY: pl_event}
|
auth_events = {POWER_KEY: pl_event}
|
||||||
else:
|
else:
|
||||||
auth_events_ids = await self.auth.compute_auth_events(
|
auth_events_ids = self.auth.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=False
|
event, prev_state_ids, for_verification=False
|
||||||
)
|
)
|
||||||
auth_events = await self.store.get_events(auth_events_ids)
|
auth_events = await self.store.get_events(auth_events_ids)
|
||||||
|
@ -28,7 +28,7 @@ class SlavedClientIpStore(BaseSlavedStore):
|
|||||||
name="client_ip_last_seen", keylen=4, max_entries=50000
|
name="client_ip_last_seen", keylen=4, max_entries=50000
|
||||||
)
|
)
|
||||||
|
|
||||||
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
|
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
|
||||||
now = int(self._clock.time_msec())
|
now = int(self._clock.time_msec())
|
||||||
key = (user_id, access_token, ip)
|
key = (user_id, access_token, ip)
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ class ClientDirectoryServer(RestServlet):
|
|||||||
dir_handler = self.handlers.directory_handler
|
dir_handler = self.handlers.directory_handler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
service = await self.auth.get_appservice_by_req(request)
|
service = self.auth.get_appservice_by_req(request)
|
||||||
room_alias = RoomAlias.from_string(room_alias)
|
room_alias = RoomAlias.from_string(room_alias)
|
||||||
await dir_handler.delete_appservice_association(service, room_alias)
|
await dir_handler.delete_appservice_association(service, room_alias)
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -424,7 +424,7 @@ class RegisterRestServlet(RestServlet):
|
|||||||
|
|
||||||
appservice = None
|
appservice = None
|
||||||
if self.auth.has_access_token(request):
|
if self.auth.has_access_token(request):
|
||||||
appservice = await self.auth.get_appservice_by_req(request)
|
appservice = self.auth.get_appservice_by_req(request)
|
||||||
|
|
||||||
# fork off as soon as possible for ASes which have completely
|
# fork off as soon as possible for ASes which have completely
|
||||||
# different registration flows to normal users
|
# different registration flows to normal users
|
||||||
|
@ -380,8 +380,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
|||||||
if self.user_ips_max_age:
|
if self.user_ips_max_age:
|
||||||
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
|
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def insert_client_ip(
|
||||||
def insert_client_ip(
|
|
||||||
self, user_id, access_token, ip, user_agent, device_id, now=None
|
self, user_id, access_token, ip, user_agent, device_id, now=None
|
||||||
):
|
):
|
||||||
if not now:
|
if not now:
|
||||||
@ -392,7 +391,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
|||||||
last_seen = self.client_ip_last_seen.get(key)
|
last_seen = self.client_ip_last_seen.get(key)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
last_seen = None
|
last_seen = None
|
||||||
yield self.populate_monthly_active_users(user_id)
|
await self.populate_monthly_active_users(user_id)
|
||||||
# Rate-limited inserts
|
# Rate-limited inserts
|
||||||
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||||
return
|
return
|
||||||
|
@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
# this is overridden for the appservice tests
|
# this is overridden for the appservice tests
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||||
|
|
||||||
|
self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
|
||||||
self.store.is_support_user = Mock(return_value=defer.succeed(False))
|
self.store.is_support_user = Mock(return_value=defer.succeed(False))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_user_by_req_user_valid_token(self):
|
def test_get_user_by_req_user_valid_token(self):
|
||||||
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
|
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
|
||||||
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
self.store.get_user_by_access_token = Mock(
|
||||||
|
return_value=defer.succeed(user_info)
|
||||||
|
)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
self.assertEquals(requester.user.to_string(), self.test_user)
|
self.assertEquals(requester.user.to_string(), self.test_user)
|
||||||
|
|
||||||
def test_get_user_by_req_user_bad_token(self):
|
def test_get_user_by_req_user_bad_token(self):
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
|
||||||
f = self.failureResultOf(d, InvalidClientTokenError).value
|
f = self.failureResultOf(d, InvalidClientTokenError).value
|
||||||
self.assertEqual(f.code, 401)
|
self.assertEqual(f.code, 401)
|
||||||
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
||||||
|
|
||||||
def test_get_user_by_req_user_missing_token(self):
|
def test_get_user_by_req_user_missing_token(self):
|
||||||
user_info = {"name": self.test_user, "token_id": "ditto"}
|
user_info = {"name": self.test_user, "token_id": "ditto"}
|
||||||
self.store.get_user_by_access_token = Mock(return_value=user_info)
|
self.store.get_user_by_access_token = Mock(
|
||||||
|
return_value=defer.succeed(user_info)
|
||||||
|
)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
|
||||||
f = self.failureResultOf(d, MissingClientTokenError).value
|
f = self.failureResultOf(d, MissingClientTokenError).value
|
||||||
self.assertEqual(f.code, 401)
|
self.assertEqual(f.code, 401)
|
||||||
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
|
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
|
||||||
@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
||||||
)
|
)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientIP.return_value = "127.0.0.1"
|
||||||
@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
ip_range_whitelist=IPSet(["192.168/16"]),
|
ip_range_whitelist=IPSet(["192.168/16"]),
|
||||||
)
|
)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "192.168.10.10"
|
request.getClientIP.return_value = "192.168.10.10"
|
||||||
@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
ip_range_whitelist=IPSet(["192.168/16"]),
|
ip_range_whitelist=IPSet(["192.168/16"]),
|
||||||
)
|
)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "131.111.8.42"
|
request.getClientIP.return_value = "131.111.8.42"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
|
||||||
f = self.failureResultOf(d, InvalidClientTokenError).value
|
f = self.failureResultOf(d, InvalidClientTokenError).value
|
||||||
self.assertEqual(f.code, 401)
|
self.assertEqual(f.code, 401)
|
||||||
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_bad_token(self):
|
def test_get_user_by_req_appservice_bad_token(self):
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
|
||||||
f = self.failureResultOf(d, InvalidClientTokenError).value
|
f = self.failureResultOf(d, InvalidClientTokenError).value
|
||||||
self.assertEqual(f.code, 401)
|
self.assertEqual(f.code, 401)
|
||||||
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
||||||
@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
def test_get_user_by_req_appservice_missing_token(self):
|
def test_get_user_by_req_appservice_missing_token(self):
|
||||||
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
|
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
|
||||||
f = self.failureResultOf(d, MissingClientTokenError).value
|
f = self.failureResultOf(d, MissingClientTokenError).value
|
||||||
self.assertEqual(f.code, 401)
|
self.assertEqual(f.code, 401)
|
||||||
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
|
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
|
||||||
@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
app_service.is_interested_in_user = Mock(return_value=True)
|
app_service.is_interested_in_user = Mock(return_value=True)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
# This just needs to return a truth-y value.
|
||||||
|
self.store.get_user_by_id = Mock(
|
||||||
|
return_value=defer.succeed({"is_guest": False})
|
||||||
|
)
|
||||||
|
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientIP.return_value = "127.0.0.1"
|
||||||
@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
app_service.is_interested_in_user = Mock(return_value=False)
|
app_service.is_interested_in_user = Mock(return_value=False)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientIP.return_value = "127.0.0.1"
|
request.getClientIP.return_value = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
request.args[b"user_id"] = [masquerading_user_id]
|
request.args[b"user_id"] = [masquerading_user_id]
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
d = self.auth.get_user_by_req(request)
|
d = defer.ensureDeferred(self.auth.get_user_by_req(request))
|
||||||
self.failureResultOf(d, AuthError)
|
self.failureResultOf(d, AuthError)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_user_from_macaroon(self):
|
def test_get_user_from_macaroon(self):
|
||||||
self.store.get_user_by_access_token = Mock(
|
self.store.get_user_by_access_token = Mock(
|
||||||
return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
|
return_value=defer.succeed(
|
||||||
|
{"name": "@baldrick:matrix.org", "device_id": "device"}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
user_id = "@baldrick:matrix.org"
|
user_id = "@baldrick:matrix.org"
|
||||||
@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_guest_user_from_macaroon(self):
|
def test_get_guest_user_from_macaroon(self):
|
||||||
self.store.get_user_by_id = Mock(return_value={"is_guest": True})
|
self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
|
||||||
self.store.get_user_by_access_token = Mock(return_value=None)
|
self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
|
||||||
|
|
||||||
user_id = "@baldrick:matrix.org"
|
user_id = "@baldrick:matrix.org"
|
||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def get_user(tok):
|
def get_user(tok):
|
||||||
if token != tok:
|
if token != tok:
|
||||||
return None
|
return defer.succeed(None)
|
||||||
return {
|
return defer.succeed(
|
||||||
"name": USER_ID,
|
{
|
||||||
"is_guest": False,
|
"name": USER_ID,
|
||||||
"token_id": 1234,
|
"is_guest": False,
|
||||||
"device_id": "DEVICE",
|
"token_id": 1234,
|
||||||
}
|
"device_id": "DEVICE",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
self.store.get_user_by_access_token = get_user
|
self.store.get_user_by_access_token = get_user
|
||||||
self.store.get_user_by_id = Mock(return_value={"is_guest": False})
|
self.store.get_user_by_id = Mock(
|
||||||
|
return_value=defer.succeed({"is_guest": False})
|
||||||
|
)
|
||||||
|
|
||||||
# check the token works
|
# check the token works
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
|
@ -375,8 +375,10 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
event = MockEvent(sender="@foo:bar", type="m.profile")
|
event = MockEvent(sender="@foo:bar", type="m.profile")
|
||||||
events = [event]
|
events = [event]
|
||||||
|
|
||||||
user_filter = yield self.filtering.get_user_filter(
|
user_filter = yield defer.ensureDeferred(
|
||||||
user_localpart=user_localpart, filter_id=filter_id
|
self.filtering.get_user_filter(
|
||||||
|
user_localpart=user_localpart, filter_id=filter_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
results = user_filter.filter_presence(events=events)
|
results = user_filter.filter_presence(events=events)
|
||||||
@ -396,8 +398,10 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
events = [event]
|
events = [event]
|
||||||
|
|
||||||
user_filter = yield self.filtering.get_user_filter(
|
user_filter = yield defer.ensureDeferred(
|
||||||
user_localpart=user_localpart + "2", filter_id=filter_id
|
self.filtering.get_user_filter(
|
||||||
|
user_localpart=user_localpart + "2", filter_id=filter_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
results = user_filter.filter_presence(events=events)
|
results = user_filter.filter_presence(events=events)
|
||||||
@ -412,8 +416,10 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
|
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
|
||||||
events = [event]
|
events = [event]
|
||||||
|
|
||||||
user_filter = yield self.filtering.get_user_filter(
|
user_filter = yield defer.ensureDeferred(
|
||||||
user_localpart=user_localpart, filter_id=filter_id
|
self.filtering.get_user_filter(
|
||||||
|
user_localpart=user_localpart, filter_id=filter_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
results = user_filter.filter_room_state(events=events)
|
results = user_filter.filter_room_state(events=events)
|
||||||
@ -430,8 +436,10 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
events = [event]
|
events = [event]
|
||||||
|
|
||||||
user_filter = yield self.filtering.get_user_filter(
|
user_filter = yield defer.ensureDeferred(
|
||||||
user_localpart=user_localpart, filter_id=filter_id
|
self.filtering.get_user_filter(
|
||||||
|
user_localpart=user_localpart, filter_id=filter_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
results = user_filter.filter_room_state(events)
|
results = user_filter.filter_room_state(events)
|
||||||
@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
user_filter_json,
|
user_filter_json,
|
||||||
(
|
(
|
||||||
yield self.datastore.get_user_filter(
|
yield defer.ensureDeferred(
|
||||||
user_localpart=user_localpart, filter_id=0
|
self.datastore.get_user_filter(
|
||||||
|
user_localpart=user_localpart, filter_id=0
|
||||||
|
)
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -479,8 +489,10 @@ class FilteringTestCase(unittest.TestCase):
|
|||||||
user_localpart=user_localpart, user_filter=user_filter_json
|
user_localpart=user_localpart, user_filter=user_filter_json
|
||||||
)
|
)
|
||||||
|
|
||||||
filter = yield self.filtering.get_user_filter(
|
filter = yield defer.ensureDeferred(
|
||||||
user_localpart=user_localpart, filter_id=filter_id
|
self.filtering.get_user_filter(
|
||||||
|
user_localpart=user_localpart, filter_id=filter_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(filter.get_filter_json(), user_filter_json)
|
self.assertEquals(filter.get_filter_json(), user_filter_json)
|
||||||
|
@ -126,10 +126,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.room_members = []
|
self.room_members = []
|
||||||
|
|
||||||
def check_user_in_room(room_id, user_id):
|
async def check_user_in_room(room_id, user_id):
|
||||||
if user_id not in [u.to_string() for u in self.room_members]:
|
if user_id not in [u.to_string() for u in self.room_members]:
|
||||||
raise AuthError(401, "User is not in the room")
|
raise AuthError(401, "User is not in the room")
|
||||||
return defer.succeed(None)
|
return None
|
||||||
|
|
||||||
hs.get_auth().check_user_in_room = check_user_in_room
|
hs.get_auth().check_user_in_room = check_user_in_room
|
||||||
|
|
||||||
|
@ -20,6 +20,8 @@ import urllib.parse
|
|||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import HttpResponseException, ResourceLimitError
|
from synapse.api.errors import HttpResponseException, ResourceLimitError
|
||||||
@ -335,7 +337,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
|
|||||||
store = self.hs.get_datastore()
|
store = self.hs.get_datastore()
|
||||||
|
|
||||||
# Set monthly active users to the limit
|
# Set monthly active users to the limit
|
||||||
store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value)
|
store.get_monthly_active_count = Mock(
|
||||||
|
return_value=defer.succeed(self.hs.config.max_mau_value)
|
||||||
|
)
|
||||||
# Check that the blocking of monthly active users is working as expected
|
# Check that the blocking of monthly active users is working as expected
|
||||||
# The registration of a new user fails due to the limit
|
# The registration of a new user fails due to the limit
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
@ -588,7 +592,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
# Set monthly active users to the limit
|
# Set monthly active users to the limit
|
||||||
self.store.get_monthly_active_count = Mock(
|
self.store.get_monthly_active_count = Mock(
|
||||||
return_value=self.hs.config.max_mau_value
|
return_value=defer.succeed(self.hs.config.max_mau_value)
|
||||||
)
|
)
|
||||||
# Check that the blocking of monthly active users is working as expected
|
# Check that the blocking of monthly active users is working as expected
|
||||||
# The registration of a new user fails due to the limit
|
# The registration of a new user fails due to the limit
|
||||||
@ -628,7 +632,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
# Set monthly active users to the limit
|
# Set monthly active users to the limit
|
||||||
self.store.get_monthly_active_count = Mock(
|
self.store.get_monthly_active_count = Mock(
|
||||||
return_value=self.hs.config.max_mau_value
|
return_value=defer.succeed(self.hs.config.max_mau_value)
|
||||||
)
|
)
|
||||||
# Check that the blocking of monthly active users is working as expected
|
# Check that the blocking of monthly active users is working as expected
|
||||||
# The registration of a new user fails due to the limit
|
# The registration of a new user fails due to the limit
|
||||||
|
@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase):
|
|||||||
profile_handler=self.mock_handler,
|
profile_handler=self.mock_handler,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_user_by_req(request=None, allow_guest=False):
|
async def _get_user_by_req(request=None, allow_guest=False):
|
||||||
return defer.succeed(synapse.types.create_requester(myid))
|
return synapse.types.create_requester(myid)
|
||||||
|
|
||||||
hs.get_auth().get_user_by_req = _get_user_by_req
|
hs.get_auth().get_user_by_req = _get_user_by_req
|
||||||
|
|
||||||
|
@ -23,8 +23,6 @@ from urllib import parse as urlparse
|
|||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||||
from synapse.handlers.pagination import PurgeStatus
|
from synapse.handlers.pagination import PurgeStatus
|
||||||
@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.hs.get_federation_handler = Mock(return_value=Mock())
|
self.hs.get_federation_handler = Mock(return_value=Mock())
|
||||||
|
|
||||||
def _insert_client_ip(*args, **kwargs):
|
async def _insert_client_ip(*args, **kwargs):
|
||||||
return defer.succeed(None)
|
return None
|
||||||
|
|
||||||
self.hs.get_datastore().insert_client_ip = _insert_client_ip
|
self.hs.get_datastore().insert_client_ip = _insert_client_ip
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def get_user_by_access_token(token=None, allow_guest=False):
|
async def get_user_by_access_token(token=None, allow_guest=False):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
hs.get_auth().get_user_by_access_token = get_user_by_access_token
|
hs.get_auth().get_user_by_access_token = get_user_by_access_token
|
||||||
|
|
||||||
def _insert_client_ip(*args, **kwargs):
|
async def _insert_client_ip(*args, **kwargs):
|
||||||
return defer.succeed(None)
|
return None
|
||||||
|
|
||||||
hs.get_datastore().insert_client_ip = _insert_client_ip
|
hs.get_datastore().insert_client_ip = _insert_client_ip
|
||||||
|
|
||||||
|
@ -258,7 +258,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
|
|||||||
self.user_id = "@user_id:test"
|
self.user_id = "@user_id:test"
|
||||||
|
|
||||||
def test_server_notice_only_sent_once(self):
|
def test_server_notice_only_sent_once(self):
|
||||||
self.store.get_monthly_active_count = Mock(return_value=1000)
|
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000))
|
||||||
|
|
||||||
self.store.user_last_seen_monthly_active = Mock(
|
self.store.user_last_seen_monthly_active = Mock(
|
||||||
return_value=defer.succeed(1000)
|
return_value=defer.succeed(1000)
|
||||||
|
@ -241,20 +241,16 @@ class HomeserverTestCase(TestCase):
|
|||||||
if hasattr(self, "user_id"):
|
if hasattr(self, "user_id"):
|
||||||
if self.hijack_auth:
|
if self.hijack_auth:
|
||||||
|
|
||||||
def get_user_by_access_token(token=None, allow_guest=False):
|
async def get_user_by_access_token(token=None, allow_guest=False):
|
||||||
return succeed(
|
return {
|
||||||
{
|
"user": UserID.from_string(self.helper.auth_user_id),
|
||||||
"user": UserID.from_string(self.helper.auth_user_id),
|
"token_id": 1,
|
||||||
"token_id": 1,
|
"is_guest": False,
|
||||||
"is_guest": False,
|
}
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_user_by_req(request, allow_guest=False, rights="access"):
|
async def get_user_by_req(request, allow_guest=False, rights="access"):
|
||||||
return succeed(
|
return create_requester(
|
||||||
create_requester(
|
UserID.from_string(self.helper.auth_user_id), 1, False, None
|
||||||
UserID.from_string(self.helper.auth_user_id), 1, False, None
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hs.get_auth().get_user_by_req = get_user_by_req
|
self.hs.get_auth().get_user_by_req = get_user_by_req
|
||||||
|
Loading…
Reference in New Issue
Block a user