mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-12-15 22:23:53 -05:00
Fix IP URL previews on Python 3 (#4215)
This commit is contained in:
parent
c26f49a664
commit
ea6abf6724
8 changed files with 596 additions and 278 deletions
|
|
@ -15,21 +15,55 @@
|
|||
|
||||
import os
|
||||
|
||||
from mock import Mock
|
||||
import attr
|
||||
from netaddr import IPSet
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet._resolver import HostResolution
|
||||
from twisted.internet.address import IPv4Address, IPv6Address
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.test.proto_helpers import AccumulatingProtocol
|
||||
from twisted.web._newclient import ResponseDone
|
||||
|
||||
from synapse.config.repository import MediaStorageProviderConfig
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeTransport
|
||||
|
||||
|
||||
@attr.s
|
||||
class FakeResponse(object):
|
||||
version = attr.ib()
|
||||
code = attr.ib()
|
||||
phrase = attr.ib()
|
||||
headers = attr.ib()
|
||||
body = attr.ib()
|
||||
absoluteURI = attr.ib()
|
||||
|
||||
@property
|
||||
def request(self):
|
||||
@attr.s
|
||||
class FakeTransport(object):
|
||||
absoluteURI = self.absoluteURI
|
||||
|
||||
return FakeTransport()
|
||||
|
||||
def deliverBody(self, protocol):
|
||||
protocol.dataReceived(self.body)
|
||||
protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
||||
|
||||
class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
|
||||
hijack_auth = True
|
||||
user_id = "@test:user"
|
||||
end_content = (
|
||||
b'<html><head>'
|
||||
b'<meta property="og:title" content="~matrix~" />'
|
||||
b'<meta property="og:description" content="hi" />'
|
||||
b'</head></html>'
|
||||
)
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
|
|
@ -39,6 +73,15 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
|||
config = self.default_config()
|
||||
config.url_preview_enabled = True
|
||||
config.max_spider_size = 9999999
|
||||
config.url_preview_ip_range_blacklist = IPSet(
|
||||
(
|
||||
"192.168.1.1",
|
||||
"1.0.0.0/8",
|
||||
"3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
|
||||
"2001:800::/21",
|
||||
)
|
||||
)
|
||||
config.url_preview_ip_range_whitelist = IPSet(("1.1.1.1",))
|
||||
config.url_preview_url_blacklist = []
|
||||
config.media_store_path = self.storage_path
|
||||
|
||||
|
|
@ -62,63 +105,50 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
|||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
|
||||
self.fetches = []
|
||||
|
||||
def get_file(url, output_stream, max_size):
|
||||
"""
|
||||
Returns tuple[int,dict,str,int] of file length, response headers,
|
||||
absolute URI, and response code.
|
||||
"""
|
||||
|
||||
def write_to(r):
|
||||
data, response = r
|
||||
output_stream.write(data)
|
||||
return response
|
||||
|
||||
d = Deferred()
|
||||
d.addCallback(write_to)
|
||||
self.fetches.append((d, url))
|
||||
return make_deferred_yieldable(d)
|
||||
|
||||
client = Mock()
|
||||
client.get_file = get_file
|
||||
|
||||
self.media_repo = hs.get_media_repository_resource()
|
||||
preview_url = self.media_repo.children[b'preview_url']
|
||||
preview_url.client = client
|
||||
self.preview_url = preview_url
|
||||
self.preview_url = self.media_repo.children[b'preview_url']
|
||||
|
||||
self.lookups = {}
|
||||
|
||||
class Resolver(object):
|
||||
def resolveHostName(
|
||||
_self,
|
||||
resolutionReceiver,
|
||||
hostName,
|
||||
portNumber=0,
|
||||
addressTypes=None,
|
||||
transportSemantics='TCP',
|
||||
):
|
||||
|
||||
resolution = HostResolution(hostName)
|
||||
resolutionReceiver.resolutionBegan(resolution)
|
||||
if hostName not in self.lookups:
|
||||
raise DNSLookupError("OH NO")
|
||||
|
||||
for i in self.lookups[hostName]:
|
||||
resolutionReceiver.addressResolved(i[0]('TCP', i[1], portNumber))
|
||||
resolutionReceiver.resolutionComplete()
|
||||
return resolutionReceiver
|
||||
|
||||
self.reactor.nameResolver = Resolver()
|
||||
|
||||
def test_cache_returns_correct_type(self):
|
||||
self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=matrix.org", shorthand=False
|
||||
"GET", "url_preview?url=http://matrix.org", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
# We've made one fetch
|
||||
self.assertEqual(len(self.fetches), 1)
|
||||
|
||||
end_content = (
|
||||
b'<html><head>'
|
||||
b'<meta property="og:title" content="~matrix~" />'
|
||||
b'<meta property="og:description" content="hi" />'
|
||||
b'</head></html>'
|
||||
)
|
||||
|
||||
self.fetches[0][0].callback(
|
||||
(
|
||||
end_content,
|
||||
(
|
||||
len(end_content),
|
||||
{
|
||||
b"Content-Length": [b"%d" % (len(end_content))],
|
||||
b"Content-Type": [b'text/html; charset="utf8"'],
|
||||
},
|
||||
"https://example.com",
|
||||
200,
|
||||
),
|
||||
)
|
||||
client = self.reactor.tcpClients[0][2].buildProtocol(None)
|
||||
server = AccumulatingProtocol()
|
||||
server.makeConnection(FakeTransport(client, self.reactor))
|
||||
client.makeConnection(FakeTransport(server, self.reactor))
|
||||
client.dataReceived(
|
||||
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
|
||||
% (len(self.end_content),)
|
||||
+ self.end_content
|
||||
)
|
||||
|
||||
self.pump()
|
||||
|
|
@ -129,14 +159,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
|||
|
||||
# Check the cache returns the correct response
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=matrix.org", shorthand=False
|
||||
"GET", "url_preview?url=http://matrix.org", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
# Only one fetch, still, since we'll lean on the cache
|
||||
self.assertEqual(len(self.fetches), 1)
|
||||
|
||||
# Check the cache response has the same content
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(
|
||||
|
|
@ -144,20 +171,17 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
# Clear the in-memory cache
|
||||
self.assertIn("matrix.org", self.preview_url._cache)
|
||||
self.preview_url._cache.pop("matrix.org")
|
||||
self.assertNotIn("matrix.org", self.preview_url._cache)
|
||||
self.assertIn("http://matrix.org", self.preview_url._cache)
|
||||
self.preview_url._cache.pop("http://matrix.org")
|
||||
self.assertNotIn("http://matrix.org", self.preview_url._cache)
|
||||
|
||||
# Check the database cache returns the correct response
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=matrix.org", shorthand=False
|
||||
"GET", "url_preview?url=http://matrix.org", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
# Only one fetch, still, since we'll lean on the cache
|
||||
self.assertEqual(len(self.fetches), 1)
|
||||
|
||||
# Check the cache response has the same content
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(
|
||||
|
|
@ -165,15 +189,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
def test_non_ascii_preview_httpequiv(self):
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=matrix.org", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
# We've made one fetch
|
||||
self.assertEqual(len(self.fetches), 1)
|
||||
self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
|
||||
|
||||
end_content = (
|
||||
b'<html><head>'
|
||||
|
|
@ -183,21 +199,23 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
|||
b'</head></html>'
|
||||
)
|
||||
|
||||
self.fetches[0][0].callback(
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=http://matrix.org", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
client = self.reactor.tcpClients[0][2].buildProtocol(None)
|
||||
server = AccumulatingProtocol()
|
||||
server.makeConnection(FakeTransport(client, self.reactor))
|
||||
client.makeConnection(FakeTransport(server, self.reactor))
|
||||
client.dataReceived(
|
||||
(
|
||||
end_content,
|
||||
(
|
||||
len(end_content),
|
||||
{
|
||||
b"Content-Length": [b"%d" % (len(end_content))],
|
||||
# This charset=utf-8 should be ignored, because the
|
||||
# document has a meta tag overriding it.
|
||||
b"Content-Type": [b'text/html; charset="utf8"'],
|
||||
},
|
||||
"https://example.com",
|
||||
200,
|
||||
),
|
||||
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
|
||||
b"Content-Type: text/html; charset=\"utf8\"\r\n\r\n"
|
||||
)
|
||||
% (len(end_content),)
|
||||
+ end_content
|
||||
)
|
||||
|
||||
self.pump()
|
||||
|
|
@ -205,15 +223,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
|||
self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
|
||||
|
||||
def test_non_ascii_preview_content_type(self):
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=matrix.org", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
# We've made one fetch
|
||||
self.assertEqual(len(self.fetches), 1)
|
||||
self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
|
||||
|
||||
end_content = (
|
||||
b'<html><head>'
|
||||
|
|
@ -222,21 +232,239 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
|||
b'</head></html>'
|
||||
)
|
||||
|
||||
self.fetches[0][0].callback(
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=http://matrix.org", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
client = self.reactor.tcpClients[0][2].buildProtocol(None)
|
||||
server = AccumulatingProtocol()
|
||||
server.makeConnection(FakeTransport(client, self.reactor))
|
||||
client.makeConnection(FakeTransport(server, self.reactor))
|
||||
client.dataReceived(
|
||||
(
|
||||
end_content,
|
||||
(
|
||||
len(end_content),
|
||||
{
|
||||
b"Content-Length": [b"%d" % (len(end_content))],
|
||||
b"Content-Type": [b'text/html; charset="windows-1251"'],
|
||||
},
|
||||
"https://example.com",
|
||||
200,
|
||||
),
|
||||
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
|
||||
b"Content-Type: text/html; charset=\"windows-1251\"\r\n\r\n"
|
||||
)
|
||||
% (len(end_content),)
|
||||
+ end_content
|
||||
)
|
||||
|
||||
self.pump()
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
|
||||
|
||||
def test_ipaddr(self):
|
||||
"""
|
||||
IP addresses can be previewed directly.
|
||||
"""
|
||||
self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=http://example.com", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
client = self.reactor.tcpClients[0][2].buildProtocol(None)
|
||||
server = AccumulatingProtocol()
|
||||
server.makeConnection(FakeTransport(client, self.reactor))
|
||||
client.makeConnection(FakeTransport(server, self.reactor))
|
||||
client.dataReceived(
|
||||
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
|
||||
% (len(self.end_content),)
|
||||
+ self.end_content
|
||||
)
|
||||
|
||||
self.pump()
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(
|
||||
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
|
||||
)
|
||||
|
||||
def test_blacklisted_ip_specific(self):
|
||||
"""
|
||||
Blacklisted IP addresses, found via DNS, are not spidered.
|
||||
"""
|
||||
self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=http://example.com", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
# No requests made.
|
||||
self.assertEqual(len(self.reactor.tcpClients), 0)
|
||||
self.assertEqual(channel.code, 403)
|
||||
self.assertEqual(
|
||||
channel.json_body,
|
||||
{
|
||||
'errcode': 'M_UNKNOWN',
|
||||
'error': 'IP address blocked by IP blacklist entry',
|
||||
},
|
||||
)
|
||||
|
||||
def test_blacklisted_ip_range(self):
|
||||
"""
|
||||
Blacklisted IP ranges, IPs found over DNS, are not spidered.
|
||||
"""
|
||||
self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=http://example.com", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
self.assertEqual(channel.code, 403)
|
||||
self.assertEqual(
|
||||
channel.json_body,
|
||||
{
|
||||
'errcode': 'M_UNKNOWN',
|
||||
'error': 'IP address blocked by IP blacklist entry',
|
||||
},
|
||||
)
|
||||
|
||||
def test_blacklisted_ip_specific_direct(self):
|
||||
"""
|
||||
Blacklisted IP addresses, accessed directly, are not spidered.
|
||||
"""
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=http://192.168.1.1", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
# No requests made.
|
||||
self.assertEqual(len(self.reactor.tcpClients), 0)
|
||||
self.assertEqual(channel.code, 403)
|
||||
self.assertEqual(
|
||||
channel.json_body,
|
||||
{
|
||||
'errcode': 'M_UNKNOWN',
|
||||
'error': 'IP address blocked by IP blacklist entry',
|
||||
},
|
||||
)
|
||||
|
||||
def test_blacklisted_ip_range_direct(self):
|
||||
"""
|
||||
Blacklisted IP ranges, accessed directly, are not spidered.
|
||||
"""
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=http://1.1.1.2", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
self.assertEqual(channel.code, 403)
|
||||
self.assertEqual(
|
||||
channel.json_body,
|
||||
{
|
||||
'errcode': 'M_UNKNOWN',
|
||||
'error': 'IP address blocked by IP blacklist entry',
|
||||
},
|
||||
)
|
||||
|
||||
def test_blacklisted_ip_range_whitelisted_ip(self):
|
||||
"""
|
||||
Blacklisted but then subsequently whitelisted IP addresses can be
|
||||
spidered.
|
||||
"""
|
||||
self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=http://example.com", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
client = self.reactor.tcpClients[0][2].buildProtocol(None)
|
||||
|
||||
server = AccumulatingProtocol()
|
||||
server.makeConnection(FakeTransport(client, self.reactor))
|
||||
client.makeConnection(FakeTransport(server, self.reactor))
|
||||
|
||||
client.dataReceived(
|
||||
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
|
||||
% (len(self.end_content),)
|
||||
+ self.end_content
|
||||
)
|
||||
|
||||
self.pump()
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(
|
||||
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
|
||||
)
|
||||
|
||||
def test_blacklisted_ip_with_external_ip(self):
|
||||
"""
|
||||
If a hostname resolves a blacklisted IP, even if there's a
|
||||
non-blacklisted one, it will be rejected.
|
||||
"""
|
||||
# Hardcode the URL resolving to the IP we want.
|
||||
self.lookups[u"example.com"] = [
|
||||
(IPv4Address, "1.1.1.2"),
|
||||
(IPv4Address, "8.8.8.8"),
|
||||
]
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=http://example.com", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
self.assertEqual(channel.code, 403)
|
||||
self.assertEqual(
|
||||
channel.json_body,
|
||||
{
|
||||
'errcode': 'M_UNKNOWN',
|
||||
'error': 'IP address blocked by IP blacklist entry',
|
||||
},
|
||||
)
|
||||
|
||||
def test_blacklisted_ipv6_specific(self):
|
||||
"""
|
||||
Blacklisted IP addresses, found via DNS, are not spidered.
|
||||
"""
|
||||
self.lookups["example.com"] = [
|
||||
(IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
|
||||
]
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=http://example.com", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
# No requests made.
|
||||
self.assertEqual(len(self.reactor.tcpClients), 0)
|
||||
self.assertEqual(channel.code, 403)
|
||||
self.assertEqual(
|
||||
channel.json_body,
|
||||
{
|
||||
'errcode': 'M_UNKNOWN',
|
||||
'error': 'IP address blocked by IP blacklist entry',
|
||||
},
|
||||
)
|
||||
|
||||
def test_blacklisted_ipv6_range(self):
|
||||
"""
|
||||
Blacklisted IP ranges, IPs found over DNS, are not spidered.
|
||||
"""
|
||||
self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "url_preview?url=http://example.com", shorthand=False
|
||||
)
|
||||
request.render(self.preview_url)
|
||||
self.pump()
|
||||
|
||||
self.assertEqual(channel.code, 403)
|
||||
self.assertEqual(
|
||||
channel.json_body,
|
||||
{
|
||||
'errcode': 'M_UNKNOWN',
|
||||
'error': 'IP address blocked by IP blacklist entry',
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -383,8 +383,16 @@ class FakeTransport(object):
|
|||
self.disconnecting = True
|
||||
|
||||
def pauseProducing(self):
|
||||
if not self.producer:
|
||||
return
|
||||
|
||||
self.producer.pauseProducing()
|
||||
|
||||
def resumeProducing(self):
|
||||
if not self.producer:
|
||||
return
|
||||
self.producer.resumeProducing()
|
||||
|
||||
def unregisterProducer(self):
|
||||
if not self.producer:
|
||||
return
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue