Add and use get_domian_from_id

This commit is contained in:
Erik Johnston 2016-05-09 10:36:03 +01:00
parent 96d9d5d388
commit 08dfa8eee2
6 changed files with 23 additions and 28 deletions

View File

@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, RoomID, UserID, EventID from synapse.types import Requester, UserID, get_domian_from_id
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -91,8 +91,8 @@ class Auth(object):
"Room %r does not exist" % (event.room_id,) "Room %r does not exist" % (event.room_id,)
) )
creating_domain = RoomID.from_string(event.room_id).domain creating_domain = get_domian_from_id(event.room_id)
originating_domain = UserID.from_string(event.sender).domain originating_domain = get_domian_from_id(event.sender)
if creating_domain != originating_domain: if creating_domain != originating_domain:
if not self.can_federate(event, auth_events): if not self.can_federate(event, auth_events):
raise AuthError( raise AuthError(
@ -219,7 +219,7 @@ class Auth(object):
for event in curr_state.values(): for event in curr_state.values():
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
try: try:
if UserID.from_string(event.state_key).domain != host: if get_domian_from_id(event.state_key) != host:
continue continue
except: except:
logger.warn("state_key not user_id: %s", event.state_key) logger.warn("state_key not user_id: %s", event.state_key)
@ -266,8 +266,8 @@ class Auth(object):
target_user_id = event.state_key target_user_id = event.state_key
creating_domain = RoomID.from_string(event.room_id).domain creating_domain = get_domian_from_id(event.room_id)
target_domain = UserID.from_string(target_user_id).domain target_domain = get_domian_from_id(target_user_id)
if creating_domain != target_domain: if creating_domain != target_domain:
if not self.can_federate(event, auth_events): if not self.can_federate(event, auth_events):
raise AuthError( raise AuthError(
@ -889,8 +889,8 @@ class Auth(object):
if user_level >= redact_level: if user_level >= redact_level:
return False return False
redacter_domain = EventID.from_string(event.event_id).domain redacter_domain = get_domian_from_id(event.event_id)
redactee_domain = EventID.from_string(event.redacts).domain redactee_domain = get_domian_from_id(event.redacts)
if redacter_domain == redactee_domain: if redacter_domain == redactee_domain:
return True return True

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError, SynapseError, AuthError from synapse.api.errors import LimitExceededError, SynapseError, AuthError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, RoomAlias, Requester from synapse.types import UserID, RoomAlias, Requester, get_domian_from_id
from synapse.push.action_generator import ActionGenerator from synapse.push.action_generator import ActionGenerator
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
@ -296,7 +296,7 @@ class BaseHandler(object):
return True return True
for (state_key, membership) in room_members: for (state_key, membership) in room_members:
if ( if (
UserID.from_string(state_key).domain == self.hs.hostname self.hs.is_mine_id(state_key)
and membership == Membership.JOIN and membership == Membership.JOIN
): ):
return True return True
@ -421,9 +421,7 @@ class BaseHandler(object):
try: try:
if k[0] == EventTypes.Member: if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN: if s.content["membership"] == Membership.JOIN:
destinations.add( destinations.add(get_domian_from_id(s.state_key))
UserID.from_string(s.state_key).domain
)
except SynapseError: except SynapseError:
logger.warn( logger.warn(
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id

View File

@ -33,7 +33,7 @@ from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import ( from synapse.crypto.event_signing import (
compute_event_signature, add_hashes_and_signatures, compute_event_signature, add_hashes_and_signatures,
) )
from synapse.types import UserID from synapse.types import UserID, get_domian_from_id
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
@ -453,7 +453,7 @@ class FederationHandler(BaseHandler):
joined_domains = {} joined_domains = {}
for u, d in joined_users: for u, d in joined_users:
try: try:
dom = UserID.from_string(u).domain dom = get_domian_from_id(u)
old_d = joined_domains.get(dom) old_d = joined_domains.get(dom)
if old_d: if old_d:
joined_domains[dom] = min(d, old_d) joined_domains[dom] = min(d, old_d)
@ -743,9 +743,7 @@ class FederationHandler(BaseHandler):
try: try:
if k[0] == EventTypes.Member: if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN: if s.content["membership"] == Membership.JOIN:
destinations.add( destinations.add(get_domian_from_id(s.state_key))
UserID.from_string(s.state_key).domain
)
except: except:
logger.warn( logger.warn(
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
@ -970,9 +968,7 @@ class FederationHandler(BaseHandler):
try: try:
if k[0] == EventTypes.Member: if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.LEAVE: if s.content["membership"] == Membership.LEAVE:
destinations.add( destinations.add(get_domian_from_id(s.state_key))
UserID.from_string(s.state_key).domain
)
except: except:
logger.warn( logger.warn(
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id

View File

@ -33,7 +33,7 @@ from synapse.util.logcontext import preserve_fn
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID from synapse.types import UserID, get_domian_from_id
import synapse.metrics import synapse.metrics
from ._base import BaseHandler from ._base import BaseHandler
@ -440,7 +440,7 @@ class PresenceHandler(BaseHandler):
if not local_states: if not local_states:
continue continue
host = UserID.from_string(user_id).domain host = get_domian_from_id(user_id)
hosts_to_states.setdefault(host, []).extend(local_states) hosts_to_states.setdefault(host, []).extend(local_states)
# TODO: de-dup hosts_to_states, as a single host might have multiple # TODO: de-dup hosts_to_states, as a single host might have multiple

View File

@ -21,7 +21,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.types import UserID from synapse.types import get_domian_from_id
import logging import logging
@ -273,10 +273,7 @@ class RoomMemberStore(SQLBaseStore):
room_id, membership=Membership.JOIN room_id, membership=Membership.JOIN
) )
joined_domains = set( joined_domains = set(get_domian_from_id(r["user_id"]) for r in rows)
UserID.from_string(r["user_id"]).domain
for r in rows
)
return joined_domains return joined_domains

View File

@ -21,6 +21,10 @@ from collections import namedtuple
Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"]) Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
def get_domian_from_id(string):
return string.split(":", 1)[1]
class DomainSpecificString( class DomainSpecificString(
namedtuple("DomainSpecificString", ("localpart", "domain")) namedtuple("DomainSpecificString", ("localpart", "domain"))
): ):