mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
263 lines
8.0 KiB
Python
263 lines
8.0 KiB
Python
#
|
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
|
#
|
|
# Copyright (C) 2023 New Vector, Ltd
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as
|
|
# published by the Free Software Foundation, either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
#
|
|
# See the GNU Affero General Public License for more details:
|
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
|
#
|
|
# Originally licensed under the Apache License, Version 2.0:
|
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
|
#
|
|
# [This file includes modifications made by New Vector Limited]
|
|
#
|
|
#
|
|
import itertools
|
|
import re
|
|
import secrets
|
|
import string
|
|
from typing import Any, Iterable, Optional, Tuple
|
|
|
|
from netaddr import valid_ipv6
|
|
|
|
from synapse.api.errors import Codes, SynapseError
|
|
|
|
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
|
|
|
|
# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
|
|
CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
|
|
|
|
# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris,
|
|
# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically
|
|
# says "there is no grammar for media ids"
|
|
#
|
|
# The server_name part of this is purposely lax: use parse_and_validate_mxc for
|
|
# additional validation.
|
|
#
|
|
MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$")
|
|
|
|
|
|
def random_string(length: int) -> str:
|
|
"""Generate a cryptographically secure string of random letters.
|
|
|
|
Drawn from the characters: `a-z` and `A-Z`
|
|
"""
|
|
return "".join(secrets.choice(string.ascii_letters) for _ in range(length))
|
|
|
|
|
|
def random_string_with_symbols(length: int) -> str:
|
|
"""Generate a cryptographically secure string of random letters/numbers/symbols.
|
|
|
|
Drawn from the characters: `a-z`, `A-Z`, `0-9`, and `.,;:^&*-_+=#~@`
|
|
"""
|
|
return "".join(secrets.choice(_string_with_symbols) for _ in range(length))
|
|
|
|
|
|
def is_ascii(s: bytes) -> bool:
|
|
try:
|
|
s.decode("ascii").encode("ascii")
|
|
except UnicodeError:
|
|
return False
|
|
return True
|
|
|
|
|
|
def assert_valid_client_secret(client_secret: str) -> None:
|
|
"""Validate that a given string matches the client_secret defined by the spec"""
|
|
if (
|
|
len(client_secret) <= 0
|
|
or len(client_secret) > 255
|
|
or CLIENT_SECRET_REGEX.match(client_secret) is None
|
|
):
|
|
raise SynapseError(
|
|
400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
|
|
)
|
|
|
|
|
|
def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
|
|
"""Split a server name into host/port parts.
|
|
|
|
Args:
|
|
server_name: server name to parse
|
|
|
|
Returns:
|
|
host/port parts.
|
|
|
|
Raises:
|
|
ValueError if the server name could not be parsed.
|
|
"""
|
|
try:
|
|
if server_name and server_name[-1] == "]":
|
|
# ipv6 literal, hopefully
|
|
return server_name, None
|
|
|
|
domain_port = server_name.rsplit(":", 1)
|
|
domain = domain_port[0]
|
|
port = int(domain_port[1]) if domain_port[1:] else None
|
|
return domain, port
|
|
except Exception:
|
|
raise ValueError("Invalid server name '%s'" % server_name)
|
|
|
|
|
|
# An approximation of the domain name syntax in RFC 1035, section 2.3.1.
|
|
# NB: "\Z" is not equivalent to "$".
|
|
# The latter will match the position before a "\n" at the end of a string.
|
|
VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*\\Z")
|
|
|
|
|
|
def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]:
|
|
"""Split a server name into host/port parts and do some basic validation.
|
|
|
|
Args:
|
|
server_name: server name to parse
|
|
|
|
Returns:
|
|
host/port parts.
|
|
|
|
Raises:
|
|
ValueError if the server name could not be parsed.
|
|
"""
|
|
host, port = parse_server_name(server_name)
|
|
|
|
# these tests don't need to be bulletproof as we'll find out soon enough
|
|
# if somebody is giving us invalid data. What we *do* need is to be sure
|
|
# that nobody is sneaking IP literals in that look like hostnames, etc.
|
|
|
|
# look for ipv6 literals
|
|
if host and host[0] == "[":
|
|
if host[-1] != "]":
|
|
raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
|
|
|
|
# valid_ipv6 raises when given an empty string
|
|
ipv6_address = host[1:-1]
|
|
if not ipv6_address or not valid_ipv6(ipv6_address):
|
|
raise ValueError(
|
|
"Server name '%s' is not a valid IPv6 address" % (server_name,)
|
|
)
|
|
elif not VALID_HOST_REGEX.match(host):
|
|
raise ValueError("Server name '%s' has an invalid format" % (server_name,))
|
|
|
|
return host, port
|
|
|
|
|
|
def valid_id_server_location(id_server: str) -> bool:
|
|
"""Check whether an identity server location, such as the one passed as the
|
|
`id_server` parameter to `/_matrix/client/r0/account/3pid/bind`, is valid.
|
|
|
|
A valid identity server location consists of a valid hostname and optional
|
|
port number, optionally followed by any number of `/` delimited path
|
|
components, without any fragment or query string parts.
|
|
|
|
Args:
|
|
id_server: identity server location string to validate
|
|
|
|
Returns:
|
|
True if valid, False otherwise.
|
|
"""
|
|
|
|
components = id_server.split("/", 1)
|
|
|
|
host = components[0]
|
|
|
|
try:
|
|
parse_and_validate_server_name(host)
|
|
except ValueError:
|
|
return False
|
|
|
|
if len(components) < 2:
|
|
# no path
|
|
return True
|
|
|
|
path = components[1]
|
|
return "#" not in path and "?" not in path
|
|
|
|
|
|
def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]:
|
|
"""Parse the given string as an MXC URI
|
|
|
|
Checks that the "server name" part is a valid server name
|
|
|
|
Args:
|
|
mxc: the (alleged) MXC URI to be checked
|
|
Returns:
|
|
hostname, port, media id
|
|
Raises:
|
|
ValueError if the URI cannot be parsed
|
|
"""
|
|
m = MXC_REGEX.match(mxc)
|
|
if not m:
|
|
raise ValueError("mxc URI %r did not match expected format" % (mxc,))
|
|
server_name = m.group(1)
|
|
media_id = m.group(2)
|
|
host, port = parse_and_validate_server_name(server_name)
|
|
return host, port, media_id
|
|
|
|
|
|
def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
|
|
"""If iterable has maxitems or fewer, return the stringification of a list
|
|
containing those items.
|
|
|
|
Otherwise, return the stringification of a list with the first maxitems items,
|
|
followed by "...".
|
|
|
|
Args:
|
|
iterable: iterable to truncate
|
|
maxitems: number of items to return before truncating
|
|
"""
|
|
|
|
items = list(itertools.islice(iterable, maxitems + 1))
|
|
if len(items) <= maxitems:
|
|
return str(items)
|
|
return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
|
|
|
|
|
|
def strtobool(val: str) -> bool:
|
|
"""Convert a string representation of truth to True or False
|
|
|
|
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
|
|
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
|
|
'val' is anything else.
|
|
|
|
This is lifted from distutils.util.strtobool, with the exception that it actually
|
|
returns a bool, rather than an int.
|
|
"""
|
|
val = val.lower()
|
|
if val in ("y", "yes", "t", "true", "on", "1"):
|
|
return True
|
|
elif val in ("n", "no", "f", "false", "off", "0"):
|
|
return False
|
|
else:
|
|
raise ValueError("invalid truth value %r" % (val,))
|
|
|
|
|
|
_BASE62 = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
|
|
|
|
|
def base62_encode(num: int, minwidth: int = 1) -> str:
|
|
"""Encode a number using base62
|
|
|
|
Args:
|
|
num: number to be encoded
|
|
minwidth: width to pad to, if the number is small
|
|
"""
|
|
res = ""
|
|
while num:
|
|
num, rem = divmod(num, 62)
|
|
res = _BASE62[rem] + res
|
|
|
|
# pad to minimum width
|
|
pad = "0" * (minwidth - len(res))
|
|
return pad + res
|
|
|
|
|
|
def non_null_str_or_none(val: Any) -> Optional[str]:
|
|
"""Check that the arg is a string containing no null (U+0000) codepoints.
|
|
|
|
If so, returns the given string unmodified; otherwise, returns None.
|
|
"""
|
|
return val if isinstance(val, str) and "\u0000" not in val else None
|