More server_name validation

We need to do a bit more validation when we get a server name, but don't want
to be re-doing it all over the shop, so factor out a separate
parse_and_validate_server_name, and do the extra validation.

Also, use it to verify the server name in the config file.
This commit is contained in:
Richard van der Hoff 2018-07-04 18:15:03 +01:00
parent 13f7adf84b
commit 546bc9e28b
5 changed files with 68 additions and 13 deletions

1
changelog.d/3483.feature Normal file
View File

@ -0,0 +1 @@
Reject invalid server names in homeserver.yaml

View File

@ -16,6 +16,7 @@
import logging import logging
from synapse.http.endpoint import parse_and_validate_server_name
from ._base import Config, ConfigError from ._base import Config, ConfigError
logger = logging.Logger(__name__) logger = logging.Logger(__name__)
@ -25,6 +26,12 @@ class ServerConfig(Config):
def read_config(self, config): def read_config(self, config):
self.server_name = config["server_name"] self.server_name = config["server_name"]
try:
parse_and_validate_server_name(self.server_name)
except ValueError as e:
raise ConfigError(str(e))
self.pid_file = self.abspath(config.get("pid_file")) self.pid_file = self.abspath(config.get("pid_file"))
self.web_client = config["web_client"] self.web_client = config["web_client"]
self.web_client_location = config.get("web_client_location", None) self.web_client_location = config.get("web_client_location", None)
@ -162,8 +169,8 @@ class ServerConfig(Config):
}) })
def default_config(self, server_name, **kwargs): def default_config(self, server_name, **kwargs):
if ":" in server_name: _, bind_port = parse_and_validate_server_name(server_name)
bind_port = int(server_name.split(":")[1]) if bind_port is not None:
unsecure_port = bind_port - 400 unsecure_port = bind_port - 400
else: else:
bind_port = 8448 bind_port = 8448

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError, FederationDeniedError from synapse.api.errors import Codes, SynapseError, FederationDeniedError
from synapse.http.endpoint import parse_server_name from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import ( from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@ -170,8 +170,9 @@ def _parse_auth_header(header_bytes):
return value return value
origin = strip_quotes(param_dict["origin"]) origin = strip_quotes(param_dict["origin"])
# ensure that the origin is a valid server name # ensure that the origin is a valid server name
parse_server_name(origin) parse_and_validate_server_name(origin)
key = strip_quotes(param_dict["key"]) key = strip_quotes(param_dict["key"])
sig = strip_quotes(param_dict["sig"]) sig = strip_quotes(param_dict["sig"])

View File

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.error import ConnectError from twisted.internet.error import ConnectError
@ -41,8 +43,6 @@ _Server = collections.namedtuple(
def parse_server_name(server_name): def parse_server_name(server_name):
"""Split a server name into host/port parts. """Split a server name into host/port parts.
Does some basic sanity checking of the
Args: Args:
server_name (str): server name to parse server_name (str): server name to parse
@ -55,9 +55,6 @@ def parse_server_name(server_name):
try: try:
if server_name[-1] == ']': if server_name[-1] == ']':
# ipv6 literal, hopefully # ipv6 literal, hopefully
if server_name[0] != '[':
raise Exception()
return server_name, None return server_name, None
domain_port = server_name.rsplit(":", 1) domain_port = server_name.rsplit(":", 1)
@ -68,6 +65,46 @@ def parse_server_name(server_name):
raise ValueError("Invalid server name '%s'" % server_name) raise ValueError("Invalid server name '%s'" % server_name)
VALID_HOST_REGEX = re.compile(
"\\A[0-9a-zA-Z.-]+\\Z",
)
def parse_and_validate_server_name(server_name):
"""Split a server name into host/port parts and do some basic validation.
Args:
server_name (str): server name to parse
Returns:
Tuple[str, int|None]: host/port parts.
Raises:
ValueError if the server name could not be parsed.
"""
host, port = parse_server_name(server_name)
# these tests don't need to be bulletproof as we'll find out soon enough
# if somebody is giving us invalid data. What we *do* need is to be sure
# that nobody is sneaking IP literals in that look like hostnames, etc.
# look for ipv6 literals
if host[0] == '[':
if host[-1] != ']':
raise ValueError("Mismatched [...] in server name '%s'" % (
server_name,
))
return host, port
# otherwise it should only be alphanumerics.
if not VALID_HOST_REGEX.match(host):
raise ValueError("Server name '%s' contains invalid characters" % (
server_name,
))
return host, port
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout=None): timeout=None):
"""Construct an endpoint for the given matrix destination. """Construct an endpoint for the given matrix destination.

View File

@ -12,7 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.http.endpoint import parse_server_name from synapse.http.endpoint import (
parse_server_name,
parse_and_validate_server_name,
)
from tests import unittest from tests import unittest
@ -30,17 +33,23 @@ class ServerNameTestCase(unittest.TestCase):
for i, o in test_data.items(): for i, o in test_data.items():
self.assertEqual(parse_server_name(i), o) self.assertEqual(parse_server_name(i), o)
def test_parse_bad_server_names(self): def test_validate_bad_server_names(self):
test_data = [ test_data = [
"", # empty "", # empty
"localhost:http", # non-numeric port "localhost:http", # non-numeric port
"1234]", # smells like ipv6 literal but isn't "1234]", # smells like ipv6 literal but isn't
"[1234",
"underscore_.com",
"percent%65.com",
"1234:5678:80", # too many colons
] ]
for i in test_data: for i in test_data:
try: try:
parse_server_name(i) parse_and_validate_server_name(i)
self.fail( self.fail(
"Expected parse_server_name(\"%s\") to throw" % i, "Expected parse_and_validate_server_name('%s') to throw" % (
i,
),
) )
except ValueError: except ValueError:
pass pass