Check group_id belongs to this domain

This commit is contained in:
Luke Barnard 2017-11-16 17:54:27 +00:00
parent 97bd18af4e
commit b1edf26051
2 changed files with 12 additions and 3 deletions

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.types import GroupID, get_domain_from_id
from twisted.internet import defer from twisted.internet import defer
@ -83,12 +84,13 @@ class ApplicationService(object):
GROUP_ID_REGEX = re.compile('\+.*:.+') GROUP_ID_REGEX = re.compile('\+.*:.+')
def __init__(self, token, url=None, namespaces=None, hs_token=None, def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
sender=None, id=None, protocols=None, rate_limited=True): sender=None, id=None, protocols=None, rate_limited=True):
self.token = token self.token = token
self.url = url self.url = url
self.hs_token = hs_token self.hs_token = hs_token
self.sender = sender self.sender = sender
self.server_name = hostname
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
@ -132,12 +134,18 @@ class ApplicationService(object):
raise ValueError( raise ValueError(
"Expected string for 'group_id' in ns '%s'" % ns "Expected string for 'group_id' in ns '%s'" % ns
) )
if not ApplicationService.GROUP_ID_REGEX.match( try:
regex_obj.get("group_id")): GroupID.from_string(regex_obj.get("group_id"))
except Exception:
raise ValueError( raise ValueError(
"Expected valid group ID for 'group_id' in ns '%s'" % ns "Expected valid group ID for 'group_id' in ns '%s'" % ns
) )
if get_domain_from_id(regex_obj.get("group_id")) != self.server_name:
raise ValueError(
"Expected string for 'group_id' to be for this host in ns '%s'" % ns
)
regex = regex_obj.get("regex") regex = regex_obj.get("regex")
if isinstance(regex, basestring): if isinstance(regex, basestring):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex regex_obj["regex"] = re.compile(regex) # Pre-compile regex

View File

@ -154,6 +154,7 @@ def _load_appservice(hostname, as_info, config_filename):
) )
return ApplicationService( return ApplicationService(
token=as_info["as_token"], token=as_info["as_token"],
hostname=hostname,
url=as_info["url"], url=as_info["url"],
namespaces=as_info["namespaces"], namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"], hs_token=as_info["hs_token"],