Validate group ids when parsing

May as well do it whenever we parse a Group ID. We check the sigil and basic
structure here so it makes sense to check the grammar in the same place.
This commit is contained in:
Richard van der Hoff 2017-10-20 23:51:07 +01:00
parent 29812c628b
commit 1135193dfd
3 changed files with 45 additions and 17 deletions

View File

@ -15,7 +15,6 @@
import logging import logging
from synapse import types
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from twisted.internet import defer from twisted.internet import defer
@ -696,9 +695,11 @@ class GroupsServerHandler(object):
def create_group(self, group_id, user_id, content): def create_group(self, group_id, user_id, content):
group = yield self.check_group_is_ours(group_id) group = yield self.check_group_is_ours(group_id)
_validate_group_id(group_id)
logger.info("Attempting to create group with ID: %r", group_id) logger.info("Attempting to create group with ID: %r", group_id)
# parsing the id into a GroupID validates it.
group_id_obj = GroupID.from_string(group_id)
if group: if group:
raise SynapseError(400, "Group already exists") raise SynapseError(400, "Group already exists")
@ -708,7 +709,7 @@ class GroupsServerHandler(object):
raise SynapseError( raise SynapseError(
403, "Only server admin can create group on this server", 403, "Only server admin can create group on this server",
) )
localpart = GroupID.from_string(group_id).localpart localpart = group_id_obj.localpart
if not localpart.startswith(self.hs.config.group_creation_prefix): if not localpart.startswith(self.hs.config.group_creation_prefix):
raise SynapseError( raise SynapseError(
400, 400,
@ -784,15 +785,3 @@ def _parse_visibility_from_contents(content):
is_public = True is_public = True
return is_public return is_public
def _validate_group_id(group_id):
"""Validates the group ID is valid for creation on this home server
"""
localpart = GroupID.from_string(group_id).localpart
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
"Group ID can only contain characters a-z, 0-9, or '=_-./'",
)

View File

@ -161,6 +161,23 @@ class GroupID(DomainSpecificString):
"""Structure representing a group ID.""" """Structure representing a group ID."""
SIGIL = "+" SIGIL = "+"
@classmethod
def from_string(cls, s):
group_id = super(GroupID, cls).from_string(s)
if not group_id.localpart:
raise SynapseError(
400,
"Group ID cannot be empty",
)
if contains_invalid_mxid_characters(group_id.localpart):
raise SynapseError(
400,
"Group ID can only contain characters a-z, 0-9, or '=_-./'",
)
return group_id
mxid_localpart_allowed_characters = set("_-./=" + string.ascii_lowercase + string.digits) mxid_localpart_allowed_characters = set("_-./=" + string.ascii_lowercase + string.digits)

View File

@ -17,7 +17,7 @@ from tests import unittest
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID, RoomAlias from synapse.types import UserID, RoomAlias, GroupID
mock_homeserver = HomeServer(hostname="my.domain") mock_homeserver = HomeServer(hostname="my.domain")
@ -60,3 +60,25 @@ class RoomAliasTestCase(unittest.TestCase):
room = RoomAlias("channel", "my.domain") room = RoomAlias("channel", "my.domain")
self.assertEquals(room.to_string(), "#channel:my.domain") self.assertEquals(room.to_string(), "#channel:my.domain")
class GroupIDTestCase(unittest.TestCase):
def test_parse(self):
group_id = GroupID.from_string("+group/=_-.123:my.domain")
self.assertEqual("group/=_-.123", group_id.localpart)
self.assertEqual("my.domain", group_id.domain)
def test_validate(self):
bad_ids = [
"$badsigil:domain",
"+:empty",
] + [
"+group" + c + ":domain" for c in "A%?æ£"
]
for id_string in bad_ids:
try:
GroupID.from_string(id_string)
self.fail("Parsing '%s' should raise exception" % id_string)
except SynapseError as exc:
self.assertEqual(400, exc.code)
self.assertEqual("M_UNKNOWN", exc.errcode)