Merge pull request from GHSA-x345-32rc-8h85

* tests for push rule pattern matching

* tests for acl pattern matching

* factor out common `re.escape`

* Factor out common re.compile

* Factor out common anchoring code

* add word_boundary support to `glob_to_regex`

* Use `glob_to_regex` in push rule evaluator

NB that this drops support for character classes. I don't think anyone ever
used them.

* Improve efficiency of globs with multiple wildcards

The idea here is that we compress multiple `*` globs into a single `.*`. We
also need to consider `?`, since `*?*` is as hard to implement efficiently as
`**`.

* add assertion on regex pattern

* Fix mypy

* Simplify glob_to_regex

* Inline the glob_to_regex helper function

Signed-off-by: Dan Callahan <danc@element.io>

* Moar comments

Signed-off-by: Dan Callahan <danc@element.io>

Co-authored-by: Dan Callahan <danc@element.io>
This commit is contained in:
Richard van der Hoff 2021-05-11 10:47:23 +01:00 committed by GitHub
parent 4df26abf28
commit 03318a766c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 297 additions and 69 deletions

View File

@ -17,7 +17,7 @@ import os
import warnings import warnings
from datetime import datetime from datetime import datetime
from hashlib import sha256 from hashlib import sha256
from typing import List, Optional from typing import List, Optional, Pattern
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
@ -124,7 +124,7 @@ class TlsConfig(Config):
fed_whitelist_entries = [] fed_whitelist_entries = []
# Support globs (*) in whitelist values # Support globs (*) in whitelist values
self.federation_certificate_verification_whitelist = [] # type: List[str] self.federation_certificate_verification_whitelist = [] # type: List[Pattern]
for entry in fed_whitelist_entries: for entry in fed_whitelist_entries:
try: try:
entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii")) entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii"))

View File

