Merge pull request #2561 from matrix-org/rav/id_checking

Updates to ID checking
This commit is contained in:
Richard van der Hoff 2017-10-23 14:39:20 +01:00 committed by GitHub
commit 3267b81b81
4 changed files with 70 additions and 34 deletions

View File

@ -13,14 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer import logging
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import UserID, get_domain_from_id, RoomID, GroupID from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from twisted.internet import defer
import logging
import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -698,9 +695,11 @@ class GroupsServerHandler(object):
def create_group(self, group_id, user_id, content): def create_group(self, group_id, user_id, content):
group = yield self.check_group_is_ours(group_id) group = yield self.check_group_is_ours(group_id)
_validate_group_id(group_id)
logger.info("Attempting to create group with ID: %r", group_id) logger.info("Attempting to create group with ID: %r", group_id)
# parsing the id into a GroupID validates it.
group_id_obj = GroupID.from_string(group_id)
if group: if group:
raise SynapseError(400, "Group already exists") raise SynapseError(400, "Group already exists")
@ -710,7 +709,7 @@ class GroupsServerHandler(object):
raise SynapseError( raise SynapseError(
403, "Only server admin can create group on this server", 403, "Only server admin can create group on this server",
) )
localpart = GroupID.from_string(group_id).localpart localpart = group_id_obj.localpart
if not localpart.startswith(self.hs.config.group_creation_prefix): if not localpart.startswith(self.hs.config.group_creation_prefix):
raise SynapseError( raise SynapseError(
400, 400,
@ -786,18 +785,3 @@ def _parse_visibility_from_contents(content):
is_public = True is_public = True
return is_public return is_public
def _validate_group_id(group_id):
"""Validates the group ID is valid for creation on this home server
"""
localpart = GroupID.from_string(group_id).localpart
if localpart.lower() != localpart:
raise SynapseError(400, "Group ID must be lower case")
if urllib.quote(localpart.encode('utf-8')) != localpart:
raise SynapseError(
400,
"Group ID can only contain characters a-z, 0-9, or '_-./'",
)

View File

@ -15,7 +15,6 @@
"""Contains functions for registering clients.""" """Contains functions for registering clients."""
import logging import logging
import urllib
from twisted.internet import defer from twisted.internet import defer
@ -23,6 +22,7 @@ from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
from synapse import types
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from ._base import BaseHandler from ._base import BaseHandler
@ -46,12 +46,10 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None, def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None): assigned_user_id=None):
yield run_on_reactor() if types.contains_invalid_mxid_characters(localpart):
if urllib.quote(localpart.encode('utf-8')) != localpart:
raise SynapseError( raise SynapseError(
400, 400,
"User ID can only contain characters a-z, 0-9, or '_-./'", "User ID can only contain characters a-z, 0-9, or '=_-./'",
Codes.INVALID_USERNAME Codes.INVALID_USERNAME
) )
@ -81,7 +79,7 @@ class RegistrationHandler(BaseHandler):
"A different user ID has already been registered for this session", "A different user ID has already been registered for this session",
) )
yield self.check_user_id_not_appservice_exclusive(user_id) self.check_user_id_not_appservice_exclusive(user_id)
users = yield self.store.get_users_by_id_case_insensitive(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users: if users:
@ -254,11 +252,10 @@ class RegistrationHandler(BaseHandler):
""" """
Registers email_id as SAML2 Based Auth. Registers email_id as SAML2 Based Auth.
""" """
if urllib.quote(localpart) != localpart: if types.contains_invalid_mxid_characters(localpart):
raise SynapseError( raise SynapseError(
400, 400,
"User ID must only contain characters which do not" "User ID can only contain characters a-z, 0-9, or '=_-./'",
" require URL encoding."
) )
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import string
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
@ -156,6 +157,38 @@ class GroupID(DomainSpecificString):
"""Structure representing a group ID.""" """Structure representing a group ID."""
SIGIL = "+" SIGIL = "+"
@classmethod
def from_string(cls, s):
group_id = super(GroupID, cls).from_string(s)
if not group_id.localpart:
raise SynapseError(
400,
"Group ID cannot be empty",
)
if contains_invalid_mxid_characters(group_id.localpart):
raise SynapseError(
400,
"Group ID can only contain characters a-z, 0-9, or '=_-./'",
)
return group_id
mxid_localpart_allowed_characters = set("_-./=" + string.ascii_lowercase + string.digits)
def contains_invalid_mxid_characters(localpart):
"""Check for characters not allowed in an mxid or groupid localpart
Args:
localpart (basestring): the localpart to be checked
Returns:
bool: True if there are any naughty characters
"""
return any(c not in mxid_localpart_allowed_characters for c in localpart)
class StreamToken( class StreamToken(
namedtuple("Token", ( namedtuple("Token", (

View File

@ -17,7 +17,7 @@ from tests import unittest
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID, RoomAlias from synapse.types import UserID, RoomAlias, GroupID
mock_homeserver = HomeServer(hostname="my.domain") mock_homeserver = HomeServer(hostname="my.domain")
@ -60,3 +60,25 @@ class RoomAliasTestCase(unittest.TestCase):
room = RoomAlias("channel", "my.domain") room = RoomAlias("channel", "my.domain")
self.assertEquals(room.to_string(), "#channel:my.domain") self.assertEquals(room.to_string(), "#channel:my.domain")
class GroupIDTestCase(unittest.TestCase):
def test_parse(self):
group_id = GroupID.from_string("+group/=_-.123:my.domain")
self.assertEqual("group/=_-.123", group_id.localpart)
self.assertEqual("my.domain", group_id.domain)
def test_validate(self):
bad_ids = [
"$badsigil:domain",
"+:empty",
] + [
"+group" + c + ":domain" for c in "A%?æ£"
]
for id_string in bad_ids:
try:
GroupID.from_string(id_string)
self.fail("Parsing '%s' should raise exception" % id_string)
except SynapseError as exc:
self.assertEqual(400, exc.code)
self.assertEqual("M_UNKNOWN", exc.errcode)