@ -19,6 +19,7 @@ from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import UserID from synapse.types import UserID
from synapse.util import glob_to_regex, re_word_boundary
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -183,7 +184,7 @@ class PushRuleEvaluatorForEvent:
r = regex_cache.get((display_name, False, True), None) r = regex_cache.get((display_name, False, True), None)
if not r: if not r:
r1 = re.escape(display_name) r1 = re.escape(display_name)
r1 = _re_word_boundary(r1) r1 = re_word_boundary(r1)
r = re.compile(r1, flags=re.IGNORECASE) r = re.compile(r1, flags=re.IGNORECASE)
regex_cache[(display_name, False, True)] = r regex_cache[(display_name, False, True)] = r
@ -212,7 +213,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
try: try:
r = regex_cache.get((glob, True, word_boundary), None) r = regex_cache.get((glob, True, word_boundary), None)
if not r: if not r:
r = _glob_to_re(glob, word_boundary) r = glob_to_regex(glob, word_boundary)
regex_cache[(glob, True, word_boundary)] = r regex_cache[(glob, True, word_boundary)] = r
return bool(r.search(value)) return bool(r.search(value))
except re.error: except re.error:
@ -220,56 +221,6 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
return False return False
def _glob_to_re(glob: str, word_boundary: bool) -> Pattern:
"""Generates regex for a given glob.
Args:
glob
word_boundary: Whether to match against word boundaries or entire string.
"""
if IS_GLOB.search(glob):
r = re.escape(glob)
r = r.replace(r"\*", ".*?")
r = r.replace(r"\?", ".")
# handle [abc], [a-z] and [!a-z] style ranges.
r = GLOB_REGEX.sub(
lambda x: (
"[%s%s]" % (x.group(1) and "^" or "", x.group(2).replace(r"\\\-", "-"))
),
r,
)
if word_boundary:
r = _re_word_boundary(r)
return re.compile(r, flags=re.IGNORECASE)
else:
r = "^" + r + "$"
return re.compile(r, flags=re.IGNORECASE)
elif word_boundary:
r = re.escape(glob)
r = _re_word_boundary(r)
return re.compile(r, flags=re.IGNORECASE)
else:
r = "^" + re.escape(glob) + "$"
return re.compile(r, flags=re.IGNORECASE)
def _re_word_boundary(r: str) -> str:
"""
Adds word boundary characters to the start and end of an
expression to require that the match occur as a whole word,
but do so respecting the fact that strings starting or ending
with non-word characters will change word boundaries.
"""
# we can't use \b as it chokes on unicode. however \W seems to be okay
# as shorthand for [^0-9A-Za-z_].
return r"(^|\W)%s(\W|$)" % (r,)
def _flatten_dict( def _flatten_dict(
d: Union[EventBase, dict], d: Union[EventBase, dict],
prefix: Optional[List[str]] = None, prefix: Optional[List[str]] = None,

View File

@ -15,6 +15,7 @@
import json import json
import logging import logging
import re import re
from typing import Pattern
import attr import attr
from frozendict import frozendict from frozendict import frozendict
@ -26,6 +27,9 @@ from synapse.logging import context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_WILDCARD_RUN = re.compile(r"([\?\*]+)")
def _reject_invalid_json(val): def _reject_invalid_json(val):
"""Do not allow Infinity, -Infinity, or NaN values in JSON.""" """Do not allow Infinity, -Infinity, or NaN values in JSON."""
raise ValueError("Invalid JSON value: '%s'" % val) raise ValueError("Invalid JSON value: '%s'" % val)
@ -158,25 +162,54 @@ def log_failure(failure, msg, consumeErrors=True):
return failure return failure
def glob_to_regex(glob): def glob_to_regex(glob: str, word_boundary: bool = False) -> Pattern:
"""Converts a glob to a compiled regex object. """Converts a glob to a compiled regex object.
The regex is anchored at the beginning and end of the string.
Args: Args:
glob (str) glob: pattern to match
word_boundary: If True, the pattern will be allowed to match at word boundaries
anywhere in the string. Otherwise, the pattern is anchored at the start and
end of the string.
Returns: Returns:
re.RegexObject compiled regex pattern
""" """
res = ""
for c in glob:
if c == "*":
res = res + ".*"
elif c == "?":
res = res + "."
else:
res = res + re.escape(c)
# \A anchors at start of string, \Z at end of string # Patterns with wildcards must be simplified to avoid performance cliffs
return re.compile(r"\A" + res + r"\Z", re.IGNORECASE) # - The glob `?**?**?` is equivalent to the glob `???*`
# - The glob `???*` is equivalent to the regex `.{3,}`
chunks = []
for chunk in _WILDCARD_RUN.split(glob):
# No wildcards? re.escape()
if not _WILDCARD_RUN.match(chunk):
chunks.append(re.escape(chunk))
continue
# Wildcards? Simplify.
qmarks = chunk.count("?")
if "*" in chunk:
chunks.append(".{%d,}" % qmarks)
else:
chunks.append(".{%d}" % qmarks)
res = "".join(chunks)
if word_boundary:
res = re_word_boundary(res)
else:
# \A anchors at start of string, \Z at end of string
res = r"\A" + res + r"\Z"
return re.compile(res, re.IGNORECASE)
def re_word_boundary(r: str) -> str:
"""
Adds word boundary characters to the start and end of an
expression to require that the match occur as a whole word,
but do so respecting the fact that strings starting or ending
with non-word characters will change word boundaries.
"""
# we can't use \b as it chokes on unicode. however \W seems to be okay
# as shorthand for [^0-9A-Za-z_].
return r"(^|\W)%s(\W|$)" % (r,)

View File

@ -74,6 +74,25 @@ class ServerACLsTestCase(unittest.TestCase):
self.assertFalse(server_matches_acl_event("[1:2::]", e)) self.assertFalse(server_matches_acl_event("[1:2::]", e))
self.assertTrue(server_matches_acl_event("1:2:3:4", e)) self.assertTrue(server_matches_acl_event("1:2:3:4", e))
def test_wildcard_matching(self):
e = _create_acl_event({"allow": ["good*.com"]})
self.assertTrue(
server_matches_acl_event("good.com", e),
"* matches 0 characters",
)
self.assertTrue(
server_matches_acl_event("GOOD.COM", e),
"pattern is case-insensitive",
)
self.assertTrue(
server_matches_acl_event("good.aa.com", e),
"* matches several characters, including '.'",
)
self.assertFalse(
server_matches_acl_event("ishgood.com", e),
"pattern does not allow prefixes",
)
class StateQueryTests(unittest.FederatingHomeserverTestCase): class StateQueryTests(unittest.FederatingHomeserverTestCase):

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.push import push_rule_evaluator from synapse.push import push_rule_evaluator
@ -66,6 +68,170 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
# A display name with spaces should work fine. # A display name with spaces should work fine.
self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
def _assert_matches(
self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
) -> None:
evaluator = self._get_evaluator(content)
self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
def _assert_not_matches(
self, condition: Dict[str, Any], content: Dict[str, Any], msg=None
) -> None:
evaluator = self._get_evaluator(content)
self.assertFalse(
evaluator.matches(condition, "@user:test", "display_name"), msg
)
def test_event_match_body(self):
"""Check that event_match conditions on content.body work as expected"""
# if the key is `content.body`, the pattern matches substrings.
# non-wildcards should match
condition = {
"kind": "event_match",
"key": "content.body",
"pattern": "foobaz",
}
self._assert_matches(
condition,
{"body": "aaa FoobaZ zzz"},
"patterns should match and be case-insensitive",
)
self._assert_not_matches(
condition,
{"body": "aa xFoobaZ yy"},
"pattern should only match at word boundaries",
)
self._assert_not_matches(
condition,
{"body": "aa foobazx yy"},
"pattern should only match at word boundaries",
)
# wildcards should match
condition = {
"kind": "event_match",
"key": "content.body",
"pattern": "f?o*baz",
}
self._assert_matches(
condition,
{"body": "aaa FoobarbaZ zzz"},
"* should match string and pattern should be case-insensitive",
)
self._assert_matches(
condition, {"body": "aa foobaz yy"}, "* should match 0 characters"
)
self._assert_not_matches(
condition, {"body": "aa fobbaz yy"}, "? should not match 0 characters"
)
self._assert_not_matches(
condition, {"body": "aa fiiobaz yy"}, "? should not match 2 characters"
)
self._assert_not_matches(
condition,
{"body": "aa xfooxbaz yy"},
"pattern should only match at word boundaries",
)
self._assert_not_matches(
condition,
{"body": "aa fooxbazx yy"},
"pattern should only match at word boundaries",
)
# test backslashes
condition = {
"kind": "event_match",
"key": "content.body",
"pattern": r"f\oobaz",
}
self._assert_matches(
condition,
{"body": r"F\oobaz"},
"backslash should match itself",
)
condition = {
"kind": "event_match",
"key": "content.body",
"pattern": r"f\?obaz",
}
self._assert_matches(
condition,
{"body": r"F\oobaz"},
r"? after \ should match any character",
)
def test_event_match_non_body(self):
"""Check that event_match conditions on other keys work as expected"""
# if the key is anything other than 'content.body', the pattern must match the
# whole value.
# non-wildcards should match
condition = {
"kind": "event_match",
"key": "content.value",
"pattern": "foobaz",
}
self._assert_matches(
condition,
{"value": "FoobaZ"},
"patterns should match and be case-insensitive",
)
self._assert_not_matches(
condition,
{"value": "xFoobaZ"},
"pattern should only match at the start/end of the value",
)
self._assert_not_matches(
condition,
{"value": "FoobaZz"},
"pattern should only match at the start/end of the value",
)
# wildcards should match
condition = {
"kind": "event_match",
"key": "content.value",
"pattern": "f?o*baz",
}
self._assert_matches(
condition,
{"value": "FoobarbaZ"},
"* should match string and pattern should be case-insensitive",
)
self._assert_matches(
condition, {"value": "foobaz"}, "* should match 0 characters"
)
self._assert_not_matches(
condition, {"value": "fobbaz"}, "? should not match 0 characters"
)
self._assert_not_matches(
condition, {"value": "fiiobaz"}, "? should not match 2 characters"
)
self._assert_not_matches(
condition,
{"value": "xfooxbaz"},
"pattern should only match at the start/end of the value",
)
self._assert_not_matches(
condition,
{"value": "fooxbazx"},
"pattern should only match at the start/end of the value",
)
self._assert_not_matches(
condition,
{"value": "x\nfooxbaz"},
"pattern should not match after a newline",
)
self._assert_not_matches(
condition,
{"value": "fooxbaz\nx"},
"pattern should not match before a newline",
)
def test_no_body(self): def test_no_body(self):
"""Not having a body shouldn't break the evaluator.""" """Not having a body shouldn't break the evaluator."""
evaluator = self._get_evaluator({}) evaluator = self._get_evaluator({})

View File

@ -0,0 +1,59 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util import glob_to_regex
from tests.unittest import TestCase
class GlobToRegexTestCase(TestCase):
def test_literal_match(self):
"""patterns without wildcards should match"""
pat = glob_to_regex("foobaz")
self.assertTrue(
pat.match("FoobaZ"), "patterns should match and be case-insensitive"
)
self.assertFalse(
pat.match("x foobaz"), "pattern should not match at word boundaries"
)
def test_wildcard_match(self):
pat = glob_to_regex("f?o*baz")
self.assertTrue(
pat.match("FoobarbaZ"),
"* should match string and pattern should be case-insensitive",
)
self.assertTrue(pat.match("foobaz"), "* should match 0 characters")
self.assertFalse(pat.match("fooxaz"), "the character after * must match")
self.assertFalse(pat.match("fobbaz"), "? should not match 0 characters")
self.assertFalse(pat.match("fiiobaz"), "? should not match 2 characters")
def test_multi_wildcard(self):
"""patterns with multiple wildcards in a row should match"""
pat = glob_to_regex("**baz")
self.assertTrue(pat.match("agsgsbaz"), "** should match any string")
self.assertTrue(pat.match("baz"), "** should match the empty string")
self.assertEqual(pat.pattern, r"\A.{0,}baz\Z")
pat = glob_to_regex("*?baz")
self.assertTrue(pat.match("agsgsbaz"), "*? should match any string")
self.assertTrue(pat.match("abaz"), "*? should match a single char")
self.assertFalse(pat.match("baz"), "*? should not match the empty string")
self.assertEqual(pat.pattern, r"\A.{1,}baz\Z")
pat = glob_to_regex("a?*?*?baz")
self.assertTrue(pat.match("a g baz"), "?*?*? should match 3 chars")
self.assertFalse(pat.match("a..baz"), "?*?*? should not match 2 chars")
self.assertTrue(pat.match("a.gg.baz"), "?*?*? should match 4 chars")
self.assertEqual(pat.pattern, r"\Aa.{3,}baz\Z")