Merge pull request #3679 from matrix-org/hawkowl/blackify-tests

Blackify the tests
This commit is contained in:
Amber Brown 2018-08-11 00:12:56 +10:00 committed by GitHub
commit a001038b92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
76 changed files with 1628 additions and 2277 deletions

1
changelog.d/3679.misc Normal file
View File

@ -0,0 +1 @@
Synapse's tests are now formatted with the black autoformatter.

View File

@ -34,7 +34,6 @@ class TestHandlers(object):
class AuthTestCase(unittest.TestCase): class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.state_handler = Mock() self.state_handler = Mock()
@ -53,11 +52,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self): def test_get_user_by_req_user_valid_token(self):
user_info = { user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
"name": self.test_user,
"token_id": "ditto",
"device_id": "device",
}
self.store.get_user_by_access_token = Mock(return_value=user_info) self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={}) request = Mock(args={})
@ -76,10 +71,7 @@ class AuthTestCase(unittest.TestCase):
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
def test_get_user_by_req_user_missing_token(self): def test_get_user_by_req_user_missing_token(self):
user_info = { user_info = {"name": self.test_user, "token_id": "ditto"}
"name": self.test_user,
"token_id": "ditto",
}
self.store.get_user_by_access_token = Mock(return_value=user_info) self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={}) request = Mock(args={})
@ -90,8 +82,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token(self): def test_get_user_by_req_appservice_valid_token(self):
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
ip_range_whitelist=None,
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
@ -106,8 +97,11 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_good_ip(self): def test_get_user_by_req_appservice_valid_token_good_ip(self):
from netaddr import IPSet from netaddr import IPSet
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar",
url="a_url",
sender=self.test_user,
ip_range_whitelist=IPSet(["192.168/16"]), ip_range_whitelist=IPSet(["192.168/16"]),
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
@ -122,8 +116,11 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_valid_token_bad_ip(self): def test_get_user_by_req_appservice_valid_token_bad_ip(self):
from netaddr import IPSet from netaddr import IPSet
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar",
url="a_url",
sender=self.test_user,
ip_range_whitelist=IPSet(["192.168/16"]), ip_range_whitelist=IPSet(["192.168/16"]),
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
@ -160,8 +157,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_valid_token_valid_user_id(self): def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
masquerading_user_id = b"@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
ip_range_whitelist=None,
) )
app_service.is_interested_in_user = Mock(return_value=True) app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
@ -174,15 +170,13 @@ class AuthTestCase(unittest.TestCase):
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals( self.assertEquals(
requester.user.to_string(), requester.user.to_string(), masquerading_user_id.decode('utf8')
masquerading_user_id.decode('utf8')
) )
def test_get_user_by_req_appservice_valid_token_bad_user_id(self): def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
masquerading_user_id = b"@doppelganger:matrix.org" masquerading_user_id = b"@doppelganger:matrix.org"
app_service = Mock( app_service = Mock(
token="foobar", url="a_url", sender=self.test_user, token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
ip_range_whitelist=None,
) )
app_service.is_interested_in_user = Mock(return_value=False) app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
@ -201,17 +195,15 @@ class AuthTestCase(unittest.TestCase):
# TODO(danielwh): Remove this mock when we remove the # TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback. # get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(
return_value={ return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
"name": "@baldrick:matrix.org",
"device_id": "device",
}
) )
user_id = "@baldrick:matrix.org" user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key) key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
@ -225,15 +217,14 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self): def test_get_guest_user_from_macaroon(self):
self.store.get_user_by_id = Mock(return_value={ self.store.get_user_by_id = Mock(return_value={"is_guest": True})
"is_guest": True,
})
user_id = "@baldrick:matrix.org" user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key) key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
@ -257,7 +248,8 @@ class AuthTestCase(unittest.TestCase):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key) key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,)) macaroon.add_first_party_caveat("user_id = %s" % (user,))
@ -277,7 +269,8 @@ class AuthTestCase(unittest.TestCase):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key) key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
@ -298,7 +291,8 @@ class AuthTestCase(unittest.TestCase):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key + "wrong") key=self.hs.config.macaroon_secret_key + "wrong",
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,)) macaroon.add_first_party_caveat("user_id = %s" % (user,))
@ -320,7 +314,8 @@ class AuthTestCase(unittest.TestCase):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key) key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,)) macaroon.add_first_party_caveat("user_id = %s" % (user,))
@ -347,7 +342,8 @@ class AuthTestCase(unittest.TestCase):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key) key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,)) macaroon.add_first_party_caveat("user_id = %s" % (user,))
@ -380,7 +376,8 @@ class AuthTestCase(unittest.TestCase):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.hs.config.macaroon_secret_key) key=self.hs.config.macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
@ -401,9 +398,7 @@ class AuthTestCase(unittest.TestCase):
token = yield self.hs.handlers.auth_handler.issue_access_token( token = yield self.hs.handlers.auth_handler.issue_access_token(
USER_ID, "DEVICE" USER_ID, "DEVICE"
) )
self.store.add_access_token_to_user.assert_called_with( self.store.add_access_token_to_user.assert_called_with(USER_ID, token, "DEVICE")
USER_ID, token, "DEVICE"
)
def get_user(tok): def get_user(tok):
if token != tok: if token != tok:
@ -414,10 +409,9 @@ class AuthTestCase(unittest.TestCase):
"token_id": 1234, "token_id": 1234,
"device_id": "DEVICE", "device_id": "DEVICE",
} }
self.store.get_user_by_access_token = get_user self.store.get_user_by_access_token = get_user
self.store.get_user_by_id = Mock(return_value={ self.store.get_user_by_id = Mock(return_value={"is_guest": False})
"is_guest": False,
})
# check the token works # check the token works
request = Mock(args={}) request = Mock(args={})

View File

@ -38,7 +38,6 @@ def MockEvent(**kwargs):
class FilteringTestCase(unittest.TestCase): class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_federation_resource = MockHttpResource() self.mock_federation_resource = MockHttpResource()
@ -47,9 +46,7 @@ class FilteringTestCase(unittest.TestCase):
self.mock_http_client.put_json = DeferredMockCallable() self.mock_http_client.put_json = DeferredMockCallable()
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
handlers=None, handlers=None, http_client=self.mock_http_client, keyring=Mock()
http_client=self.mock_http_client,
keyring=Mock(),
) )
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
@ -64,7 +61,7 @@ class FilteringTestCase(unittest.TestCase):
{"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}}, {"room": {"timeline": {"limit": 0}, "state": {"not_bars": ["*"]}}},
{"event_format": "other"}, {"event_format": "other"},
{"room": {"not_rooms": ["#foo:pik-test"]}}, {"room": {"not_rooms": ["#foo:pik-test"]}},
{"presence": {"senders": ["@bar;pik.test.com"]}} {"presence": {"senders": ["@bar;pik.test.com"]}},
] ]
for filter in invalid_filters: for filter in invalid_filters:
with self.assertRaises(SynapseError) as check_filter_error: with self.assertRaises(SynapseError) as check_filter_error:
@ -81,34 +78,34 @@ class FilteringTestCase(unittest.TestCase):
"include_leave": False, "include_leave": False,
"rooms": ["!dee:pik-test"], "rooms": ["!dee:pik-test"],
"not_rooms": ["!gee:pik-test"], "not_rooms": ["!gee:pik-test"],
"account_data": {"limit": 0, "types": ["*"]} "account_data": {"limit": 0, "types": ["*"]},
} }
}, },
{ {
"room": { "room": {
"state": { "state": {
"types": ["m.room.*"], "types": ["m.room.*"],
"not_rooms": ["!726s6s6q:example.com"] "not_rooms": ["!726s6s6q:example.com"],
}, },
"timeline": { "timeline": {
"limit": 10, "limit": 10,
"types": ["m.room.message"], "types": ["m.room.message"],
"not_rooms": ["!726s6s6q:example.com"], "not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"] "not_senders": ["@spam:example.com"],
}, },
"ephemeral": { "ephemeral": {
"types": ["m.receipt", "m.typing"], "types": ["m.receipt", "m.typing"],
"not_rooms": ["!726s6s6q:example.com"], "not_rooms": ["!726s6s6q:example.com"],
"not_senders": ["@spam:example.com"] "not_senders": ["@spam:example.com"],
} },
}, },
"presence": { "presence": {
"types": ["m.presence"], "types": ["m.presence"],
"not_senders": ["@alice:example.com"] "not_senders": ["@alice:example.com"],
}, },
"event_format": "client", "event_format": "client",
"event_fields": ["type", "content", "sender"] "event_fields": ["type", "content", "sender"],
} },
] ]
for filter in valid_filters: for filter in valid_filters:
try: try:
@ -121,229 +118,131 @@ class FilteringTestCase(unittest.TestCase):
pass pass
def test_definition_types_works_with_literals(self): def test_definition_types_works_with_literals(self):
definition = { definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
"types": ["m.room.message", "org.matrix.foo.bar"] event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertTrue( self.assertTrue(Filter(definition).check(event))
Filter(definition).check(event)
)
def test_definition_types_works_with_wildcards(self): def test_definition_types_works_with_wildcards(self):
definition = { definition = {"types": ["m.*", "org.matrix.foo.bar"]}
"types": ["m.*", "org.matrix.foo.bar"] event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
} self.assertTrue(Filter(definition).check(event))
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertTrue(
Filter(definition).check(event)
)
def test_definition_types_works_with_unknowns(self): def test_definition_types_works_with_unknowns(self):
definition = { definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
"types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
type="now.for.something.completely.different", type="now.for.something.completely.different",
room_id="!foo:bar" room_id="!foo:bar",
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_not_types_works_with_literals(self): def test_definition_not_types_works_with_literals(self):
definition = { definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
"not_types": ["m.room.message", "org.matrix.foo.bar"] event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
} self.assertFalse(Filter(definition).check(event))
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertFalse(
Filter(definition).check(event)
)
def test_definition_not_types_works_with_wildcards(self): def test_definition_not_types_works_with_wildcards(self):
definition = { definition = {"not_types": ["m.room.message", "org.matrix.*"]}
"not_types": ["m.room.message", "org.matrix.*"]
}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
type="org.matrix.custom.event",
room_id="!foo:bar"
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_not_types_works_with_unknowns(self): def test_definition_not_types_works_with_unknowns(self):
definition = { definition = {"not_types": ["m.*", "org.*"]}
"not_types": ["m.*", "org.*"] event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
} self.assertTrue(Filter(definition).check(event))
event = MockEvent(
sender="@foo:bar",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
Filter(definition).check(event)
)
def test_definition_not_types_takes_priority_over_types(self): def test_definition_not_types_takes_priority_over_types(self):
definition = { definition = {
"not_types": ["m.*", "org.*"], "not_types": ["m.*", "org.*"],
"types": ["m.room.message", "m.room.topic"] "types": ["m.room.message", "m.room.topic"],
} }
event = MockEvent( event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
sender="@foo:bar", self.assertFalse(Filter(definition).check(event))
type="m.room.topic",
room_id="!foo:bar"
)
self.assertFalse(
Filter(definition).check(event)
)
def test_definition_senders_works_with_literals(self): def test_definition_senders_works_with_literals(self):
definition = { definition = {"senders": ["@flibble:wibble"]}
"senders": ["@flibble:wibble"]
}
event = MockEvent( event = MockEvent(
sender="@flibble:wibble", sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
Filter(definition).check(event)
) )
self.assertTrue(Filter(definition).check(event))
def test_definition_senders_works_with_unknowns(self): def test_definition_senders_works_with_unknowns(self):
definition = { definition = {"senders": ["@flibble:wibble"]}
"senders": ["@flibble:wibble"]
}
event = MockEvent( event = MockEvent(
sender="@challenger:appears", sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_not_senders_works_with_literals(self): def test_definition_not_senders_works_with_literals(self):
definition = { definition = {"not_senders": ["@flibble:wibble"]}
"not_senders": ["@flibble:wibble"]
}
event = MockEvent( event = MockEvent(
sender="@flibble:wibble", sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_not_senders_works_with_unknowns(self): def test_definition_not_senders_works_with_unknowns(self):
definition = { definition = {"not_senders": ["@flibble:wibble"]}
"not_senders": ["@flibble:wibble"]
}
event = MockEvent( event = MockEvent(
sender="@challenger:appears", sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
Filter(definition).check(event)
) )
self.assertTrue(Filter(definition).check(event))
def test_definition_not_senders_takes_priority_over_senders(self): def test_definition_not_senders_takes_priority_over_senders(self):
definition = { definition = {
"not_senders": ["@misspiggy:muppets"], "not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets", "@misspiggy:muppets"] "senders": ["@kermit:muppets", "@misspiggy:muppets"],
} }
event = MockEvent( event = MockEvent(
sender="@misspiggy:muppets", sender="@misspiggy:muppets", type="m.room.topic", room_id="!foo:bar"
type="m.room.topic",
room_id="!foo:bar"
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_rooms_works_with_literals(self): def test_definition_rooms_works_with_literals(self):
definition = { definition = {"rooms": ["!secretbase:unknown"]}
"rooms": ["!secretbase:unknown"]
}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
type="m.room.message",
room_id="!secretbase:unknown"
)
self.assertTrue(
Filter(definition).check(event)
) )
self.assertTrue(Filter(definition).check(event))
def test_definition_rooms_works_with_unknowns(self): def test_definition_rooms_works_with_unknowns(self):
definition = { definition = {"rooms": ["!secretbase:unknown"]}
"rooms": ["!secretbase:unknown"]
}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
type="m.room.message", type="m.room.message",
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown",
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_not_rooms_works_with_literals(self): def test_definition_not_rooms_works_with_literals(self):
definition = { definition = {"not_rooms": ["!anothersecretbase:unknown"]}
"not_rooms": ["!anothersecretbase:unknown"]
}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
type="m.room.message", type="m.room.message",
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown",
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_not_rooms_works_with_unknowns(self): def test_definition_not_rooms_works_with_unknowns(self):
definition = { definition = {"not_rooms": ["!secretbase:unknown"]}
"not_rooms": ["!secretbase:unknown"]
}
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
type="m.room.message", type="m.room.message",
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown",
)
self.assertTrue(
Filter(definition).check(event)
) )
self.assertTrue(Filter(definition).check(event))
def test_definition_not_rooms_takes_priority_over_rooms(self): def test_definition_not_rooms_takes_priority_over_rooms(self):
definition = { definition = {
"not_rooms": ["!secretbase:unknown"], "not_rooms": ["!secretbase:unknown"],
"rooms": ["!secretbase:unknown"] "rooms": ["!secretbase:unknown"],
} }
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
type="m.room.message",
room_id="!secretbase:unknown"
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_combined_event(self): def test_definition_combined_event(self):
definition = { definition = {
@ -352,16 +251,14 @@ class FilteringTestCase(unittest.TestCase):
"rooms": ["!stage:unknown"], "rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"], "not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"], "types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"] "not_types": ["muppets.misspiggy.*"],
} }
event = MockEvent( event = MockEvent(
sender="@kermit:muppets", # yup sender="@kermit:muppets", # yup
type="m.room.message", # yup type="m.room.message", # yup
room_id="!stage:unknown" # yup room_id="!stage:unknown", # yup
)
self.assertTrue(
Filter(definition).check(event)
) )
self.assertTrue(Filter(definition).check(event))
def test_definition_combined_event_bad_sender(self): def test_definition_combined_event_bad_sender(self):
definition = { definition = {
@ -370,16 +267,14 @@ class FilteringTestCase(unittest.TestCase):
"rooms": ["!stage:unknown"], "rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"], "not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"], "types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"] "not_types": ["muppets.misspiggy.*"],
} }
event = MockEvent( event = MockEvent(
sender="@misspiggy:muppets", # nope sender="@misspiggy:muppets", # nope
type="m.room.message", # yup type="m.room.message", # yup
room_id="!stage:unknown" # yup room_id="!stage:unknown", # yup
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_combined_event_bad_room(self): def test_definition_combined_event_bad_room(self):
definition = { definition = {
@ -388,16 +283,14 @@ class FilteringTestCase(unittest.TestCase):
"rooms": ["!stage:unknown"], "rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"], "not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"], "types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"] "not_types": ["muppets.misspiggy.*"],
} }
event = MockEvent( event = MockEvent(
sender="@kermit:muppets", # yup sender="@kermit:muppets", # yup
type="m.room.message", # yup type="m.room.message", # yup
room_id="!piggyshouse:muppets" # nope room_id="!piggyshouse:muppets", # nope
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
def test_definition_combined_event_bad_type(self): def test_definition_combined_event_bad_type(self):
definition = { definition = {
@ -406,37 +299,26 @@ class FilteringTestCase(unittest.TestCase):
"rooms": ["!stage:unknown"], "rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"], "not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"], "types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"] "not_types": ["muppets.misspiggy.*"],
} }
event = MockEvent( event = MockEvent(
sender="@kermit:muppets", # yup sender="@kermit:muppets", # yup
type="muppets.misspiggy.kisses", # nope type="muppets.misspiggy.kisses", # nope
room_id="!stage:unknown" # yup room_id="!stage:unknown", # yup
)
self.assertFalse(
Filter(definition).check(event)
) )
self.assertFalse(Filter(definition).check(event))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_presence_match(self): def test_filter_presence_match(self):
user_filter_json = { user_filter_json = {"presence": {"types": ["m.*"]}}
"presence": {
"types": ["m.*"]
}
}
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, user_filter=user_filter_json
user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
type="m.profile",
) )
event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event] events = [event]
user_filter = yield self.filtering.get_user_filter( user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, filter_id=filter_id
filter_id=filter_id,
) )
results = user_filter.filter_presence(events=events) results = user_filter.filter_presence(events=events)
@ -444,15 +326,10 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_presence_no_match(self): def test_filter_presence_no_match(self):
user_filter_json = { user_filter_json = {"presence": {"types": ["m.*"]}}
"presence": {
"types": ["m.*"]
}
}
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart + "2", user_localpart=user_localpart + "2", user_filter=user_filter_json
user_filter=user_filter_json,
) )
event = MockEvent( event = MockEvent(
event_id="$asdasd:localhost", event_id="$asdasd:localhost",
@ -462,8 +339,7 @@ class FilteringTestCase(unittest.TestCase):
events = [event] events = [event]
user_filter = yield self.filtering.get_user_filter( user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart + "2", user_localpart=user_localpart + "2", filter_id=filter_id
filter_id=filter_id,
) )
results = user_filter.filter_presence(events=events) results = user_filter.filter_presence(events=events)
@ -471,27 +347,15 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_room_state_match(self): def test_filter_room_state_match(self):
user_filter_json = { user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, user_filter=user_filter_json
user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
type="m.room.topic",
room_id="!foo:bar"
) )
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event] events = [event]
user_filter = yield self.filtering.get_user_filter( user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, filter_id=filter_id
filter_id=filter_id,
) )
results = user_filter.filter_room_state(events=events) results = user_filter.filter_room_state(events=events)
@ -499,27 +363,17 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_room_state_no_match(self): def test_filter_room_state_no_match(self):
user_filter_json = { user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, user_filter=user_filter_json
user_filter=user_filter_json,
) )
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
type="org.matrix.custom.event",
room_id="!foo:bar"
) )
events = [event] events = [event]
user_filter = yield self.filtering.get_user_filter( user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, filter_id=filter_id
filter_id=filter_id,
) )
results = user_filter.filter_room_state(events) results = user_filter.filter_room_state(events)
@ -543,45 +397,32 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_add_filter(self): def test_add_filter(self):
user_filter_json = { user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.filtering.add_user_filter( filter_id = yield self.filtering.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, user_filter=user_filter_json
user_filter=user_filter_json,
) )
self.assertEquals(filter_id, 0) self.assertEquals(filter_id, 0)
self.assertEquals(user_filter_json, ( self.assertEquals(
user_filter_json,
(
yield self.datastore.get_user_filter( yield self.datastore.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, filter_id=0
filter_id=0, )
),
) )
))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_filter(self): def test_get_filter(self):
user_filter_json = { user_filter_json = {"room": {"state": {"types": ["m.*"]}}}
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, user_filter=user_filter_json
user_filter=user_filter_json,
) )
filter = yield self.filtering.get_user_filter( filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart, filter_id=filter_id
filter_id=filter_id,
) )
self.assertEquals(filter.get_filter_json(), user_filter_json) self.assertEquals(filter.get_filter_json(), user_filter_json)

View File

@ -4,17 +4,16 @@ from tests import unittest
class TestRatelimiter(unittest.TestCase): class TestRatelimiter(unittest.TestCase):
def test_allowed(self): def test_allowed(self):
limiter = Ratelimiter() limiter = Ratelimiter()
allowed, time_allowed = limiter.send_message( allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1, user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1
) )
self.assertTrue(allowed) self.assertTrue(allowed)
self.assertEquals(10., time_allowed) self.assertEquals(10., time_allowed)
allowed, time_allowed = limiter.send_message( allowed, time_allowed = limiter.send_message(
user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1, user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1
) )
self.assertFalse(allowed) self.assertFalse(allowed)
self.assertEquals(10., time_allowed) self.assertEquals(10., time_allowed)
@ -28,7 +27,7 @@ class TestRatelimiter(unittest.TestCase):
def test_pruning(self): def test_pruning(self):
limiter = Ratelimiter() limiter = Ratelimiter()
allowed, time_allowed = limiter.send_message( allowed, time_allowed = limiter.send_message(
user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1, user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1
) )
self.assertIn("test_id_1", limiter.message_counts) self.assertIn("test_id_1", limiter.message_counts)

View File

@ -24,14 +24,10 @@ from tests import unittest
def _regex(regex, exclusive=True): def _regex(regex, exclusive=True):
return { return {"regex": re.compile(regex), "exclusive": exclusive}
"regex": re.compile(regex),
"exclusive": exclusive
}
class ApplicationServiceTestCase(unittest.TestCase): class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.service = ApplicationService( self.service = ApplicationService(
id="unique_identifier", id="unique_identifier",
@ -41,8 +37,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
namespaces={ namespaces={
ApplicationService.NS_USERS: [], ApplicationService.NS_USERS: [],
ApplicationService.NS_ROOMS: [], ApplicationService.NS_ROOMS: [],
ApplicationService.NS_ALIASES: [] ApplicationService.NS_ALIASES: [],
} },
) )
self.event = Mock( self.event = Mock(
type="m.something", room_id="!foo:bar", sender="@someone:somewhere" type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
@ -52,25 +48,19 @@ class ApplicationServiceTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_user_id_prefix_match(self): def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue((yield self.service.is_interested(self.event))) self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_user_id_prefix_no_match(self): def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
_regex("@irc_.*")
)
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.assertFalse((yield self.service.is_interested(self.event))) self.assertFalse((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_room_member_is_checked(self): def test_regex_room_member_is_checked(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
_regex("@irc_.*")
)
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member" self.event.type = "m.room.member"
self.event.state_key = "@irc_foobar:matrix.org" self.event.state_key = "@irc_foobar:matrix.org"
@ -98,60 +88,47 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.store.get_aliases_for_room.return_value = [ self.store.get_aliases_for_room.return_value = [
"#irc_foobar:matrix.org", "#athing:matrix.org" "#irc_foobar:matrix.org",
"#athing:matrix.org",
] ]
self.store.get_users_in_room.return_value = [] self.store.get_users_in_room.return_value = []
self.assertTrue((yield self.service.is_interested( self.assertTrue((yield self.service.is_interested(self.event, self.store)))
self.event, self.store
)))
def test_non_exclusive_alias(self): def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org", exclusive=False) _regex("#irc_.*:matrix.org", exclusive=False)
) )
self.assertFalse(self.service.is_exclusive_alias( self.assertFalse(self.service.is_exclusive_alias("#irc_foobar:matrix.org"))
"#irc_foobar:matrix.org"
))
def test_non_exclusive_room(self): def test_non_exclusive_room(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!irc_.*:matrix.org", exclusive=False) _regex("!irc_.*:matrix.org", exclusive=False)
) )
self.assertFalse(self.service.is_exclusive_room( self.assertFalse(self.service.is_exclusive_room("!irc_foobar:matrix.org"))
"!irc_foobar:matrix.org"
))
def test_non_exclusive_user(self): def test_non_exclusive_user(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*:matrix.org", exclusive=False) _regex("@irc_.*:matrix.org", exclusive=False)
) )
self.assertFalse(self.service.is_exclusive_user( self.assertFalse(self.service.is_exclusive_user("@irc_foobar:matrix.org"))
"@irc_foobar:matrix.org"
))
def test_exclusive_alias(self): def test_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org", exclusive=True) _regex("#irc_.*:matrix.org", exclusive=True)
) )
self.assertTrue(self.service.is_exclusive_alias( self.assertTrue(self.service.is_exclusive_alias("#irc_foobar:matrix.org"))
"#irc_foobar:matrix.org"
))
def test_exclusive_user(self): def test_exclusive_user(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*:matrix.org", exclusive=True) _regex("@irc_.*:matrix.org", exclusive=True)
) )
self.assertTrue(self.service.is_exclusive_user( self.assertTrue(self.service.is_exclusive_user("@irc_foobar:matrix.org"))
"@irc_foobar:matrix.org"
))
def test_exclusive_room(self): def test_exclusive_room(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!irc_.*:matrix.org", exclusive=True) _regex("!irc_.*:matrix.org", exclusive=True)
) )
self.assertTrue(self.service.is_exclusive_room( self.assertTrue(self.service.is_exclusive_room("!irc_foobar:matrix.org"))
"!irc_foobar:matrix.org"
))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_alias_no_match(self): def test_regex_alias_no_match(self):
@ -159,47 +136,36 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.store.get_aliases_for_room.return_value = [ self.store.get_aliases_for_room.return_value = [
"#xmpp_foobar:matrix.org", "#athing:matrix.org" "#xmpp_foobar:matrix.org",
"#athing:matrix.org",
] ]
self.store.get_users_in_room.return_value = [] self.store.get_users_in_room.return_value = []
self.assertFalse((yield self.service.is_interested( self.assertFalse((yield self.service.is_interested(self.event, self.store)))
self.event, self.store
)))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_multiple_matches(self): def test_regex_multiple_matches(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
_regex("@irc_.*")
)
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"] self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"]
self.store.get_users_in_room.return_value = [] self.store.get_users_in_room.return_value = []
self.assertTrue((yield self.service.is_interested( self.assertTrue((yield self.service.is_interested(self.event, self.store)))
self.event, self.store
)))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_interested_in_self(self): def test_interested_in_self(self):
# make sure invites get through # make sure invites get through
self.service.sender = "@appservice:name" self.service.sender = "@appservice:name"
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
_regex("@irc_.*")
)
self.event.type = "m.room.member" self.event.type = "m.room.member"
self.event.content = { self.event.content = {"membership": "invite"}
"membership": "invite"
}
self.event.state_key = self.service.sender self.event.state_key = self.service.sender
self.assertTrue((yield self.service.is_interested(self.event))) self.assertTrue((yield self.service.is_interested(self.event)))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_member_list_match(self): def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
_regex("@irc_.*")
)
self.store.get_users_in_room.return_value = [ self.store.get_users_in_room.return_value = [
"@alice:here", "@alice:here",
"@irc_fo:here", # AS user "@irc_fo:here", # AS user
@ -208,6 +174,6 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.store.get_aliases_for_room.return_value = [] self.store.get_aliases_for_room.return_value = []
self.event.sender = "@xmpp_foobar:matrix.org" self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue((yield self.service.is_interested( self.assertTrue(
event=self.event, store=self.store (yield self.service.is_interested(event=self.event, store=self.store))
))) )

View File

@ -30,7 +30,6 @@ from ..utils import MockClock
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.clock = MockClock() self.clock = MockClock()
self.store = Mock() self.store = Mock()
@ -38,8 +37,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.recoverer = Mock() self.recoverer = Mock()
self.recoverer_fn = Mock(return_value=self.recoverer) self.recoverer_fn = Mock(return_value=self.recoverer)
self.txnctrl = _TransactionController( self.txnctrl = _TransactionController(
clock=self.clock, store=self.store, as_api=self.as_api, clock=self.clock,
recoverer_fn=self.recoverer_fn store=self.store,
as_api=self.as_api,
recoverer_fn=self.recoverer_fn,
) )
def test_single_service_up_txn_sent(self): def test_single_service_up_txn_sent(self):
@ -54,9 +55,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
return_value=defer.succeed(ApplicationServiceState.UP) return_value=defer.succeed(ApplicationServiceState.UP)
) )
txn.send = Mock(return_value=defer.succeed(True)) txn.send = Mock(return_value=defer.succeed(True))
self.store.create_appservice_txn = Mock( self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
return_value=defer.succeed(txn)
)
# actual call # actual call
self.txnctrl.send(service, events) self.txnctrl.send(service, events)
@ -77,9 +76,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.store.get_appservice_state = Mock( self.store.get_appservice_state = Mock(
return_value=defer.succeed(ApplicationServiceState.DOWN) return_value=defer.succeed(ApplicationServiceState.DOWN)
) )
self.store.create_appservice_txn = Mock( self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
return_value=defer.succeed(txn)
)
# actual call # actual call
self.txnctrl.send(service, events) self.txnctrl.send(service, events)
@ -104,9 +101,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
) )
self.store.set_appservice_state = Mock(return_value=defer.succeed(True)) self.store.set_appservice_state = Mock(return_value=defer.succeed(True))
txn.send = Mock(return_value=defer.succeed(False)) # fails to send txn.send = Mock(return_value=defer.succeed(False)) # fails to send
self.store.create_appservice_txn = Mock( self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn))
return_value=defer.succeed(txn)
)
# actual call # actual call
self.txnctrl.send(service, events) self.txnctrl.send(service, events)
@ -124,7 +119,6 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.clock = MockClock() self.clock = MockClock()
self.as_api = Mock() self.as_api = Mock()
@ -146,6 +140,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
def take_txn(*args, **kwargs): def take_txn(*args, **kwargs):
return defer.succeed(txns.pop(0)) return defer.succeed(txns.pop(0))
self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn) self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
self.recoverer.recover() self.recoverer.recover()
@ -171,6 +166,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
return defer.succeed(txns.pop(0)) return defer.succeed(txns.pop(0))
else: else:
return defer.succeed(txn) return defer.succeed(txn)
self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn) self.store.get_oldest_unsent_txn = Mock(side_effect=take_txn)
self.recoverer.recover() self.recoverer.recover()
@ -197,7 +193,6 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.txn_ctrl = Mock() self.txn_ctrl = Mock()
self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock()) self.queuer = _ServiceQueuer(self.txn_ctrl, MockClock())
@ -211,9 +206,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def test_send_single_event_with_queue(self): def test_send_single_event_with_queue(self):
d = defer.Deferred() d = defer.Deferred()
self.txn_ctrl.send = Mock( self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d))
side_effect=lambda x, y: make_deferred_yieldable(d),
)
service = Mock(id=4) service = Mock(id=4)
event = Mock(event_id="first") event = Mock(event_id="first")
event2 = Mock(event_id="second") event2 = Mock(event_id="second")
@ -247,6 +240,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
def do_send(x, y): def do_send(x, y):
return make_deferred_yieldable(send_return_list.pop(0)) return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send) self.txn_ctrl.send = Mock(side_effect=do_send)
# send events for different ASes and make sure they are sent # send events for different ASes and make sure they are sent

View File

@ -24,7 +24,6 @@ from tests import unittest
class ConfigGenerationTestCase(unittest.TestCase): class ConfigGenerationTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.dir = tempfile.mkdtemp() self.dir = tempfile.mkdtemp()
self.file = os.path.join(self.dir, "homeserver.yaml") self.file = os.path.join(self.dir, "homeserver.yaml")
@ -33,23 +32,30 @@ class ConfigGenerationTestCase(unittest.TestCase):
shutil.rmtree(self.dir) shutil.rmtree(self.dir)
def test_generate_config_generates_files(self): def test_generate_config_generates_files(self):
HomeServerConfig.load_or_generate_config("", [ HomeServerConfig.load_or_generate_config(
"",
[
"--generate-config", "--generate-config",
"-c", self.file, "-c",
self.file,
"--report-stats=yes", "--report-stats=yes",
"-H", "lemurs.win" "-H",
]) "lemurs.win",
],
)
self.assertSetEqual( self.assertSetEqual(
set([ set(
[
"homeserver.yaml", "homeserver.yaml",
"lemurs.win.log.config", "lemurs.win.log.config",
"lemurs.win.signing.key", "lemurs.win.signing.key",
"lemurs.win.tls.crt", "lemurs.win.tls.crt",
"lemurs.win.tls.dh", "lemurs.win.tls.dh",
"lemurs.win.tls.key", "lemurs.win.tls.key",
]), ]
set(os.listdir(self.dir)) ),
set(os.listdir(self.dir)),
) )
self.assert_log_filename_is( self.assert_log_filename_is(

View File

@ -24,7 +24,6 @@ from tests import unittest
class ConfigLoadingTestCase(unittest.TestCase): class ConfigLoadingTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.dir = tempfile.mkdtemp() self.dir = tempfile.mkdtemp()
print(self.dir) print(self.dir)
@ -43,15 +42,14 @@ class ConfigLoadingTestCase(unittest.TestCase):
def test_generates_and_loads_macaroon_secret_key(self): def test_generates_and_loads_macaroon_secret_key(self):
self.generate_config() self.generate_config()
with open(self.file, with open(self.file, "r") as f:
"r") as f:
raw = yaml.load(f) raw = yaml.load(f)
self.assertIn("macaroon_secret_key", raw) self.assertIn("macaroon_secret_key", raw)
config = HomeServerConfig.load_config("", ["-c", self.file]) config = HomeServerConfig.load_config("", ["-c", self.file])
self.assertTrue( self.assertTrue(
hasattr(config, "macaroon_secret_key"), hasattr(config, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key" "Want config to have attr macaroon_secret_key",
) )
if len(config.macaroon_secret_key) < 5: if len(config.macaroon_secret_key) < 5:
self.fail( self.fail(
@ -62,7 +60,7 @@ class ConfigLoadingTestCase(unittest.TestCase):
config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
self.assertTrue( self.assertTrue(
hasattr(config, "macaroon_secret_key"), hasattr(config, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key" "Want config to have attr macaroon_secret_key",
) )
if len(config.macaroon_secret_key) < 5: if len(config.macaroon_secret_key) < 5:
self.fail( self.fail(
@ -80,10 +78,9 @@ class ConfigLoadingTestCase(unittest.TestCase):
def test_disable_registration(self): def test_disable_registration(self):
self.generate_config() self.generate_config()
self.add_lines_to_config([ self.add_lines_to_config(
"enable_registration: true", ["enable_registration: true", "disable_registration: true"]
"disable_registration: true", )
])
# Check that disable_registration clobbers enable_registration. # Check that disable_registration clobbers enable_registration.
config = HomeServerConfig.load_config("", ["-c", self.file]) config = HomeServerConfig.load_config("", ["-c", self.file])
self.assertFalse(config.enable_registration) self.assertFalse(config.enable_registration)
@ -92,18 +89,23 @@ class ConfigLoadingTestCase(unittest.TestCase):
self.assertFalse(config.enable_registration) self.assertFalse(config.enable_registration)
# Check that either config value is clobbered by the command line. # Check that either config value is clobbered by the command line.
config = HomeServerConfig.load_or_generate_config("", [ config = HomeServerConfig.load_or_generate_config(
"-c", self.file, "--enable-registration" "", ["-c", self.file, "--enable-registration"]
]) )
self.assertTrue(config.enable_registration) self.assertTrue(config.enable_registration)
def generate_config(self): def generate_config(self):
HomeServerConfig.load_or_generate_config("", [ HomeServerConfig.load_or_generate_config(
"",
[
"--generate-config", "--generate-config",
"-c", self.file, "-c",
self.file,
"--report-stats=yes", "--report-stats=yes",
"-H", "lemurs.win" "-H",
]) "lemurs.win",
],
)
def generate_config_and_remove_lines_containing(self, needle): def generate_config_and_remove_lines_containing(self, needle):
self.generate_config() self.generate_config()

View File

@ -24,9 +24,7 @@ from tests import unittest
# Perform these tests using given secret key so we get entirely deterministic # Perform these tests using given secret key so we get entirely deterministic
# signatures output that we can test against. # signatures output that we can test against.
SIGNING_KEY_SEED = decode_base64( SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1")
"YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1"
)
KEY_ALG = "ed25519" KEY_ALG = "ed25519"
KEY_VER = 1 KEY_VER = 1
@ -36,7 +34,6 @@ HOSTNAME = "domain"
class EventSigningTestCase(unittest.TestCase): class EventSigningTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED) self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED)
self.signing_key.alg = KEY_ALG self.signing_key.alg = KEY_ALG
@ -51,7 +48,7 @@ class EventSigningTestCase(unittest.TestCase):
'signatures': {}, 'signatures': {},
'type': "X", 'type': "X",
'unsigned': {'age_ts': 1000000}, 'unsigned': {'age_ts': 1000000},
}, }
) )
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key) add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
@ -61,8 +58,7 @@ class EventSigningTestCase(unittest.TestCase):
self.assertTrue(hasattr(event, 'hashes')) self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes) self.assertIn('sha256', event.hashes)
self.assertEquals( self.assertEquals(
event.hashes['sha256'], event.hashes['sha256'], "6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI"
"6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI",
) )
self.assertTrue(hasattr(event, 'signatures')) self.assertTrue(hasattr(event, 'signatures'))
@ -77,9 +73,7 @@ class EventSigningTestCase(unittest.TestCase):
def test_sign_message(self): def test_sign_message(self):
builder = EventBuilder( builder = EventBuilder(
{ {
'content': { 'content': {'body': "Here is the message content"},
'body': "Here is the message content",
},
'event_id': "$0:domain", 'event_id': "$0:domain",
'origin': "domain", 'origin': "domain",
'origin_server_ts': 1000000, 'origin_server_ts': 1000000,
@ -98,8 +92,7 @@ class EventSigningTestCase(unittest.TestCase):
self.assertTrue(hasattr(event, 'hashes')) self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes) self.assertIn('sha256', event.hashes)
self.assertEquals( self.assertEquals(
event.hashes['sha256'], event.hashes['sha256'], "onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g"
"onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g",
) )
self.assertTrue(hasattr(event, 'signatures')) self.assertTrue(hasattr(event, 'signatures'))
@ -108,5 +101,5 @@ class EventSigningTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
event.signatures[HOSTNAME][KEY_NAME], event.signatures[HOSTNAME][KEY_NAME],
"Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw" "Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw"
"u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA" "u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA",
) )

View File

@ -36,9 +36,7 @@ class MockPerspectiveServer(object):
def get_verify_keys(self): def get_verify_keys(self):
vk = signedjson.key.get_verify_key(self.key) vk = signedjson.key.get_verify_key(self.key)
return { return {"%s:%s" % (vk.alg, vk.version): vk}
"%s:%s" % (vk.alg, vk.version): vk,
}
def get_signed_key(self, server_name, verify_key): def get_signed_key(self, server_name, verify_key):
key_id = "%s:%s" % (verify_key.alg, verify_key.version) key_id = "%s:%s" % (verify_key.alg, verify_key.version)
@ -47,10 +45,8 @@ class MockPerspectiveServer(object):
"old_verify_keys": {}, "old_verify_keys": {},
"valid_until_ts": time.time() * 1000 + 3600, "valid_until_ts": time.time() * 1000 + 3600,
"verify_keys": { "verify_keys": {
key_id: { key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
"key": signedjson.key.encode_verify_key_base64(verify_key) },
}
}
} }
signedjson.sign.sign_json(res, self.server_name, self.key) signedjson.sign.sign_json(res, self.server_name, self.key)
return res return res
@ -62,18 +58,16 @@ class KeyringTestCase(unittest.TestCase):
self.mock_perspective_server = MockPerspectiveServer() self.mock_perspective_server = MockPerspectiveServer()
self.http_client = Mock() self.http_client = Mock()
self.hs = yield utils.setup_test_homeserver( self.hs = yield utils.setup_test_homeserver(
handlers=None, handlers=None, http_client=self.http_client
http_client=self.http_client,
) )
keys = self.mock_perspective_server.get_verify_keys()
self.hs.config.perspectives = { self.hs.config.perspectives = {
self.mock_perspective_server.server_name: self.mock_perspective_server.server_name: keys
self.mock_perspective_server.get_verify_keys()
} }
def check_context(self, _, expected): def check_context(self, _, expected):
self.assertEquals( self.assertEquals(
getattr(LoggingContext.current_context(), "request", None), getattr(LoggingContext.current_context(), "request", None), expected
expected
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -89,8 +83,7 @@ class KeyringTestCase(unittest.TestCase):
context_one.request = "one" context_one.request = "one"
wait_1_deferred = kr.wait_for_previous_lookups( wait_1_deferred = kr.wait_for_previous_lookups(
["server1"], ["server1"], {"server1": lookup_1_deferred}
{"server1": lookup_1_deferred},
) )
# there were no previous lookups, so the deferred should be ready # there were no previous lookups, so the deferred should be ready
@ -105,8 +98,7 @@ class KeyringTestCase(unittest.TestCase):
# set off another wait. It should block because the first lookup # set off another wait. It should block because the first lookup
# hasn't yet completed. # hasn't yet completed.
wait_2_deferred = kr.wait_for_previous_lookups( wait_2_deferred = kr.wait_for_previous_lookups(
["server1"], ["server1"], {"server1": lookup_2_deferred}
{"server1": lookup_2_deferred},
) )
self.assertFalse(wait_2_deferred.called) self.assertFalse(wait_2_deferred.called)
# ... so we should have reset the LoggingContext. # ... so we should have reset the LoggingContext.
@ -132,21 +124,19 @@ class KeyringTestCase(unittest.TestCase):
persp_resp = { persp_resp = {
"server_keys": [ "server_keys": [
self.mock_perspective_server.get_signed_key( self.mock_perspective_server.get_signed_key(
"server10", "server10", signedjson.key.get_verify_key(key1)
signedjson.key.get_verify_key(key1) )
),
] ]
} }
persp_deferred = defer.Deferred() persp_deferred = defer.Deferred()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_perspectives(**kwargs): def get_perspectives(**kwargs):
self.assertEquals( self.assertEquals(LoggingContext.current_context().request, "11")
LoggingContext.current_context().request, "11",
)
with logcontext.PreserveLoggingContext(): with logcontext.PreserveLoggingContext():
yield persp_deferred yield persp_deferred
defer.returnValue(persp_resp) defer.returnValue(persp_resp)
self.http_client.post_json.side_effect = get_perspectives self.http_client.post_json.side_effect = get_perspectives
with LoggingContext("11") as context_11: with LoggingContext("11") as context_11:
@ -154,9 +144,7 @@ class KeyringTestCase(unittest.TestCase):
# start off a first set of lookups # start off a first set of lookups
res_deferreds = kr.verify_json_objects_for_server( res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1), [("server10", json1), ("server11", {})]
("server11", {})
]
) )
# the unsigned json should be rejected pretty quickly # the unsigned json should be rejected pretty quickly
@ -186,7 +174,7 @@ class KeyringTestCase(unittest.TestCase):
self.http_client.post_json.return_value = defer.Deferred() self.http_client.post_json.return_value = defer.Deferred()
res_deferreds_2 = kr.verify_json_objects_for_server( res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)], [("server10", json1)]
) )
yield clock.sleep(1) yield clock.sleep(1)
self.http_client.post_json.assert_not_called() self.http_client.post_json.assert_not_called()
@ -207,8 +195,7 @@ class KeyringTestCase(unittest.TestCase):
key1 = signedjson.key.generate_signing_key(1) key1 = signedjson.key.generate_signing_key(1)
yield self.hs.datastore.store_server_verify_key( yield self.hs.datastore.store_server_verify_key(
"server9", "", time.time() * 1000, "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
signedjson.key.get_verify_key(key1),
) )
json1 = {} json1 = {}
signedjson.sign.sign_json(json1, "server9", key1) signedjson.sign.sign_json(json1, "server9", key1)

View File

@ -31,25 +31,20 @@ def MockEvent(**kwargs):
class PruneEventTestCase(unittest.TestCase): class PruneEventTestCase(unittest.TestCase):
""" Asserts that a new event constructed with `evdict` will look like """ Asserts that a new event constructed with `evdict` will look like
`matchdict` when it is redacted. """ `matchdict` when it is redacted. """
def run_test(self, evdict, matchdict): def run_test(self, evdict, matchdict):
self.assertEquals( self.assertEquals(prune_event(FrozenEvent(evdict)).get_dict(), matchdict)
prune_event(FrozenEvent(evdict)).get_dict(),
matchdict
)
def test_minimal(self): def test_minimal(self):
self.run_test( self.run_test(
{ {'type': 'A', 'event_id': '$test:domain'},
'type': 'A',
'event_id': '$test:domain',
},
{ {
'type': 'A', 'type': 'A',
'event_id': '$test:domain', 'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
} },
) )
def test_basic_keys(self): def test_basic_keys(self):
@ -70,23 +65,19 @@ class PruneEventTestCase(unittest.TestCase):
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
} },
) )
def test_unsigned_age_ts(self): def test_unsigned_age_ts(self):
self.run_test( self.run_test(
{ {'type': 'B', 'event_id': '$test:domain', 'unsigned': {'age_ts': 20}},
'type': 'B',
'event_id': '$test:domain',
'unsigned': {'age_ts': 20},
},
{ {
'type': 'B', 'type': 'B',
'event_id': '$test:domain', 'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {'age_ts': 20}, 'unsigned': {'age_ts': 20},
} },
) )
self.run_test( self.run_test(
@ -101,23 +92,19 @@ class PruneEventTestCase(unittest.TestCase):
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
} },
) )
def test_content(self): def test_content(self):
self.run_test( self.run_test(
{ {'type': 'C', 'event_id': '$test:domain', 'content': {'things': 'here'}},
'type': 'C',
'event_id': '$test:domain',
'content': {'things': 'here'},
},
{ {
'type': 'C', 'type': 'C',
'event_id': '$test:domain', 'event_id': '$test:domain',
'content': {}, 'content': {},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
} },
) )
self.run_test( self.run_test(
@ -132,27 +119,20 @@ class PruneEventTestCase(unittest.TestCase):
'content': {'creator': '@2:domain'}, 'content': {'creator': '@2:domain'},
'signatures': {}, 'signatures': {},
'unsigned': {}, 'unsigned': {},
} },
) )
class SerializeEventTestCase(unittest.TestCase): class SerializeEventTestCase(unittest.TestCase):
def serialize(self, ev, fields): def serialize(self, ev, fields):
return serialize_event(ev, 1479807801915, only_event_fields=fields) return serialize_event(ev, 1479807801915, only_event_fields=fields)
def test_event_fields_works_with_keys(self): def test_event_fields_works_with_keys(self):
self.assertEquals( self.assertEquals(
self.serialize( self.serialize(
MockEvent( MockEvent(sender="@alice:localhost", room_id="!foo:bar"), ["room_id"]
sender="@alice:localhost",
room_id="!foo:bar"
), ),
["room_id"] {"room_id": "!foo:bar"},
),
{
"room_id": "!foo:bar",
}
) )
def test_event_fields_works_with_nested_keys(self): def test_event_fields_works_with_nested_keys(self):
@ -161,17 +141,11 @@ class SerializeEventTestCase(unittest.TestCase):
MockEvent( MockEvent(
sender="@alice:localhost", sender="@alice:localhost",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={"body": "A message"},
"body": "A message",
},
), ),
["content.body"] ["content.body"],
), ),
{ {"content": {"body": "A message"}},
"content": {
"body": "A message",
}
}
) )
def test_event_fields_works_with_dot_keys(self): def test_event_fields_works_with_dot_keys(self):
@ -180,17 +154,11 @@ class SerializeEventTestCase(unittest.TestCase):
MockEvent( MockEvent(
sender="@alice:localhost", sender="@alice:localhost",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={"key.with.dots": {}},
"key.with.dots": {},
},
), ),
["content.key\.with\.dots"] ["content.key\.with\.dots"],
), ),
{ {"content": {"key.with.dots": {}}},
"content": {
"key.with.dots": {},
}
}
) )
def test_event_fields_works_with_nested_dot_keys(self): def test_event_fields_works_with_nested_dot_keys(self):
@ -201,21 +169,12 @@ class SerializeEventTestCase(unittest.TestCase):
room_id="!foo:bar", room_id="!foo:bar",
content={ content={
"not_me": 1, "not_me": 1,
"nested.dot.key": { "nested.dot.key": {"leaf.key": 42, "not_me_either": 1},
"leaf.key": 42,
"not_me_either": 1,
},
}, },
), ),
["content.nested\.dot\.key.leaf\.key"] ["content.nested\.dot\.key.leaf\.key"],
), ),
{ {"content": {"nested.dot.key": {"leaf.key": 42}}},
"content": {
"nested.dot.key": {
"leaf.key": 42,
},
}
}
) )
def test_event_fields_nops_with_unknown_keys(self): def test_event_fields_nops_with_unknown_keys(self):
@ -224,17 +183,11 @@ class SerializeEventTestCase(unittest.TestCase):
MockEvent( MockEvent(
sender="@alice:localhost", sender="@alice:localhost",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={"foo": "bar"},
"foo": "bar",
},
), ),
["content.foo", "content.notexists"] ["content.foo", "content.notexists"],
), ),
{ {"content": {"foo": "bar"}},
"content": {
"foo": "bar",
}
}
) )
def test_event_fields_nops_with_non_dict_keys(self): def test_event_fields_nops_with_non_dict_keys(self):
@ -243,13 +196,11 @@ class SerializeEventTestCase(unittest.TestCase):
MockEvent( MockEvent(
sender="@alice:localhost", sender="@alice:localhost",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={"foo": ["I", "am", "an", "array"]},
"foo": ["I", "am", "an", "array"],
},
), ),
["content.foo.am"] ["content.foo.am"],
), ),
{} {},
) )
def test_event_fields_nops_with_array_keys(self): def test_event_fields_nops_with_array_keys(self):
@ -258,13 +209,11 @@ class SerializeEventTestCase(unittest.TestCase):
MockEvent( MockEvent(
sender="@alice:localhost", sender="@alice:localhost",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={"foo": ["I", "am", "an", "array"]},
"foo": ["I", "am", "an", "array"],
},
), ),
["content.foo.1"] ["content.foo.1"],
), ),
{} {},
) )
def test_event_fields_all_fields_if_empty(self): def test_event_fields_all_fields_if_empty(self):
@ -274,31 +223,21 @@ class SerializeEventTestCase(unittest.TestCase):
type="foo", type="foo",
event_id="test", event_id="test",
room_id="!foo:bar", room_id="!foo:bar",
content={ content={"foo": "bar"},
"foo": "bar",
},
), ),
[] [],
), ),
{ {
"type": "foo", "type": "foo",
"event_id": "test", "event_id": "test",
"room_id": "!foo:bar", "room_id": "!foo:bar",
"content": { "content": {"foo": "bar"},
"foo": "bar", "unsigned": {},
}, },
"unsigned": {}
}
) )
def test_event_fields_fail_if_fields_not_str(self): def test_event_fields_fail_if_fields_not_str(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
self.serialize( self.serialize(
MockEvent( MockEvent(room_id="!foo:bar", content={"foo": "bar"}), ["room_id", 4]
room_id="!foo:bar",
content={
"foo": "bar",
},
),
["room_id", 4]
) )

View File

@ -23,10 +23,7 @@ from tests import unittest
@unittest.DEBUG @unittest.DEBUG
class ServerACLsTestCase(unittest.TestCase): class ServerACLsTestCase(unittest.TestCase):
def test_blacklisted_server(self): def test_blacklisted_server(self):
e = _create_acl_event({ e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
"allow": ["*"],
"deny": ["evil.com"],
})
logging.info("ACL event: %s", e.content) logging.info("ACL event: %s", e.content)
self.assertFalse(server_matches_acl_event("evil.com", e)) self.assertFalse(server_matches_acl_event("evil.com", e))
@ -36,10 +33,7 @@ class ServerACLsTestCase(unittest.TestCase):
self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e)) self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
def test_block_ip_literals(self): def test_block_ip_literals(self):
e = _create_acl_event({ e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]})
"allow_ip_literals": False,
"allow": ["*"],
})
logging.info("ACL event: %s", e.content) logging.info("ACL event: %s", e.content)
self.assertFalse(server_matches_acl_event("1.2.3.4", e)) self.assertFalse(server_matches_acl_event("1.2.3.4", e))
@ -49,10 +43,12 @@ class ServerACLsTestCase(unittest.TestCase):
def _create_acl_event(content): def _create_acl_event(content):
return FrozenEvent({ return FrozenEvent(
{
"room_id": "!a:b", "room_id": "!a:b",
"event_id": "$a:b", "event_id": "$a:b",
"type": "m.room.server_acls", "type": "m.room.server_acls",
"sender": "@a:b", "sender": "@a:b",
"content": content "content": content,
}) }
)

View File

@ -45,20 +45,18 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [ services = [
self._mkservice(is_interested=False), self._mkservice(is_interested=False),
interested_service, interested_service,
self._mkservice(is_interested=False) self._mkservice(is_interested=False),
] ]
self.mock_store.get_app_services = Mock(return_value=services) self.mock_store.get_app_services = Mock(return_value=services)
self.mock_store.get_user_by_id = Mock(return_value=[]) self.mock_store.get_user_by_id = Mock(return_value=[])
event = Mock( event = Mock(
sender="@someone:anywhere", sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
type="m.room.message",
room_id="!foo:bar"
) )
self.mock_store.get_new_events_for_appservice.side_effect = [ self.mock_store.get_new_events_for_appservice.side_effect = [
(0, [event]), (0, [event]),
(0, []) (0, []),
] ]
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
yield self.handler.notify_interested_services(0) yield self.handler.notify_interested_services(0)
@ -74,21 +72,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_store.get_app_services = Mock(return_value=services) self.mock_store.get_app_services = Mock(return_value=services)
self.mock_store.get_user_by_id = Mock(return_value=None) self.mock_store.get_user_by_id = Mock(return_value=None)
event = Mock( event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
sender=user_id,
type="m.room.message",
room_id="!foo:bar"
)
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock() self.mock_as_api.query_user = Mock()
self.mock_store.get_new_events_for_appservice.side_effect = [ self.mock_store.get_new_events_for_appservice.side_effect = [
(0, [event]), (0, [event]),
(0, []) (0, []),
] ]
yield self.handler.notify_interested_services(0) yield self.handler.notify_interested_services(0)
self.mock_as_api.query_user.assert_called_once_with( self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
services[0], user_id
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_query_user_exists_known_user(self): def test_query_user_exists_known_user(self):
@ -96,25 +88,19 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested=True)] services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user = Mock(return_value=True) services[0].is_interested_in_user = Mock(return_value=True)
self.mock_store.get_app_services = Mock(return_value=services) self.mock_store.get_app_services = Mock(return_value=services)
self.mock_store.get_user_by_id = Mock(return_value={ self.mock_store.get_user_by_id = Mock(return_value={"name": user_id})
"name": user_id
})
event = Mock( event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
sender=user_id,
type="m.room.message",
room_id="!foo:bar"
)
self.mock_as_api.push = Mock() self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock() self.mock_as_api.query_user = Mock()
self.mock_store.get_new_events_for_appservice.side_effect = [ self.mock_store.get_new_events_for_appservice.side_effect = [
(0, [event]), (0, [event]),
(0, []) (0, []),
] ]
yield self.handler.notify_interested_services(0) yield self.handler.notify_interested_services(0)
self.assertFalse( self.assertFalse(
self.mock_as_api.query_user.called, self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been." "query_user called when it shouldn't have been.",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -129,7 +115,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [ services = [
self._mkservice_alias(is_interested_in_alias=False), self._mkservice_alias(is_interested_in_alias=False),
interested_service, interested_service,
self._mkservice_alias(is_interested_in_alias=False) self._mkservice_alias(is_interested_in_alias=False),
] ]
self.mock_store.get_app_services = Mock(return_value=services) self.mock_store.get_app_services = Mock(return_value=services)
@ -140,8 +126,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
result = yield self.handler.query_room_alias_exists(room_alias) result = yield self.handler.query_room_alias_exists(room_alias)
self.mock_as_api.query_alias.assert_called_once_with( self.mock_as_api.query_alias.assert_called_once_with(
interested_service, interested_service, room_alias_str
room_alias_str
) )
self.assertEquals(result.room_id, room_id) self.assertEquals(result.room_id, room_id)
self.assertEquals(result.servers, servers) self.assertEquals(result.servers, servers)

View File

@ -81,9 +81,7 @@ class AuthTestCase(unittest.TestCase):
def test_short_term_login_token_gives_user_id(self): def test_short_term_login_token_gives_user_id(self):
self.hs.clock.now = 1000 self.hs.clock.now = 1000
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
"a_user", 5000
)
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
token token
) )
@ -98,17 +96,13 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_short_term_login_token_cannot_replace_user_id(self): def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
"a_user", 5000
)
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize() macaroon.serialize()
) )
self.assertEqual( self.assertEqual("a_user", user_id)
"a_user", user_id
)
# add another "user_id" caveat, which might allow us to override the # add another "user_id" caveat, which might allow us to override the
# user_id. # user_id.
@ -165,7 +159,5 @@ class AuthTestCase(unittest.TestCase):
) )
def _get_macaroon(self): def _get_macaroon(self):
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
"user_a", 5000
)
return pymacaroons.Macaroon.deserialize(token) return pymacaroons.Macaroon.deserialize(token)

View File

@ -44,7 +44,7 @@ class DeviceTestCase(unittest.TestCase):
res = yield self.handler.check_device_registered( res = yield self.handler.check_device_registered(
user_id="@boris:foo", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="display name" initial_device_display_name="display name",
) )
self.assertEqual(res, "fco") self.assertEqual(res, "fco")
@ -56,14 +56,14 @@ class DeviceTestCase(unittest.TestCase):
res1 = yield self.handler.check_device_registered( res1 = yield self.handler.check_device_registered(
user_id="@boris:foo", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="display name" initial_device_display_name="display name",
) )
self.assertEqual(res1, "fco") self.assertEqual(res1, "fco")
res2 = yield self.handler.check_device_registered( res2 = yield self.handler.check_device_registered(
user_id="@boris:foo", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="new display name" initial_device_display_name="new display name",
) )
self.assertEqual(res2, "fco") self.assertEqual(res2, "fco")
@ -75,7 +75,7 @@ class DeviceTestCase(unittest.TestCase):
device_id = yield self.handler.check_device_registered( device_id = yield self.handler.check_device_registered(
user_id="@theresa:foo", user_id="@theresa:foo",
device_id=None, device_id=None,
initial_device_display_name="display" initial_device_display_name="display",
) )
dev = yield self.handler.store.get_device("@theresa:foo", device_id) dev = yield self.handler.store.get_device("@theresa:foo", device_id)
@ -87,43 +87,53 @@ class DeviceTestCase(unittest.TestCase):
res = yield self.handler.get_devices_by_user(user1) res = yield self.handler.get_devices_by_user(user1)
self.assertEqual(3, len(res)) self.assertEqual(3, len(res))
device_map = { device_map = {d["device_id"]: d for d in res}
d["device_id"]: d for d in res self.assertDictContainsSubset(
} {
self.assertDictContainsSubset({
"user_id": user1, "user_id": user1,
"device_id": "xyz", "device_id": "xyz",
"display_name": "display 0", "display_name": "display 0",
"last_seen_ip": None, "last_seen_ip": None,
"last_seen_ts": None, "last_seen_ts": None,
}, device_map["xyz"]) },
self.assertDictContainsSubset({ device_map["xyz"],
)
self.assertDictContainsSubset(
{
"user_id": user1, "user_id": user1,
"device_id": "fco", "device_id": "fco",
"display_name": "display 1", "display_name": "display 1",
"last_seen_ip": "ip1", "last_seen_ip": "ip1",
"last_seen_ts": 1000000, "last_seen_ts": 1000000,
}, device_map["fco"]) },
self.assertDictContainsSubset({ device_map["fco"],
)
self.assertDictContainsSubset(
{
"user_id": user1, "user_id": user1,
"device_id": "abc", "device_id": "abc",
"display_name": "display 2", "display_name": "display 2",
"last_seen_ip": "ip3", "last_seen_ip": "ip3",
"last_seen_ts": 3000000, "last_seen_ts": 3000000,
}, device_map["abc"]) },
device_map["abc"],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_device(self): def test_get_device(self):
yield self._record_users() yield self._record_users()
res = yield self.handler.get_device(user1, "abc") res = yield self.handler.get_device(user1, "abc")
self.assertDictContainsSubset({ self.assertDictContainsSubset(
{
"user_id": user1, "user_id": user1,
"device_id": "abc", "device_id": "abc",
"display_name": "display 2", "display_name": "display 2",
"last_seen_ip": "ip3", "last_seen_ip": "ip3",
"last_seen_ts": 3000000, "last_seen_ts": 3000000,
}, res) },
res,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_delete_device(self): def test_delete_device(self):
@ -153,8 +163,7 @@ class DeviceTestCase(unittest.TestCase):
def test_update_unknown_device(self): def test_update_unknown_device(self):
update = {"display_name": "new_display"} update = {"display_name": "new_display"}
with self.assertRaises(synapse.api.errors.NotFoundError): with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.update_device("user_id", "unknown_device_id", yield self.handler.update_device("user_id", "unknown_device_id", update)
update)
@defer.inlineCallbacks @defer.inlineCallbacks
def _record_users(self): def _record_users(self):
@ -168,16 +177,17 @@ class DeviceTestCase(unittest.TestCase):
yield self._record_user(user2, "def", "dispkay", "token4", "ip4") yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
@defer.inlineCallbacks @defer.inlineCallbacks
def _record_user(self, user_id, device_id, display_name, def _record_user(
access_token=None, ip=None): self, user_id, device_id, display_name, access_token=None, ip=None
):
device_id = yield self.handler.check_device_registered( device_id = yield self.handler.check_device_registered(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,
initial_device_display_name=display_name initial_device_display_name=display_name,
) )
if ip is not None: if ip is not None:
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, user_id, access_token, ip, "user_agent", device_id
access_token, ip, "user_agent", device_id) )
self.clock.advance_time(1000) self.clock.advance_time(1000)

View File

@ -42,6 +42,7 @@ class DirectoryTestCase(unittest.TestCase):
def register_query_handler(query_type, handler): def register_query_handler(query_type, handler):
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
self.mock_registry.register_query_handler = register_query_handler self.mock_registry.register_query_handler = register_query_handler
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
@ -68,10 +69,7 @@ class DirectoryTestCase(unittest.TestCase):
result = yield self.handler.get_association(self.my_room) result = yield self.handler.get_association(self.my_room)
self.assertEquals({ self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
"room_id": "!8765qwer:test",
"servers": ["test"],
}, result)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_remote_association(self): def test_get_remote_association(self):
@ -81,16 +79,13 @@ class DirectoryTestCase(unittest.TestCase):
result = yield self.handler.get_association(self.remote_room) result = yield self.handler.get_association(self.remote_room)
self.assertEquals({ self.assertEquals(
"room_id": "!8765qwer:test", {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}, result
"servers": ["test", "remote"], )
}, result)
self.mock_federation.make_query.assert_called_with( self.mock_federation.make_query.assert_called_with(
destination="remote", destination="remote",
query_type="directory", query_type="directory",
args={ args={"room_alias": "#another:remote"},
"room_alias": "#another:remote",
},
retry_on_dns_fail=False, retry_on_dns_fail=False,
ignore_backoff=True, ignore_backoff=True,
) )
@ -105,7 +100,4 @@ class DirectoryTestCase(unittest.TestCase):
{"room_alias": "#your-room:test"} {"room_alias": "#your-room:test"}
) )
self.assertEquals({ self.assertEquals({"room_id": "!8765asdf:test", "servers": ["test"]}, response)
"room_id": "!8765asdf:test",
"servers": ["test"],
}, response)

View File

@ -34,8 +34,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.hs = yield utils.setup_test_homeserver( self.hs = yield utils.setup_test_homeserver(
handlers=None, handlers=None, federation_client=mock.Mock()
federation_client=mock.Mock(),
) )
self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
@ -54,30 +53,21 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_id = "xyz" device_id = "xyz"
keys = { keys = {
"alg1:k1": "key1", "alg1:k1": "key1",
"alg2:k2": { "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"key": "key2", "alg2:k3": {"key": "key3"},
"signatures": {"k1": "sig1"}
},
"alg2:k3": {
"key": "key3",
},
} }
res = yield self.handler.upload_keys_for_user( res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}, local_user, device_id, {"one_time_keys": keys}
) )
self.assertDictEqual(res, { self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
"one_time_key_counts": {"alg1": 1, "alg2": 2}
})
# we should be able to change the signature without a problem # we should be able to change the signature without a problem
keys["alg2:k2"]["signatures"]["k1"] = "sig2" keys["alg2:k2"]["signatures"]["k1"] = "sig2"
res = yield self.handler.upload_keys_for_user( res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}, local_user, device_id, {"one_time_keys": keys}
) )
self.assertDictEqual(res, { self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
"one_time_key_counts": {"alg1": 1, "alg2": 2}
})
@defer.inlineCallbacks @defer.inlineCallbacks
def test_change_one_time_keys(self): def test_change_one_time_keys(self):
@ -87,25 +77,18 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_id = "xyz" device_id = "xyz"
keys = { keys = {
"alg1:k1": "key1", "alg1:k1": "key1",
"alg2:k2": { "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"key": "key2", "alg2:k3": {"key": "key3"},
"signatures": {"k1": "sig1"}
},
"alg2:k3": {
"key": "key3",
},
} }
res = yield self.handler.upload_keys_for_user( res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}, local_user, device_id, {"one_time_keys": keys}
) )
self.assertDictEqual(res, { self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
"one_time_key_counts": {"alg1": 1, "alg2": 2}
})
try: try:
yield self.handler.upload_keys_for_user( yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}, local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
) )
self.fail("No error when changing string key") self.fail("No error when changing string key")
except errors.SynapseError: except errors.SynapseError:
@ -113,7 +96,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
try: try:
yield self.handler.upload_keys_for_user( yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}, local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
) )
self.fail("No error when replacing dict key with string") self.fail("No error when replacing dict key with string")
except errors.SynapseError: except errors.SynapseError:
@ -121,9 +104,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
try: try:
yield self.handler.upload_keys_for_user( yield self.handler.upload_keys_for_user(
local_user, device_id, { local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}}
"one_time_keys": {"alg1:k1": {"key": "key"}}
},
) )
self.fail("No error when replacing string key with dict") self.fail("No error when replacing string key with dict")
except errors.SynapseError: except errors.SynapseError:
@ -131,14 +112,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
try: try:
yield self.handler.upload_keys_for_user( yield self.handler.upload_keys_for_user(
local_user, device_id, { local_user,
device_id,
{
"one_time_keys": { "one_time_keys": {
"alg2:k2": { "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
"key": "key3",
"signatures": {"k1": "sig1"},
} }
}, },
},
) )
self.fail("No error when replacing dict key") self.fail("No error when replacing dict key")
except errors.SynapseError: except errors.SynapseError:
@ -148,31 +128,20 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
def test_claim_one_time_key(self): def test_claim_one_time_key(self):
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
device_id = "xyz" device_id = "xyz"
keys = { keys = {"alg1:k1": "key1"}
"alg1:k1": "key1",
}
res = yield self.handler.upload_keys_for_user( res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}, local_user, device_id, {"one_time_keys": keys}
) )
self.assertDictEqual(res, { self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
"one_time_key_counts": {"alg1": 1}
})
res2 = yield self.handler.claim_one_time_keys({ res2 = yield self.handler.claim_one_time_keys(
"one_time_keys": { {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
local_user: { )
device_id: "alg1" self.assertEqual(
} res2,
} {
}, timeout=None)
self.assertEqual(res2, {
"failures": {}, "failures": {},
"one_time_keys": { "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
local_user: { },
device_id: { )
"alg1:k1": "key1"
}
}
}
})

View File

@ -39,8 +39,7 @@ class PresenceUpdateTestCase(unittest.TestCase):
prev_state = UserPresenceState.default(user_id) prev_state = UserPresenceState.default(user_id)
new_state = prev_state.copy_and_replace( new_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE, last_active_ts=now
last_active_ts=now,
) )
state, persist_and_notify, federation_ping = handle_update( state, persist_and_notify, federation_ping = handle_update(
@ -54,23 +53,22 @@ class PresenceUpdateTestCase(unittest.TestCase):
self.assertEquals(state.last_federation_update_ts, now) self.assertEquals(state.last_federation_update_ts, now)
self.assertEquals(wheel_timer.insert.call_count, 3) self.assertEquals(wheel_timer.insert.call_count, 3)
wheel_timer.insert.assert_has_calls([ wheel_timer.insert.assert_has_calls(
[
call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
call( call(
now=now, now=now,
obj=user_id, obj=user_id,
then=new_state.last_active_ts + IDLE_TIMER then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
), ),
call( call(
now=now, now=now,
obj=user_id, obj=user_id,
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
), ),
call( ],
now=now, any_order=True,
obj=user_id, )
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
),
], any_order=True)
def test_online_to_online(self): def test_online_to_online(self):
wheel_timer = Mock() wheel_timer = Mock()
@ -79,14 +77,11 @@ class PresenceUpdateTestCase(unittest.TestCase):
prev_state = UserPresenceState.default(user_id) prev_state = UserPresenceState.default(user_id)
prev_state = prev_state.copy_and_replace( prev_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE, last_active_ts=now, currently_active=True
last_active_ts=now,
currently_active=True,
) )
new_state = prev_state.copy_and_replace( new_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE, last_active_ts=now
last_active_ts=now,
) )
state, persist_and_notify, federation_ping = handle_update( state, persist_and_notify, federation_ping = handle_update(
@ -101,23 +96,22 @@ class PresenceUpdateTestCase(unittest.TestCase):
self.assertEquals(state.last_federation_update_ts, now) self.assertEquals(state.last_federation_update_ts, now)
self.assertEquals(wheel_timer.insert.call_count, 3) self.assertEquals(wheel_timer.insert.call_count, 3)
wheel_timer.insert.assert_has_calls([ wheel_timer.insert.assert_has_calls(
[
call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
call( call(
now=now, now=now,
obj=user_id, obj=user_id,
then=new_state.last_active_ts + IDLE_TIMER then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
), ),
call( call(
now=now, now=now,
obj=user_id, obj=user_id,
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
), ),
call( ],
now=now, any_order=True,
obj=user_id, )
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
),
], any_order=True)
def test_online_to_online_last_active_noop(self): def test_online_to_online_last_active_noop(self):
wheel_timer = Mock() wheel_timer = Mock()
@ -132,8 +126,7 @@ class PresenceUpdateTestCase(unittest.TestCase):
) )
new_state = prev_state.copy_and_replace( new_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE, last_active_ts=now
last_active_ts=now,
) )
state, persist_and_notify, federation_ping = handle_update( state, persist_and_notify, federation_ping = handle_update(
@ -148,23 +141,22 @@ class PresenceUpdateTestCase(unittest.TestCase):
self.assertEquals(state.last_federation_update_ts, now) self.assertEquals(state.last_federation_update_ts, now)
self.assertEquals(wheel_timer.insert.call_count, 3) self.assertEquals(wheel_timer.insert.call_count, 3)
wheel_timer.insert.assert_has_calls([ wheel_timer.insert.assert_has_calls(
[
call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
call( call(
now=now, now=now,
obj=user_id, obj=user_id,
then=new_state.last_active_ts + IDLE_TIMER then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
), ),
call( call(
now=now, now=now,
obj=user_id, obj=user_id,
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
), ),
call( ],
now=now, any_order=True,
obj=user_id, )
then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
),
], any_order=True)
def test_online_to_online_last_active(self): def test_online_to_online_last_active(self):
wheel_timer = Mock() wheel_timer = Mock()
@ -178,9 +170,7 @@ class PresenceUpdateTestCase(unittest.TestCase):
currently_active=True, currently_active=True,
) )
new_state = prev_state.copy_and_replace( new_state = prev_state.copy_and_replace(state=PresenceState.ONLINE)
state=PresenceState.ONLINE,
)
state, persist_and_notify, federation_ping = handle_update( state, persist_and_notify, federation_ping = handle_update(
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
@ -193,18 +183,17 @@ class PresenceUpdateTestCase(unittest.TestCase):
self.assertEquals(state.last_federation_update_ts, now) self.assertEquals(state.last_federation_update_ts, now)
self.assertEquals(wheel_timer.insert.call_count, 2) self.assertEquals(wheel_timer.insert.call_count, 2)
wheel_timer.insert.assert_has_calls([ wheel_timer.insert.assert_has_calls(
[
call(now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER),
call( call(
now=now, now=now,
obj=user_id, obj=user_id,
then=new_state.last_active_ts + IDLE_TIMER then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
), ),
call( ],
now=now, any_order=True,
obj=user_id,
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
) )
], any_order=True)
def test_remote_ping_timer(self): def test_remote_ping_timer(self):
wheel_timer = Mock() wheel_timer = Mock()
@ -213,13 +202,10 @@ class PresenceUpdateTestCase(unittest.TestCase):
prev_state = UserPresenceState.default(user_id) prev_state = UserPresenceState.default(user_id)
prev_state = prev_state.copy_and_replace( prev_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE, last_active_ts=now
last_active_ts=now,
) )
new_state = prev_state.copy_and_replace( new_state = prev_state.copy_and_replace(state=PresenceState.ONLINE)
state=PresenceState.ONLINE,
)
state, persist_and_notify, federation_ping = handle_update( state, persist_and_notify, federation_ping = handle_update(
prev_state, new_state, is_mine=False, wheel_timer=wheel_timer, now=now prev_state, new_state, is_mine=False, wheel_timer=wheel_timer, now=now
@ -232,13 +218,16 @@ class PresenceUpdateTestCase(unittest.TestCase):
self.assertEquals(new_state.status_msg, state.status_msg) self.assertEquals(new_state.status_msg, state.status_msg)
self.assertEquals(wheel_timer.insert.call_count, 1) self.assertEquals(wheel_timer.insert.call_count, 1)
wheel_timer.insert.assert_has_calls([ wheel_timer.insert.assert_has_calls(
[
call( call(
now=now, now=now,
obj=user_id, obj=user_id,
then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT,
), )
], any_order=True) ],
any_order=True,
)
def test_online_to_offline(self): def test_online_to_offline(self):
wheel_timer = Mock() wheel_timer = Mock()
@ -247,14 +236,10 @@ class PresenceUpdateTestCase(unittest.TestCase):
prev_state = UserPresenceState.default(user_id) prev_state = UserPresenceState.default(user_id)
prev_state = prev_state.copy_and_replace( prev_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE, last_active_ts=now, currently_active=True
last_active_ts=now,
currently_active=True,
) )
new_state = prev_state.copy_and_replace( new_state = prev_state.copy_and_replace(state=PresenceState.OFFLINE)
state=PresenceState.OFFLINE,
)
state, persist_and_notify, federation_ping = handle_update( state, persist_and_notify, federation_ping = handle_update(
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
@ -273,14 +258,10 @@ class PresenceUpdateTestCase(unittest.TestCase):
prev_state = UserPresenceState.default(user_id) prev_state = UserPresenceState.default(user_id)
prev_state = prev_state.copy_and_replace( prev_state = prev_state.copy_and_replace(
state=PresenceState.ONLINE, state=PresenceState.ONLINE, last_active_ts=now, currently_active=True
last_active_ts=now,
currently_active=True,
) )
new_state = prev_state.copy_and_replace( new_state = prev_state.copy_and_replace(state=PresenceState.UNAVAILABLE)
state=PresenceState.UNAVAILABLE,
)
state, persist_and_notify, federation_ping = handle_update( state, persist_and_notify, federation_ping = handle_update(
prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now prev_state, new_state, is_mine=True, wheel_timer=wheel_timer, now=now
@ -293,13 +274,16 @@ class PresenceUpdateTestCase(unittest.TestCase):
self.assertEquals(new_state.status_msg, state.status_msg) self.assertEquals(new_state.status_msg, state.status_msg)
self.assertEquals(wheel_timer.insert.call_count, 1) self.assertEquals(wheel_timer.insert.call_count, 1)
wheel_timer.insert.assert_has_calls([ wheel_timer.insert.assert_has_calls(
[
call( call(
now=now, now=now,
obj=user_id, obj=user_id,
then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
)
],
any_order=True,
) )
], any_order=True)
class PresenceTimeoutTestCase(unittest.TestCase): class PresenceTimeoutTestCase(unittest.TestCase):
@ -314,9 +298,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_user_sync_ts=now, last_user_sync_ts=now,
) )
new_state = handle_timeout( new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) self.assertEquals(new_state.state, PresenceState.UNAVAILABLE)
@ -332,9 +314,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1,
) )
new_state = handle_timeout( new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
self.assertEquals(new_state.state, PresenceState.OFFLINE) self.assertEquals(new_state.state, PresenceState.OFFLINE)
@ -369,9 +349,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1, last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1,
) )
new_state = handle_timeout( new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
self.assertEquals(new_state, new_state) self.assertEquals(new_state, new_state)
@ -388,9 +366,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_federation_update_ts=now, last_federation_update_ts=now,
) )
new_state = handle_timeout( new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNone(new_state) self.assertIsNone(new_state)
@ -425,9 +401,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
last_federation_update_ts=now, last_federation_update_ts=now,
) )
new_state = handle_timeout( new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now)
state, is_mine=True, syncing_user_ids=set(), now=now
)
self.assertIsNotNone(new_state) self.assertIsNotNone(new_state)
self.assertEquals(state, new_state) self.assertEquals(state, new_state)

View File

@ -54,9 +54,7 @@ class ProfileTestCase(unittest.TestCase):
federation_client=self.mock_federation, federation_client=self.mock_federation,
federation_server=Mock(), federation_server=Mock(),
federation_registry=self.mock_registry, federation_registry=self.mock_registry,
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=["send_message"]),
"send_message",
])
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
@ -74,9 +72,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_name(self): def test_get_my_name(self):
yield self.store.set_profile_displayname( yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
self.frank.localpart, "Frank"
)
displayname = yield self.handler.get_displayname(self.frank) displayname = yield self.handler.get_displayname(self.frank)
@ -85,22 +81,18 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_name(self): def test_set_my_name(self):
yield self.handler.set_displayname( yield self.handler.set_displayname(
self.frank, self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
synapse.types.create_requester(self.frank),
"Frank Jr."
) )
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), (yield self.store.get_profile_displayname(self.frank.localpart)),
"Frank Jr." "Frank Jr.",
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_name_noauth(self): def test_set_my_name_noauth(self):
d = self.handler.set_displayname( d = self.handler.set_displayname(
self.frank, self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
synapse.types.create_requester(self.bob),
"Frank Jr."
) )
yield self.assertFailure(d, AuthError) yield self.assertFailure(d, AuthError)
@ -145,11 +137,12 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_avatar(self): def test_set_my_avatar(self):
yield self.handler.set_avatar_url( yield self.handler.set_avatar_url(
self.frank, synapse.types.create_requester(self.frank), self.frank,
"http://my.server/pic.gif" synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
) )
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)), (yield self.store.get_profile_avatar_url(self.frank.localpart)),
"http://my.server/pic.gif" "http://my.server/pic.gif",
) )

View File

@ -46,7 +46,8 @@ class RegistrationTestCase(unittest.TestCase):
profile_handler=Mock(), profile_handler=Mock(),
) )
self.macaroon_generator = Mock( self.macaroon_generator = Mock(
generate_access_token=Mock(return_value='secret')) generate_access_token=Mock(return_value='secret')
)
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator) self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs) self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler self.handler = self.hs.get_handlers().registration_handler
@ -62,7 +63,8 @@ class RegistrationTestCase(unittest.TestCase):
user_id = "@someone:test" user_id = "@someone:test"
requester = create_requester("@as:test") requester = create_requester("@as:test")
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = yield self.handler.get_or_create_user(
requester, local_part, display_name) requester, local_part, display_name
)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@ -73,13 +75,15 @@ class RegistrationTestCase(unittest.TestCase):
yield store.register( yield store.register(
user_id=frank.to_string(), user_id=frank.to_string(),
token="jkv;g498752-43gj['eamb!-5", token="jkv;g498752-43gj['eamb!-5",
password_hash=None) password_hash=None,
)
local_part = "frank" local_part = "frank"
display_name = "Frank" display_name = "Frank"
user_id = "@frank:test" user_id = "@frank:test"
requester = create_requester("@as:test") requester = create_requester("@as:test")
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = yield self.handler.get_or_create_user(
requester, local_part, display_name) requester, local_part, display_name
)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')

View File

@ -38,23 +38,19 @@ def _expect_edu(destination, edu_type, content, origin="test"):
"origin": origin, "origin": origin,
"origin_server_ts": 1000000, "origin_server_ts": 1000000,
"pdus": [], "pdus": [],
"edus": [ "edus": [{"edu_type": edu_type, "content": content}],
{
"edu_type": edu_type,
"content": content,
}
],
} }
def _make_edu_json(origin, edu_type, content): def _make_edu_json(origin, edu_type, content):
return json.dumps( return json.dumps(_expect_edu("test", edu_type, content, origin=origin)).encode(
_expect_edu("test", edu_type, content, origin=origin) 'utf8'
).encode('utf8') )
class TypingNotificationsTestCase(unittest.TestCase): class TypingNotificationsTestCase(unittest.TestCase):
"""Tests typing notifications to rooms.""" """Tests typing notifications to rooms."""
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.clock = MockClock() self.clock = MockClock()
@ -74,7 +70,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
"test", "test",
auth=self.auth, auth=self.auth,
clock=self.clock, clock=self.clock,
datastore=Mock(spec=[ datastore=Mock(
spec=[
# Bits that Federation needs # Bits that Federation needs
"prep_send_transaction", "prep_send_transaction",
"delivered_txn", "delivered_txn",
@ -85,7 +82,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
# Bits that user_directory needs # Bits that user_directory needs
"get_user_directory_stream_pos", "get_user_directory_stream_pos",
"get_current_state_deltas", "get_current_state_deltas",
]), ]
),
state_handler=self.state_handler, state_handler=self.state_handler,
handlers=Mock(), handlers=Mock(),
notifier=mock_notifier, notifier=mock_notifier,
@ -100,19 +98,16 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.event_source = hs.get_event_sources().sources["typing"] self.event_source = hs.get_event_sources().sources["typing"]
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
retry_timings_res = { retry_timings_res = {"destination": "", "retry_last_ts": 0, "retry_interval": 0}
"destination": "", self.datastore.get_destination_retry_timings.return_value = defer.succeed(
"retry_last_ts": 0, retry_timings_res
"retry_interval": 0,
}
self.datastore.get_destination_retry_timings.return_value = (
defer.succeed(retry_timings_res)
) )
self.datastore.get_devices_by_remote.return_value = (0, []) self.datastore.get_devices_by_remote.return_value = (0, [])
def get_received_txn_response(*args): def get_received_txn_response(*args):
return defer.succeed(None) return defer.succeed(None)
self.datastore.get_received_txn_response = get_received_txn_response self.datastore.get_received_txn_response = get_received_txn_response
self.room_id = "a-room" self.room_id = "a-room"
@ -125,10 +120,12 @@ class TypingNotificationsTestCase(unittest.TestCase):
def get_joined_hosts_for_room(room_id): def get_joined_hosts_for_room(room_id):
return set(member.domain for member in self.room_members) return set(member.domain for member in self.room_members)
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
def get_current_user_in_room(room_id): def get_current_user_in_room(room_id):
return set(str(u) for u in self.room_members) return set(str(u) for u in self.room_members)
self.state_handler.get_current_user_in_room = get_current_user_in_room self.state_handler.get_current_user_in_room = get_current_user_in_room
self.datastore.get_user_directory_stream_pos.return_value = ( self.datastore.get_user_directory_stream_pos.return_value = (
@ -136,19 +133,13 @@ class TypingNotificationsTestCase(unittest.TestCase):
defer.succeed(1) defer.succeed(1)
) )
self.datastore.get_current_state_deltas.return_value = ( self.datastore.get_current_state_deltas.return_value = None
None
)
self.auth.check_joined_room = check_joined_room self.auth.check_joined_room = check_joined_room
self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = ( self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
lambda *args, **kargs: ([], 0) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
)
self.datastore.delete_device_msgs_for_remote = (
lambda *args, **kargs: None
)
# Some local users to test with # Some local users to test with
self.u_apple = UserID.from_string("@apple:test") self.u_apple = UserID.from_string("@apple:test")
@ -170,24 +161,23 @@ class TypingNotificationsTestCase(unittest.TestCase):
timeout=20000, timeout=20000,
) )
self.on_new_event.assert_has_calls([ self.on_new_event.assert_has_calls(
call('typing_key', 1, rooms=[self.room_id]), [call('typing_key', 1, rooms=[self.room_id])]
]) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events( events = yield self.event_source.get_new_events(
room_ids=[self.room_id], room_ids=[self.room_id], from_key=0
from_key=0,
) )
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
{"type": "m.typing", {
"type": "m.typing",
"room_id": self.room_id, "room_id": self.room_id,
"content": { "content": {"user_ids": [self.u_apple.to_string()]},
"user_ids": [self.u_apple.to_string()], }
}}, ],
]
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -206,13 +196,13 @@ class TypingNotificationsTestCase(unittest.TestCase):
"room_id": self.room_id, "room_id": self.room_id,
"user_id": self.u_apple.to_string(), "user_id": self.u_apple.to_string(),
"typing": True, "typing": True,
} },
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True, long_retries=True,
backoff_on_404=True, backoff_on_404=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK")),
) )
yield self.handler.started_typing( yield self.handler.started_typing(
@ -240,27 +230,29 @@ class TypingNotificationsTestCase(unittest.TestCase):
"room_id": self.room_id, "room_id": self.room_id,
"user_id": self.u_onion.to_string(), "user_id": self.u_onion.to_string(),
"typing": True, "typing": True,
} },
), ),
federation_auth=True, federation_auth=True,
) )
self.on_new_event.assert_has_calls([ self.on_new_event.assert_has_calls(
call('typing_key', 1, rooms=[self.room_id]), [call('typing_key', 1, rooms=[self.room_id])]
]) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events( events = yield self.event_source.get_new_events(
room_ids=[self.room_id], room_ids=[self.room_id], from_key=0
from_key=0
) )
self.assertEquals(events[0], [{ self.assertEquals(
events[0],
[
{
"type": "m.typing", "type": "m.typing",
"room_id": self.room_id, "room_id": self.room_id,
"content": { "content": {"user_ids": [self.u_onion.to_string()]},
"user_ids": [self.u_onion.to_string()], }
}, ],
}]) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_stopped_typing(self): def test_stopped_typing(self):
@ -278,17 +270,18 @@ class TypingNotificationsTestCase(unittest.TestCase):
"room_id": self.room_id, "room_id": self.room_id,
"user_id": self.u_apple.to_string(), "user_id": self.u_apple.to_string(),
"typing": False, "typing": False,
} },
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True, long_retries=True,
backoff_on_404=True, backoff_on_404=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK")),
) )
# Gut-wrenching # Gut-wrenching
from synapse.handlers.typing import RoomMember from synapse.handlers.typing import RoomMember
member = RoomMember(self.room_id, self.u_apple.to_string()) member = RoomMember(self.room_id, self.u_apple.to_string())
self.handler._member_typing_until[member] = 1002000 self.handler._member_typing_until[member] = 1002000
self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()]) self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()])
@ -296,29 +289,29 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.assertEquals(self.event_source.get_current_key(), 0) self.assertEquals(self.event_source.get_current_key(), 0)
yield self.handler.stopped_typing( yield self.handler.stopped_typing(
target_user=self.u_apple, target_user=self.u_apple, auth_user=self.u_apple, room_id=self.room_id
auth_user=self.u_apple,
room_id=self.room_id,
) )
self.on_new_event.assert_has_calls([ self.on_new_event.assert_has_calls(
call('typing_key', 1, rooms=[self.room_id]), [call('typing_key', 1, rooms=[self.room_id])]
]) )
yield put_json.await_calls() yield put_json.await_calls()
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events( events = yield self.event_source.get_new_events(
room_ids=[self.room_id], room_ids=[self.room_id], from_key=0
from_key=0,
) )
self.assertEquals(events[0], [{ self.assertEquals(
events[0],
[
{
"type": "m.typing", "type": "m.typing",
"room_id": self.room_id, "room_id": self.room_id,
"content": { "content": {"user_ids": []},
"user_ids": [], }
}, ],
}]) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_typing_timeout(self): def test_typing_timeout(self):
@ -333,42 +326,46 @@ class TypingNotificationsTestCase(unittest.TestCase):
timeout=10000, timeout=10000,
) )
self.on_new_event.assert_has_calls([ self.on_new_event.assert_has_calls(
call('typing_key', 1, rooms=[self.room_id]), [call('typing_key', 1, rooms=[self.room_id])]
]) )
self.on_new_event.reset_mock() self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events( events = yield self.event_source.get_new_events(
room_ids=[self.room_id], room_ids=[self.room_id], from_key=0
from_key=0,
) )
self.assertEquals(events[0], [{ self.assertEquals(
events[0],
[
{
"type": "m.typing", "type": "m.typing",
"room_id": self.room_id, "room_id": self.room_id,
"content": { "content": {"user_ids": [self.u_apple.to_string()]},
"user_ids": [self.u_apple.to_string()], }
}, ],
}]) )
self.clock.advance_time(16) self.clock.advance_time(16)
self.on_new_event.assert_has_calls([ self.on_new_event.assert_has_calls(
call('typing_key', 2, rooms=[self.room_id]), [call('typing_key', 2, rooms=[self.room_id])]
]) )
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
events = yield self.event_source.get_new_events( events = yield self.event_source.get_new_events(
room_ids=[self.room_id], room_ids=[self.room_id], from_key=1
from_key=1,
) )
self.assertEquals(events[0], [{ self.assertEquals(
events[0],
[
{
"type": "m.typing", "type": "m.typing",
"room_id": self.room_id, "room_id": self.room_id,
"content": { "content": {"user_ids": []},
"user_ids": [], }
}, ],
}]) )
# SYN-230 - see if we can still set after timeout # SYN-230 - see if we can still set after timeout
@ -379,20 +376,22 @@ class TypingNotificationsTestCase(unittest.TestCase):
timeout=10000, timeout=10000,
) )
self.on_new_event.assert_has_calls([ self.on_new_event.assert_has_calls(
call('typing_key', 3, rooms=[self.room_id]), [call('typing_key', 3, rooms=[self.room_id])]
]) )
self.on_new_event.reset_mock() self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3) self.assertEquals(self.event_source.get_current_key(), 3)
events = yield self.event_source.get_new_events( events = yield self.event_source.get_new_events(
room_ids=[self.room_id], room_ids=[self.room_id], from_key=0
from_key=0,
) )
self.assertEquals(events[0], [{ self.assertEquals(
events[0],
[
{
"type": "m.typing", "type": "m.typing",
"room_id": self.room_id, "room_id": self.room_id,
"content": { "content": {"user_ids": [self.u_apple.to_string()]},
"user_ids": [self.u_apple.to_string()], }
}, ],
}]) )

View File

@ -45,9 +45,7 @@ class ServerNameTestCase(unittest.TestCase):
try: try:
parse_and_validate_server_name(i) parse_and_validate_server_name(i)
self.fail( self.fail(
"Expected parse_and_validate_server_name('%s') to throw" % ( "Expected parse_and_validate_server_name('%s') to throw" % (i,)
i,
),
) )
except ValueError: except ValueError:
pass pass

View File

@ -31,6 +31,7 @@ from tests.utils import setup_test_homeserver
class TestReplicationClientHandler(ReplicationClientHandler): class TestReplicationClientHandler(ReplicationClientHandler):
"""Overrides on_rdata so that we can wait for it to happen""" """Overrides on_rdata so that we can wait for it to happen"""
def __init__(self, store): def __init__(self, store):
super(TestReplicationClientHandler, self).__init__(store) super(TestReplicationClientHandler, self).__init__(store)
self._rdata_awaiters = [] self._rdata_awaiters = []
@ -56,9 +57,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
"blue", "blue",
http_client=None, http_client=None,
federation_client=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=["send_message"]),
"send_message",
]),
) )
self.hs.get_ratelimiter().send_message.return_value = (True, 0) self.hs.get_ratelimiter().send_message.return_value = (True, 0)

View File

@ -29,20 +29,14 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_account_data(self): def test_user_account_data(self):
yield self.master_store.add_account_data_for_user( yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
USER_ID, TYPE, {"a": 1}
)
yield self.replicate() yield self.replicate()
yield self.check( yield self.check(
"get_global_account_data_by_type_for_user", "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1}
[TYPE, USER_ID], {"a": 1}
) )
yield self.master_store.add_account_data_for_user( yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
USER_ID, TYPE, {"a": 2}
)
yield self.replicate() yield self.replicate()
yield self.check( yield self.check(
"get_global_account_data_by_type_for_user", "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2}
[TYPE, USER_ID], {"a": 2}
) )

View File

@ -38,6 +38,7 @@ def patch__eq__(cls):
def unpatch(): def unpatch():
if eq is not None: if eq is not None:
cls.__eq__ = eq cls.__eq__ = eq
return unpatch return unpatch
@ -48,10 +49,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def setUp(self): def setUp(self):
# Patch up the equality operator for events so that we can check # Patch up the equality operator for events so that we can check
# whether lists of events match using assertEquals # whether lists of events match using assertEquals
self.unpatches = [ self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(FrozenEvent)]
patch__eq__(_EventInternalMetadata),
patch__eq__(FrozenEvent),
]
return super(SlavedEventStoreTestCase, self).setUp() return super(SlavedEventStoreTestCase, self).setUp()
def tearDown(self): def tearDown(self):
@ -61,33 +59,27 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def test_get_latest_event_ids_in_room(self): def test_get_latest_event_ids_in_room(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID) create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate() yield self.replicate()
yield self.check( yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
"get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]
)
join = yield self.persist( join = yield self.persist(
type="m.room.member", key=USER_ID, membership="join", type="m.room.member",
key=USER_ID,
membership="join",
prev_events=[(create.event_id, {})], prev_events=[(create.event_id, {})],
) )
yield self.replicate() yield self.replicate()
yield self.check( yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
"get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_redactions(self): def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join") yield self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist( msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
type="m.room.message", msgtype="m.text", body="Hello"
)
yield self.replicate() yield self.replicate()
yield self.check("get_event", [msg.event_id], msg) yield self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist( redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id)
type="m.room.redaction", redacts=msg.event_id
)
yield self.replicate() yield self.replicate()
msg_dict = msg.get_dict() msg_dict = msg.get_dict()
@ -102,9 +94,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join") yield self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist( msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
type="m.room.message", msgtype="m.text", body="Hello"
)
yield self.replicate() yield self.replicate()
yield self.check("get_event", [msg.event_id], msg) yield self.check("get_event", [msg.event_id], msg)
@ -127,10 +117,19 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.member", key=USER_ID_2, membership="invite" type="m.room.member", key=USER_ID_2, membership="invite"
) )
yield self.replicate() yield self.replicate()
yield self.check("get_invited_rooms_for_user", [USER_ID_2], [RoomsForUser( yield self.check(
ROOM_ID, USER_ID, "invite", event.event_id, "get_invited_rooms_for_user",
event.internal_metadata.stream_ordering [USER_ID_2],
)]) [
RoomsForUser(
ROOM_ID,
USER_ID,
"invite",
event.event_id,
event.internal_metadata.stream_ordering,
)
],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_push_actions_for_user(self): def test_push_actions_for_user(self):
@ -146,40 +145,55 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
yield self.check( yield self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 0} {"highlight_count": 0, "notify_count": 0},
) )
yield self.persist( yield self.persist(
type="m.room.message", msgtype="m.text", body="world", type="m.room.message",
msgtype="m.text",
body="world",
push_actions=[(USER_ID_2, ["notify"])], push_actions=[(USER_ID_2, ["notify"])],
) )
yield self.replicate() yield self.replicate()
yield self.check( yield self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 1} {"highlight_count": 0, "notify_count": 1},
) )
yield self.persist( yield self.persist(
type="m.room.message", msgtype="m.text", body="world", type="m.room.message",
push_actions=[(USER_ID_2, [ msgtype="m.text",
"notify", {"set_tweak": "highlight", "value": True} body="world",
])], push_actions=[
(USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
],
) )
yield self.replicate() yield self.replicate()
yield self.check( yield self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 1, "notify_count": 2} {"highlight_count": 1, "notify_count": 2},
) )
event_id = 0 event_id = 0
@defer.inlineCallbacks @defer.inlineCallbacks
def persist( def persist(
self, sender=USER_ID, room_id=ROOM_ID, type={}, key=None, internal={}, self,
state=None, reset_state=False, backfill=False, sender=USER_ID,
depth=None, prev_events=[], auth_events=[], prev_state=[], redacts=None, room_id=ROOM_ID,
type={},
key=None,
internal={},
state=None,
reset_state=False,
backfill=False,
depth=None,
prev_events=[],
auth_events=[],
prev_state=[],
redacts=None,
push_actions=[], push_actions=[],
**content **content
): ):
@ -219,34 +233,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.event_id += 1 self.event_id += 1
if state is not None: if state is not None:
state_ids = { state_ids = {key: e.event_id for key, e in state.items()}
key: e.event_id for key, e in state.items()
}
context = EventContext.with_state( context = EventContext.with_state(
state_group=None, state_group=None, current_state_ids=state_ids, prev_state_ids=state_ids
current_state_ids=state_ids,
prev_state_ids=state_ids
) )
else: else:
state_handler = self.hs.get_state_handler() state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event) context = yield state_handler.compute_event_context(event)
yield self.master_store.add_push_actions_to_staging( yield self.master_store.add_push_actions_to_staging(
event.event_id, { event.event_id, {user_id: actions for user_id, actions in push_actions}
user_id: actions
for user_id, actions in push_actions
},
) )
ordering = None ordering = None
if backfill: if backfill:
yield self.master_store.persist_events( yield self.master_store.persist_events([(event, context)], backfilled=True)
[(event, context)], backfilled=True
)
else: else:
ordering, _ = yield self.master_store.persist_event( ordering, _ = yield self.master_store.persist_event(event, context)
event, context,
)
if ordering: if ordering:
event.internal_metadata.stream_ordering = ordering event.internal_metadata.stream_ordering = ordering

View File

@ -34,6 +34,6 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {} ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
) )
yield self.replicate() yield self.replicate()
yield self.check("get_receipts_for_user", [USER_ID, "m.read"], { yield self.check(
ROOM_ID: EVENT_ID "get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}
}) )

View File

@ -11,7 +11,6 @@ from tests.utils import MockClock
class HttpTransactionCacheTestCase(unittest.TestCase): class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.clock = MockClock() self.clock = MockClock()
self.hs = Mock() self.hs = Mock()
@ -24,9 +23,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_executes_given_function(self): def test_executes_given_function(self):
cb = Mock( cb = Mock(return_value=defer.succeed(self.mock_http_response))
return_value=defer.succeed(self.mock_http_response)
)
res = yield self.cache.fetch_or_execute( res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg" self.mock_key, cb, "some_arg", keyword="arg"
) )
@ -35,9 +32,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_deduplicates_based_on_key(self): def test_deduplicates_based_on_key(self):
cb = Mock( cb = Mock(return_value=defer.succeed(self.mock_http_response))
return_value=defer.succeed(self.mock_http_response)
)
for i in range(3): # invoke multiple times for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute( res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
@ -120,29 +115,18 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cleans_up(self): def test_cleans_up(self):
cb = Mock( cb = Mock(return_value=defer.succeed(self.mock_http_response))
return_value=defer.succeed(self.mock_http_response) yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
)
yield self.cache.fetch_or_execute(
self.mock_key, cb, "an arg"
)
# should NOT have cleaned up yet # should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2) self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
yield self.cache.fetch_or_execute( yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
self.mock_key, cb, "an arg"
)
# still using cache # still using cache
cb.assert_called_once_with("an arg") cb.assert_called_once_with("an arg")
self.clock.advance_time_msec(CLEANUP_PERIOD_MS) self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
yield self.cache.fetch_or_execute( yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
self.mock_key, cb, "an arg"
)
# no longer using cache # no longer using cache
self.assertEqual(cb.call_count, 2) self.assertEqual(cb.call_count, 2)
self.assertEqual( self.assertEqual(cb.call_args_list, [call("an arg"), call("an arg")])
cb.call_args_list,
[call("an arg",), call("an arg",)]
)

View File

@ -215,6 +215,7 @@ class UserRegisterTestCase(unittest.TestCase):
mac. Admin is optional. Additional checks are done for length and mac. Admin is optional. Additional checks are done for length and
type. type.
""" """
def nonce(): def nonce():
request, channel = make_request("GET", self.url) request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock) render(request, self.resource, self.clock)
@ -289,7 +290,9 @@ class UserRegisterTestCase(unittest.TestCase):
self.assertEqual('Invalid password', channel.json_body["error"]) self.assertEqual('Invalid password', channel.json_body["error"])
# Must not have null bytes # Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}) body = json.dumps(
{"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}
)
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) render(request, self.resource, self.clock)

View File

@ -43,9 +43,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
http_client=None, http_client=None,
federation_client=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=["send_message"]),
"send_message",
]),
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
@ -98,18 +96,12 @@ class EventStreamPermissionsTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_stream_room_permissions(self): def test_stream_room_permissions(self):
room_id = yield self.create_room_as( room_id = yield self.create_room_as(self.other_user, tok=self.other_token)
self.other_user,
tok=self.other_token
)
yield self.send(room_id, tok=self.other_token) yield self.send(room_id, tok=self.other_token)
# invited to room (expect no content for room) # invited to room (expect no content for room)
yield self.invite( yield self.invite(
room_id, room_id, src=self.other_user, targ=self.user_id, tok=self.other_token
src=self.other_user,
targ=self.user_id,
tok=self.other_token
) )
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
@ -120,13 +112,16 @@ class EventStreamPermissionsTestCase(RestTestCase):
# We may get a presence event for ourselves down # We may get a presence event for ourselves down
self.assertEquals( self.assertEquals(
0, 0,
len([ len(
c for c in response["chunk"] [
c
for c in response["chunk"]
if not ( if not (
c.get("type") == "m.presence" c.get("type") == "m.presence"
and c["content"].get("user_id") == self.user_id and c["content"].get("user_id") == self.user_id
) )
]) ]
),
) )
# joined room (expect all content for room) # joined room (expect all content for room)

View File

@ -36,12 +36,14 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self.mock_handler = Mock(spec=[ self.mock_handler = Mock(
spec=[
"get_displayname", "get_displayname",
"set_displayname", "set_displayname",
"get_avatar_url", "get_avatar_url",
"set_avatar_url", "set_avatar_url",
]) ]
)
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
"test", "test",
@ -49,7 +51,7 @@ class ProfileTestCase(unittest.TestCase):
resource_for_client=self.mock_resource, resource_for_client=self.mock_resource,
federation=Mock(), federation=Mock(),
federation_client=Mock(), federation_client=Mock(),
profile_handler=self.mock_handler profile_handler=self.mock_handler,
) )
def _get_user_by_req(request=None, allow_guest=False): def _get_user_by_req(request=None, allow_guest=False):
@ -78,9 +80,7 @@ class ProfileTestCase(unittest.TestCase):
mocked_set.return_value = defer.succeed(()) mocked_set.return_value = defer.succeed(())
(code, response) = yield self.mock_resource.trigger( (code, response) = yield self.mock_resource.trigger(
"PUT", "PUT", "/profile/%s/displayname" % (myid), b'{"displayname": "Frank Jr."}'
"/profile/%s/displayname" % (myid),
b'{"displayname": "Frank Jr."}'
) )
self.assertEquals(200, code) self.assertEquals(200, code)
@ -94,14 +94,12 @@ class ProfileTestCase(unittest.TestCase):
mocked_set.side_effect = AuthError(400, "message") mocked_set.side_effect = AuthError(400, "message")
(code, response) = yield self.mock_resource.trigger( (code, response) = yield self.mock_resource.trigger(
"PUT", "/profile/%s/displayname" % ("@4567:test"), "PUT",
b'{"displayname": "Frank Jr."}' "/profile/%s/displayname" % ("@4567:test"),
b'{"displayname": "Frank Jr."}',
) )
self.assertTrue( self.assertTrue(400 <= code < 499, msg="code %d is in the 4xx range" % (code))
400 <= code < 499,
msg="code %d is in the 4xx range" % (code)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_other_name(self): def test_get_other_name(self):
@ -121,14 +119,12 @@ class ProfileTestCase(unittest.TestCase):
mocked_set.side_effect = SynapseError(400, "message") mocked_set.side_effect = SynapseError(400, "message")
(code, response) = yield self.mock_resource.trigger( (code, response) = yield self.mock_resource.trigger(
"PUT", "/profile/%s/displayname" % ("@opaque:elsewhere"), "PUT",
b'{"displayname":"bob"}' "/profile/%s/displayname" % ("@opaque:elsewhere"),
b'{"displayname":"bob"}',
) )
self.assertTrue( self.assertTrue(400 <= code <= 499, msg="code %d is in the 4xx range" % (code))
400 <= code <= 499,
msg="code %d is in the 4xx range" % (code)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_avatar(self): def test_get_my_avatar(self):
@ -151,7 +147,7 @@ class ProfileTestCase(unittest.TestCase):
(code, response) = yield self.mock_resource.trigger( (code, response) = yield self.mock_resource.trigger(
"PUT", "PUT",
"/profile/%s/avatar_url" % (myid), "/profile/%s/avatar_url" % (myid),
b'{"avatar_url": "http://my.server/pic.gif"}' b'{"avatar_url": "http://my.server/pic.gif"}',
) )
self.assertEquals(200, code) self.assertEquals(200, code)

View File

@ -32,6 +32,7 @@ class CreateUserServletTestCase(unittest.TestCase):
""" """
Tests for CreateUserRestServlet. Tests for CreateUserRestServlet.
""" """
if PY3: if PY3:
skip = "Not ported to Python 3." skip = "Not ported to Python 3."

View File

@ -31,6 +31,7 @@ PATH_PREFIX = "/_matrix/client/api/v1"
class RoomTypingTestCase(RestTestCase): class RoomTypingTestCase(RestTestCase):
""" Tests /rooms/$room_id/typing/$user_id REST API. """ """ Tests /rooms/$room_id/typing/$user_id REST API. """
user_id = "@sid:red" user_id = "@sid:red"
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@ -47,9 +48,7 @@ class RoomTypingTestCase(RestTestCase):
clock=self.clock, clock=self.clock,
http_client=None, http_client=None,
federation_client=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=["send_message"]),
"send_message",
]),
) )
self.hs = hs self.hs = hs
@ -71,6 +70,7 @@ class RoomTypingTestCase(RestTestCase):
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
hs.get_datastore().insert_client_ip = _insert_client_ip hs.get_datastore().insert_client_ip = _insert_client_ip
def get_room_members(room_id): def get_room_members(room_id):
@ -94,6 +94,7 @@ class RoomTypingTestCase(RestTestCase):
else: else:
if remotedomains is not None: if remotedomains is not None:
remotedomains.add(member.domain) remotedomains.add(member.domain)
hs.get_room_member_handler().fetch_room_distributions_into = ( hs.get_room_member_handler().fetch_room_distributions_into = (
fetch_room_distributions_into fetch_room_distributions_into
) )
@ -107,37 +108,42 @@ class RoomTypingTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_typing(self): def test_set_typing(self):
(code, _) = yield self.mock_resource.trigger( (code, _) = yield self.mock_resource.trigger(
"PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), "PUT",
'{"typing": true, "timeout": 30000}' "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
'{"typing": true, "timeout": 30000}',
) )
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events( events = yield self.event_source.get_new_events(
from_key=0, from_key=0, room_ids=[self.room_id]
room_ids=[self.room_id],
) )
self.assertEquals(events[0], [{ self.assertEquals(
events[0],
[
{
"type": "m.typing", "type": "m.typing",
"room_id": self.room_id, "room_id": self.room_id,
"content": { "content": {"user_ids": [self.user_id]},
"user_ids": [self.user_id],
} }
}]) ],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_not_typing(self): def test_set_not_typing(self):
(code, _) = yield self.mock_resource.trigger( (code, _) = yield self.mock_resource.trigger(
"PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), "PUT",
'{"typing": false}' "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
'{"typing": false}',
) )
self.assertEquals(200, code) self.assertEquals(200, code)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_typing_timeout(self): def test_typing_timeout(self):
(code, _) = yield self.mock_resource.trigger( (code, _) = yield self.mock_resource.trigger(
"PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), "PUT",
'{"typing": true, "timeout": 30000}' "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
'{"typing": true, "timeout": 30000}',
) )
self.assertEquals(200, code) self.assertEquals(200, code)
@ -148,8 +154,9 @@ class RoomTypingTestCase(RestTestCase):
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
(code, _) = yield self.mock_resource.trigger( (code, _) = yield self.mock_resource.trigger(
"PUT", "/rooms/%s/typing/%s" % (self.room_id, self.user_id), "PUT",
'{"typing": true, "timeout": 30000}' "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
'{"typing": true, "timeout": 30000}',
) )
self.assertEquals(200, code) self.assertEquals(200, code)

View File

@ -55,25 +55,39 @@ class RestTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
yield self.change_membership(room=room, src=src, targ=targ, tok=tok, yield self.change_membership(
room=room,
src=src,
targ=targ,
tok=tok,
membership=Membership.INVITE, membership=Membership.INVITE,
expect_code=expect_code) expect_code=expect_code,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def join(self, room=None, user=None, expect_code=200, tok=None): def join(self, room=None, user=None, expect_code=200, tok=None):
yield self.change_membership(room=room, src=user, targ=user, tok=tok, yield self.change_membership(
room=room,
src=user,
targ=user,
tok=tok,
membership=Membership.JOIN, membership=Membership.JOIN,
expect_code=expect_code) expect_code=expect_code,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def leave(self, room=None, user=None, expect_code=200, tok=None): def leave(self, room=None, user=None, expect_code=200, tok=None):
yield self.change_membership(room=room, src=user, targ=user, tok=tok, yield self.change_membership(
room=room,
src=user,
targ=user,
tok=tok,
membership=Membership.LEAVE, membership=Membership.LEAVE,
expect_code=expect_code) expect_code=expect_code,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def change_membership(self, room, src, targ, membership, tok=None, def change_membership(self, room, src, targ, membership, tok=None, expect_code=200):
expect_code=200):
temp_id = self.auth_user_id temp_id = self.auth_user_id
self.auth_user_id = src self.auth_user_id = src
@ -81,16 +95,15 @@ class RestTestCase(unittest.TestCase):
if tok: if tok:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
data = { data = {"membership": membership}
"membership": membership
}
(code, response) = yield self.mock_resource.trigger( (code, response) = yield self.mock_resource.trigger(
"PUT", path, json.dumps(data) "PUT", path, json.dumps(data)
) )
self.assertEquals( self.assertEquals(
expect_code, code, expect_code,
msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response) code,
msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response),
) )
self.auth_user_id = temp_id self.auth_user_id = temp_id
@ -100,17 +113,15 @@ class RestTestCase(unittest.TestCase):
(code, response) = yield self.mock_resource.trigger( (code, response) = yield self.mock_resource.trigger(
"POST", "POST",
"/register", "/register",
json.dumps({ json.dumps(
"user": user_id, {"user": user_id, "password": "test", "type": "m.login.password"}
"password": "test", ),
"type": "m.login.password" )
}))
self.assertEquals(200, code, msg=response) self.assertEquals(200, code, msg=response)
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
def send(self, room_id, body=None, txn_id=None, tok=None, def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
expect_code=200):
if txn_id is None: if txn_id is None:
txn_id = "m%s" % (str(time.time())) txn_id = "m%s" % (str(time.time()))
if body is None: if body is None:
@ -132,8 +143,9 @@ class RestTestCase(unittest.TestCase):
actual (dict): The test result. Extra keys will not be checked. actual (dict): The test result. Extra keys will not be checked.
""" """
for key in required: for key in required:
self.assertEquals(required[key], actual[key], self.assertEquals(
msg="%s mismatch. %s" % (key, actual)) required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
)
@attr.s @attr.s
@ -156,7 +168,9 @@ class RestHelper(object):
if tok: if tok:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
request, channel = make_request("POST", path, json.dumps(content).encode('utf8')) request, channel = make_request(
"POST", path, json.dumps(content).encode('utf8')
)
request.render(self.resource) request.render(self.resource)
wait_until_result(self.hs.get_reactor(), channel) wait_until_result(self.hs.get_reactor(), channel)
@ -204,9 +218,7 @@ class RestHelper(object):
data = {"membership": membership} data = {"membership": membership}
request, channel = make_request( request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))
"PUT", path, json.dumps(data).encode('utf8')
)
request.render(self.resource) request.render(self.resource)
wait_until_result(self.hs.get_reactor(), channel) wait_until_result(self.hs.get_reactor(), channel)

View File

@ -101,9 +101,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
wait_until_result(self.clock, channel) wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals( self.assertEquals(channel.json_body["error"], "Invalid password")
channel.json_body["error"], "Invalid password"
)
def test_POST_bad_username(self): def test_POST_bad_username(self):
request_data = json.dumps({"username": 777, "password": "monkey"}) request_data = json.dumps({"username": 777, "password": "monkey"})
@ -112,9 +110,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
wait_until_result(self.clock, channel) wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals( self.assertEquals(channel.json_body["error"], "Invalid username")
channel.json_body["error"], "Invalid username"
)
def test_POST_user_valid(self): def test_POST_user_valid(self):
user_id = "@kermit:muppet" user_id = "@kermit:muppet"
@ -157,10 +153,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
wait_until_result(self.clock, channel) wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals( self.assertEquals(channel.json_body["error"], "Registration has been disabled")
channel.json_body["error"],
"Registration has been disabled",
)
def test_POST_guest_registration(self): def test_POST_guest_registration(self):
user_id = "a@b" user_id = "a@b"
@ -188,6 +181,4 @@ class RegisterRestServletTestCase(unittest.TestCase):
wait_until_result(self.clock, channel) wait_until_result(self.clock, channel)
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals( self.assertEquals(channel.json_body["error"], "Guest access is disabled")
channel.json_body["error"], "Guest access is disabled"
)

View File

@ -41,13 +41,11 @@ class MediaStorageTests(unittest.TestCase):
hs.get_reactor = Mock(return_value=reactor) hs.get_reactor = Mock(return_value=reactor)
hs.config.media_store_path = self.primary_base_path hs.config.media_store_path = self.primary_base_path
storage_providers = [FileStorageProviderBackend( storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)]
hs, self.secondary_base_path
)]
self.filepaths = MediaFilePaths(self.primary_base_path) self.filepaths = MediaFilePaths(self.primary_base_path)
self.media_storage = MediaStorage( self.media_storage = MediaStorage(
hs, self.primary_base_path, self.filepaths, storage_providers, hs, self.primary_base_path, self.filepaths, storage_providers
) )
def tearDown(self): def tearDown(self):

View File

@ -136,6 +136,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
""" """
A MemoryReactorClock that supports callFromThread. A MemoryReactorClock that supports callFromThread.
""" """
def callFromThread(self, callback, *args, **kwargs): def callFromThread(self, callback, *args, **kwargs):
""" """
Make the callback fire in the next reactor iteration. Make the callback fire in the next reactor iteration.
@ -184,6 +185,7 @@ def setup_test_homeserver(*args, **kwargs):
""" """
Threadless thread pool. Threadless thread pool.
""" """
def start(self): def start(self):
pass pass

View File

@ -25,7 +25,6 @@ from tests import unittest
class CacheTestCase(unittest.TestCase): class CacheTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.cache = Cache("test") self.cache = Cache("test")
@ -97,7 +96,6 @@ class CacheTestCase(unittest.TestCase):
class CacheDecoratorTestCase(unittest.TestCase): class CacheDecoratorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_passthrough(self): def test_passthrough(self):
class A(object): class A(object):
@ -180,8 +178,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
yield a.func(k) yield a.func(k)
self.assertTrue( self.assertTrue(
callcount[0] >= 14, callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
msg="Expected callcount >= 14, got %d" % (callcount[0])
) )
def test_prefill(self): def test_prefill(self):

View File

@ -34,7 +34,6 @@ from tests.utils import setup_test_homeserver
class ApplicationServiceStoreTestCase(unittest.TestCase): class ApplicationServiceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.as_yaml_files = [] self.as_yaml_files = []
@ -44,20 +43,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
password_providers=[], password_providers=[],
) )
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config, federation_sender=Mock(), federation_client=Mock()
federation_sender=Mock(),
federation_client=Mock(),
) )
self.as_token = "token1" self.as_token = "token1"
self.as_url = "some_url" self.as_url = "some_url"
self.as_id = "as1" self.as_id = "as1"
self._add_appservice( self._add_appservice(
self.as_token, self.as_token, self.as_id, self.as_url, "some_hs_token", "bob"
self.as_id,
self.as_url,
"some_hs_token",
"bob"
) )
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
@ -73,8 +66,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
pass pass
def _add_appservice(self, as_token, id, url, hs_token, sender): def _add_appservice(self, as_token, id, url, hs_token, sender):
as_yaml = dict(url=url, as_token=as_token, hs_token=hs_token, as_yaml = dict(
id=id, sender_localpart=sender, namespaces={}) url=url,
as_token=as_token,
hs_token=hs_token,
id=id,
sender_localpart=sender,
namespaces={},
)
# use the token as the filename # use the token as the filename
with open(as_token, 'w') as outfile: with open(as_token, 'w') as outfile:
outfile.write(yaml.dump(as_yaml)) outfile.write(yaml.dump(as_yaml))
@ -85,24 +84,13 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self.assertEquals(service, None) self.assertEquals(service, None)
def test_retrieval_of_service(self): def test_retrieval_of_service(self):
stored_service = self.store.get_app_service_by_token( stored_service = self.store.get_app_service_by_token(self.as_token)
self.as_token
)
self.assertEquals(stored_service.token, self.as_token) self.assertEquals(stored_service.token, self.as_token)
self.assertEquals(stored_service.id, self.as_id) self.assertEquals(stored_service.id, self.as_id)
self.assertEquals(stored_service.url, self.as_url) self.assertEquals(stored_service.url, self.as_url)
self.assertEquals( self.assertEquals(stored_service.namespaces[ApplicationService.NS_ALIASES], [])
stored_service.namespaces[ApplicationService.NS_ALIASES], self.assertEquals(stored_service.namespaces[ApplicationService.NS_ROOMS], [])
[] self.assertEquals(stored_service.namespaces[ApplicationService.NS_USERS], [])
)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_ROOMS],
[]
)
self.assertEquals(
stored_service.namespaces[ApplicationService.NS_USERS],
[]
)
def test_retrieval_of_all_services(self): def test_retrieval_of_all_services(self):
services = self.store.get_app_services() services = self.store.get_app_services()
@ -110,7 +98,6 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.as_yaml_files = [] self.as_yaml_files = []
@ -121,33 +108,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
password_providers=[], password_providers=[],
) )
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config, federation_sender=Mock(), federation_client=Mock()
federation_sender=Mock(),
federation_client=Mock(),
) )
self.db_pool = hs.get_db_pool() self.db_pool = hs.get_db_pool()
self.as_list = [ self.as_list = [
{ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
"token": "token1", {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"},
"url": "https://matrix-as.org", {"token": "beta_tok", "url": "https://beta.com", "id": "id_beta"},
"id": "id_1" {"token": "gamma_tok", "url": "https://gamma.com", "id": "id_gamma"},
},
{
"token": "alpha_tok",
"url": "https://alpha.com",
"id": "id_alpha"
},
{
"token": "beta_tok",
"url": "https://beta.com",
"id": "id_beta"
},
{
"token": "gamma_tok",
"url": "https://gamma.com",
"id": "id_gamma"
},
] ]
for s in self.as_list: for s in self.as_list:
yield self._add_service(s["url"], s["token"], s["id"]) yield self._add_service(s["url"], s["token"], s["id"])
@ -157,8 +126,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.store = TestTransactionStore(None, hs) self.store = TestTransactionStore(None, hs)
def _add_service(self, url, as_token, id): def _add_service(self, url, as_token, id):
as_yaml = dict(url=url, as_token=as_token, hs_token="something", as_yaml = dict(
id=id, sender_localpart="a_sender", namespaces={}) url=url,
as_token=as_token,
hs_token="something",
id=id,
sender_localpart="a_sender",
namespaces={},
)
# use the token as the filename # use the token as the filename
with open(as_token, 'w') as outfile: with open(as_token, 'w') as outfile:
outfile.write(yaml.dump(as_yaml)) outfile.write(yaml.dump(as_yaml))
@ -168,21 +143,21 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
return self.db_pool.runQuery( return self.db_pool.runQuery(
"INSERT INTO application_services_state(as_id, state, last_txn) " "INSERT INTO application_services_state(as_id, state, last_txn) "
"VALUES(?,?,?)", "VALUES(?,?,?)",
(id, state, txn) (id, state, txn),
) )
def _insert_txn(self, as_id, txn_id, events): def _insert_txn(self, as_id, txn_id, events):
return self.db_pool.runQuery( return self.db_pool.runQuery(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) " "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)", "VALUES(?,?,?)",
(as_id, txn_id, json.dumps([e.event_id for e in events])) (as_id, txn_id, json.dumps([e.event_id for e in events])),
) )
def _set_last_txn(self, as_id, txn_id): def _set_last_txn(self, as_id, txn_id):
return self.db_pool.runQuery( return self.db_pool.runQuery(
"INSERT INTO application_services_state(as_id, last_txn, state) " "INSERT INTO application_services_state(as_id, last_txn, state) "
"VALUES(?,?,?)", "VALUES(?,?,?)",
(as_id, txn_id, ApplicationServiceState.UP) (as_id, txn_id, ApplicationServiceState.UP),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -193,24 +168,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_appservice_state_up(self): def test_get_appservice_state_up(self):
yield self._set_state( yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
self.as_list[0]["id"], ApplicationServiceState.UP
)
service = Mock(id=self.as_list[0]["id"]) service = Mock(id=self.as_list[0]["id"])
state = yield self.store.get_appservice_state(service) state = yield self.store.get_appservice_state(service)
self.assertEquals(ApplicationServiceState.UP, state) self.assertEquals(ApplicationServiceState.UP, state)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_appservice_state_down(self): def test_get_appservice_state_down(self):
yield self._set_state( yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
self.as_list[0]["id"], ApplicationServiceState.UP yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
) yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
yield self._set_state(
self.as_list[1]["id"], ApplicationServiceState.DOWN
)
yield self._set_state(
self.as_list[2]["id"], ApplicationServiceState.DOWN
)
service = Mock(id=self.as_list[1]["id"]) service = Mock(id=self.as_list[1]["id"])
state = yield self.store.get_appservice_state(service) state = yield self.store.get_appservice_state(service)
self.assertEquals(ApplicationServiceState.DOWN, state) self.assertEquals(ApplicationServiceState.DOWN, state)
@ -225,34 +192,22 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_appservices_state_down(self): def test_set_appservices_state_down(self):
service = Mock(id=self.as_list[1]["id"]) service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state( yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
service,
ApplicationServiceState.DOWN
)
rows = yield self.db_pool.runQuery( rows = yield self.db_pool.runQuery(
"SELECT as_id FROM application_services_state WHERE state=?", "SELECT as_id FROM application_services_state WHERE state=?",
(ApplicationServiceState.DOWN,) (ApplicationServiceState.DOWN,),
) )
self.assertEquals(service.id, rows[0][0]) self.assertEquals(service.id, rows[0][0])
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_appservices_state_multiple_up(self): def test_set_appservices_state_multiple_up(self):
service = Mock(id=self.as_list[1]["id"]) service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state( yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
service, yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
ApplicationServiceState.UP yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
)
yield self.store.set_appservice_state(
service,
ApplicationServiceState.DOWN
)
yield self.store.set_appservice_state(
service,
ApplicationServiceState.UP
)
rows = yield self.db_pool.runQuery( rows = yield self.db_pool.runQuery(
"SELECT as_id FROM application_services_state WHERE state=?", "SELECT as_id FROM application_services_state WHERE state=?",
(ApplicationServiceState.UP,) (ApplicationServiceState.UP,),
) )
self.assertEquals(service.id, rows[0][0]) self.assertEquals(service.id, rows[0][0])
@ -319,14 +274,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
res = yield self.db_pool.runQuery( res = yield self.db_pool.runQuery(
"SELECT last_txn FROM application_services_state WHERE as_id=?", "SELECT last_txn FROM application_services_state WHERE as_id=?",
(service.id,) (service.id,),
) )
self.assertEquals(1, len(res)) self.assertEquals(1, len(res))
self.assertEquals(txn_id, res[0][0]) self.assertEquals(txn_id, res[0][0])
res = yield self.db_pool.runQuery( res = yield self.db_pool.runQuery(
"SELECT * FROM application_services_txns WHERE txn_id=?", "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
(txn_id,)
) )
self.assertEquals(0, len(res)) self.assertEquals(0, len(res))
@ -340,17 +294,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery( res = yield self.db_pool.runQuery(
"SELECT last_txn, state FROM application_services_state WHERE " "SELECT last_txn, state FROM application_services_state WHERE " "as_id=?",
"as_id=?", (service.id,),
(service.id,)
) )
self.assertEquals(1, len(res)) self.assertEquals(1, len(res))
self.assertEquals(txn_id, res[0][0]) self.assertEquals(txn_id, res[0][0])
self.assertEquals(ApplicationServiceState.UP, res[0][1]) self.assertEquals(ApplicationServiceState.UP, res[0][1])
res = yield self.db_pool.runQuery( res = yield self.db_pool.runQuery(
"SELECT * FROM application_services_txns WHERE txn_id=?", "SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
(txn_id,)
) )
self.assertEquals(0, len(res)) self.assertEquals(0, len(res))
@ -382,12 +334,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_appservices_by_state_single(self): def test_get_appservices_by_state_single(self):
yield self._set_state( yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
self.as_list[0]["id"], ApplicationServiceState.DOWN yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
)
yield self._set_state(
self.as_list[1]["id"], ApplicationServiceState.UP
)
services = yield self.store.get_appservices_by_state( services = yield self.store.get_appservices_by_state(
ApplicationServiceState.DOWN ApplicationServiceState.DOWN
@ -397,18 +345,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_appservices_by_state_multiple(self): def test_get_appservices_by_state_multiple(self):
yield self._set_state( yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
self.as_list[0]["id"], ApplicationServiceState.DOWN yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
) yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
yield self._set_state( yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
self.as_list[1]["id"], ApplicationServiceState.UP
)
yield self._set_state(
self.as_list[2]["id"], ApplicationServiceState.DOWN
)
yield self._set_state(
self.as_list[3]["id"], ApplicationServiceState.UP
)
services = yield self.store.get_appservices_by_state( services = yield self.store.get_appservices_by_state(
ApplicationServiceState.DOWN ApplicationServiceState.DOWN
@ -416,20 +356,17 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.assertEquals(2, len(services)) self.assertEquals(2, len(services))
self.assertEquals( self.assertEquals(
set([self.as_list[2]["id"], self.as_list[0]["id"]]), set([self.as_list[2]["id"], self.as_list[0]["id"]]),
set([services[0].id, services[1].id]) set([services[0].id, services[1].id]),
) )
# required for ApplicationServiceTransactionStoreTestCase tests # required for ApplicationServiceTransactionStoreTestCase tests
class TestTransactionStore(ApplicationServiceTransactionStore, class TestTransactionStore(ApplicationServiceTransactionStore, ApplicationServiceStore):
ApplicationServiceStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(TestTransactionStore, self).__init__(db_conn, hs) super(TestTransactionStore, self).__init__(db_conn, hs)
class ApplicationServiceStoreConfigTestCase(unittest.TestCase): class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
def _write_config(self, suffix, **kwargs): def _write_config(self, suffix, **kwargs):
vals = { vals = {
"id": "id" + suffix, "id": "id" + suffix,
@ -452,8 +389,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f2 = self._write_config(suffix="2") f2 = self._write_config(suffix="2")
config = Mock( config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1, app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
password_providers=[]
) )
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config,
@ -470,8 +406,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f2 = self._write_config(id="id", suffix="2") f2 = self._write_config(id="id", suffix="2")
config = Mock( config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1, app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
password_providers=[]
) )
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config,
@ -494,8 +429,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f2 = self._write_config(as_token="as_token", suffix="2") f2 = self._write_config(as_token="as_token", suffix="2")
config = Mock( config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1, app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
password_providers=[]
) )
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config,

View File

@ -7,7 +7,6 @@ from tests.utils import setup_test_homeserver
class BackgroundUpdateTestCase(unittest.TestCase): class BackgroundUpdateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver() # type: synapse.server.HomeServer hs = yield setup_test_homeserver() # type: synapse.server.HomeServer
@ -51,9 +50,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
yield self.store.start_background_update("test_update", {"my_key": 1}) yield self.store.start_background_update("test_update", {"my_key": 1})
self.update_handler.reset_mock() self.update_handler.reset_mock()
result = yield self.store.do_next_background_update( result = yield self.store.do_next_background_update(duration_ms * desired_count)
duration_ms * desired_count
)
self.assertIsNotNone(result) self.assertIsNotNone(result)
self.update_handler.assert_called_once_with( self.update_handler.assert_called_once_with(
{"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
@ -67,18 +64,12 @@ class BackgroundUpdateTestCase(unittest.TestCase):
self.update_handler.side_effect = update self.update_handler.side_effect = update
self.update_handler.reset_mock() self.update_handler.reset_mock()
result = yield self.store.do_next_background_update( result = yield self.store.do_next_background_update(duration_ms * desired_count)
duration_ms * desired_count
)
self.assertIsNotNone(result) self.assertIsNotNone(result)
self.update_handler.assert_called_once_with( self.update_handler.assert_called_once_with({"my_key": 2}, desired_count)
{"my_key": 2}, desired_count
)
# third step: we don't expect to be called any more # third step: we don't expect to be called any more
self.update_handler.reset_mock() self.update_handler.reset_mock()
result = yield self.store.do_next_background_update( result = yield self.store.do_next_background_update(duration_ms * desired_count)
duration_ms * desired_count
)
self.assertIsNone(result) self.assertIsNone(result)
self.assertFalse(self.update_handler.called) self.assertFalse(self.update_handler.called)

View File

@ -40,10 +40,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def runInteraction(func, *args, **kwargs): def runInteraction(func, *args, **kwargs):
return defer.succeed(func(self.mock_txn, *args, **kwargs)) return defer.succeed(func(self.mock_txn, *args, **kwargs))
self.db_pool.runInteraction = runInteraction self.db_pool.runInteraction = runInteraction
def runWithConnection(func, *args, **kwargs): def runWithConnection(func, *args, **kwargs):
return defer.succeed(func(self.mock_conn, *args, **kwargs)) return defer.succeed(func(self.mock_conn, *args, **kwargs))
self.db_pool.runWithConnection = runWithConnection self.db_pool.runWithConnection = runWithConnection
config = Mock() config = Mock()
@ -63,8 +65,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
yield self.datastore._simple_insert( yield self.datastore._simple_insert(
table="tablename", table="tablename", values={"columname": "Value"}
values={"columname": "Value"}
) )
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
@ -78,12 +79,11 @@ class SQLBaseStoreTestCase(unittest.TestCase):
yield self.datastore._simple_insert( yield self.datastore._simple_insert(
table="tablename", table="tablename",
# Use OrderedDict() so we can assert on the SQL generated # Use OrderedDict() so we can assert on the SQL generated
values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]) values=OrderedDict([("colA", 1), ("colB", 2), ("colC", 3)]),
) )
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", "INSERT INTO tablename (colA, colB, colC) VALUES(?, ?, ?)", (1, 2, 3)
(1, 2, 3,)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -92,9 +92,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)])) self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
value = yield self.datastore._simple_select_one_onecol( value = yield self.datastore._simple_select_one_onecol(
table="tablename", table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
keyvalues={"keycol": "TheKey"},
retcol="retcol"
) )
self.assertEquals("Value", value) self.assertEquals("Value", value)
@ -110,13 +108,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
ret = yield self.datastore._simple_select_one( ret = yield self.datastore._simple_select_one(
table="tablename", table="tablename",
keyvalues={"keycol": "TheKey"}, keyvalues={"keycol": "TheKey"},
retcols=["colA", "colB", "colC"] retcols=["colA", "colB", "colC"],
) )
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret) self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"SELECT colA, colB, colC FROM tablename WHERE keycol = ?", "SELECT colA, colB, colC FROM tablename WHERE keycol = ?", ["TheKey"]
["TheKey"]
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -128,7 +125,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
table="tablename", table="tablename",
keyvalues={"keycol": "Not here"}, keyvalues={"keycol": "Not here"},
retcols=["colA"], retcols=["colA"],
allow_none=True allow_none=True,
) )
self.assertFalse(ret) self.assertFalse(ret)
@ -137,20 +134,15 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_select_list(self): def test_select_list(self):
self.mock_txn.rowcount = 3 self.mock_txn.rowcount = 3
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = ( self.mock_txn.description = (("colA", None, None, None, None, None, None),)
("colA", None, None, None, None, None, None),
)
ret = yield self.datastore._simple_select_list( ret = yield self.datastore._simple_select_list(
table="tablename", table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
keyvalues={"keycol": "A set"},
retcols=["colA"],
) )
self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"SELECT colA FROM tablename WHERE keycol = ?", "SELECT colA FROM tablename WHERE keycol = ?", ["A set"]
["A set"]
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -160,12 +152,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
yield self.datastore._simple_update_one( yield self.datastore._simple_update_one(
table="tablename", table="tablename",
keyvalues={"keycol": "TheKey"}, keyvalues={"keycol": "TheKey"},
updatevalues={"columnname": "New Value"} updatevalues={"columnname": "New Value"},
) )
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"UPDATE tablename SET columnname = ? WHERE keycol = ?", "UPDATE tablename SET columnname = ? WHERE keycol = ?",
["New Value", "TheKey"] ["New Value", "TheKey"],
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -175,13 +167,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
yield self.datastore._simple_update_one( yield self.datastore._simple_update_one(
table="tablename", table="tablename",
keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
updatevalues=OrderedDict([("colC", 3), ("colD", 4)]) updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
) )
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
"UPDATE tablename SET colC = ?, colD = ? WHERE" "UPDATE tablename SET colC = ?, colD = ? WHERE" " colA = ? AND colB = ?",
" colA = ? AND colB = ?", [3, 4, 1, 2],
[3, 4, 1, 2]
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -189,8 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
yield self.datastore._simple_delete_one( yield self.datastore._simple_delete_one(
table="tablename", table="tablename", keyvalues={"keycol": "Go away"}
keyvalues={"keycol": "Go away"},
) )
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(

View File

@ -37,8 +37,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
self.clock.now = 12345678 self.clock.now = 12345678
user_id = "@user:id" user_id = "@user:id"
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, user_id, "access_token", "ip", "user_agent", "device_id"
"access_token", "ip", "user_agent", "device_id",
) )
result = yield self.store.get_last_client_ip_by_device(user_id, "device_id") result = yield self.store.get_last_client_ip_by_device(user_id, "device_id")
@ -53,7 +52,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
"user_agent": "user_agent", "user_agent": "user_agent",
"last_seen": 12345678000, "last_seen": 12345678000,
}, },
r r,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -62,7 +61,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
self.hs.config.max_mau_value = 50 self.hs.config.max_mau_value = 50
user_id = "@user:server" user_id = "@user:server"
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id", user_id, "access_token", "ip", "user_agent", "device_id"
) )
active = yield self.store._user_last_seen_monthly_active(user_id) active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertFalse(active) self.assertFalse(active)
@ -78,7 +77,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
return_value=defer.succeed(lots_of_users) return_value=defer.succeed(lots_of_users)
) )
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id", user_id, "access_token", "ip", "user_agent", "device_id"
) )
active = yield self.store._user_last_seen_monthly_active(user_id) active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertFalse(active) self.assertFalse(active)
@ -92,7 +91,7 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
self.assertFalse(active) self.assertFalse(active)
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id", user_id, "access_token", "ip", "user_agent", "device_id"
) )
active = yield self.store._user_last_seen_monthly_active(user_id) active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertTrue(active) self.assertTrue(active)
@ -107,10 +106,10 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
self.assertFalse(active) self.assertFalse(active)
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id", user_id, "access_token", "ip", "user_agent", "device_id"
) )
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id", user_id, "access_token", "ip", "user_agent", "device_id"
) )
active = yield self.store._user_last_seen_monthly_active(user_id) active = yield self.store._user_last_seen_monthly_active(user_id)
self.assertTrue(active) self.assertTrue(active)

View File

@ -34,62 +34,58 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_store_new_device(self): def test_store_new_device(self):
yield self.store.store_device( yield self.store.store_device("user_id", "device_id", "display_name")
"user_id", "device_id", "display_name"
)
res = yield self.store.get_device("user_id", "device_id") res = yield self.store.get_device("user_id", "device_id")
self.assertDictContainsSubset({ self.assertDictContainsSubset(
{
"user_id": "user_id", "user_id": "user_id",
"device_id": "device_id", "device_id": "device_id",
"display_name": "display_name", "display_name": "display_name",
}, res) },
res,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_devices_by_user(self): def test_get_devices_by_user(self):
yield self.store.store_device( yield self.store.store_device("user_id", "device1", "display_name 1")
"user_id", "device1", "display_name 1" yield self.store.store_device("user_id", "device2", "display_name 2")
) yield self.store.store_device("user_id2", "device3", "display_name 3")
yield self.store.store_device(
"user_id", "device2", "display_name 2"
)
yield self.store.store_device(
"user_id2", "device3", "display_name 3"
)
res = yield self.store.get_devices_by_user("user_id") res = yield self.store.get_devices_by_user("user_id")
self.assertEqual(2, len(res.keys())) self.assertEqual(2, len(res.keys()))
self.assertDictContainsSubset({ self.assertDictContainsSubset(
{
"user_id": "user_id", "user_id": "user_id",
"device_id": "device1", "device_id": "device1",
"display_name": "display_name 1", "display_name": "display_name 1",
}, res["device1"]) },
self.assertDictContainsSubset({ res["device1"],
)
self.assertDictContainsSubset(
{
"user_id": "user_id", "user_id": "user_id",
"device_id": "device2", "device_id": "device2",
"display_name": "display_name 2", "display_name": "display_name 2",
}, res["device2"]) },
res["device2"],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_update_device(self): def test_update_device(self):
yield self.store.store_device( yield self.store.store_device("user_id", "device_id", "display_name 1")
"user_id", "device_id", "display_name 1"
)
res = yield self.store.get_device("user_id", "device_id") res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"]) self.assertEqual("display_name 1", res["display_name"])
# do a no-op first # do a no-op first
yield self.store.update_device( yield self.store.update_device("user_id", "device_id")
"user_id", "device_id",
)
res = yield self.store.get_device("user_id", "device_id") res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"]) self.assertEqual("display_name 1", res["display_name"])
# do the update # do the update
yield self.store.update_device( yield self.store.update_device(
"user_id", "device_id", "user_id", "device_id", new_display_name="display_name 2"
new_display_name="display_name 2",
) )
# check it worked # check it worked
@ -100,7 +96,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
def test_update_unknown_device(self): def test_update_unknown_device(self):
with self.assertRaises(synapse.api.errors.StoreError) as cm: with self.assertRaises(synapse.api.errors.StoreError) as cm:
yield self.store.update_device( yield self.store.update_device(
"user_id", "unknown_device_id", "user_id", "unknown_device_id", new_display_name="display_name 2"
new_display_name="display_name 2",
) )
self.assertEqual(404, cm.exception.code) self.assertEqual(404, cm.exception.code)

View File

@ -24,7 +24,6 @@ from tests.utils import setup_test_homeserver
class DirectoryStoreTestCase(unittest.TestCase): class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver() hs = yield setup_test_homeserver()
@ -37,38 +36,29 @@ class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_room_to_alias(self): def test_room_to_alias(self):
yield self.store.create_room_alias_association( yield self.store.create_room_alias_association(
room_alias=self.alias, room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
room_id=self.room.to_string(),
servers=["test"],
) )
self.assertEquals( self.assertEquals(
["#my-room:test"], ["#my-room:test"],
(yield self.store.get_aliases_for_room(self.room.to_string())) (yield self.store.get_aliases_for_room(self.room.to_string())),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_alias_to_room(self): def test_alias_to_room(self):
yield self.store.create_room_alias_association( yield self.store.create_room_alias_association(
room_alias=self.alias, room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
room_id=self.room.to_string(),
servers=["test"],
) )
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{ {"room_id": self.room.to_string(), "servers": ["test"]},
"room_id": self.room.to_string(), (yield self.store.get_association_from_room_alias(self.alias)),
"servers": ["test"],
},
(yield self.store.get_association_from_room_alias(self.alias))
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_delete_alias(self): def test_delete_alias(self):
yield self.store.create_room_alias_association( yield self.store.create_room_alias_association(
room_alias=self.alias, room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
room_id=self.room.to_string(),
servers=["test"],
) )
room_id = yield self.store.delete_room_alias(self.alias) room_id = yield self.store.delete_room_alias(self.alias)

View File

@ -35,70 +35,49 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070 now = 1470174257070
json = {"key": "value"} json = {"key": "value"}
yield self.store.store_device( yield self.store.store_device("user", "device", None)
"user", "device", None
)
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys("user", "device", now, json)
"user", "device", now, json)
res = yield self.store.get_e2e_device_keys((("user", "device"),)) res = yield self.store.get_e2e_device_keys((("user", "device"),))
self.assertIn("user", res) self.assertIn("user", res)
self.assertIn("device", res["user"]) self.assertIn("device", res["user"])
dev = res["user"]["device"] dev = res["user"]["device"]
self.assertDictContainsSubset({ self.assertDictContainsSubset({"keys": json, "device_display_name": None}, dev)
"keys": json,
"device_display_name": None,
}, dev)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_key_with_device_name(self): def test_get_key_with_device_name(self):
now = 1470174257070 now = 1470174257070
json = {"key": "value"} json = {"key": "value"}
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys("user", "device", now, json)
"user", "device", now, json) yield self.store.store_device("user", "device", "display_name")
yield self.store.store_device(
"user", "device", "display_name"
)
res = yield self.store.get_e2e_device_keys((("user", "device"),)) res = yield self.store.get_e2e_device_keys((("user", "device"),))
self.assertIn("user", res) self.assertIn("user", res)
self.assertIn("device", res["user"]) self.assertIn("device", res["user"])
dev = res["user"]["device"] dev = res["user"]["device"]
self.assertDictContainsSubset({ self.assertDictContainsSubset(
"keys": json, {"keys": json, "device_display_name": "display_name"}, dev
"device_display_name": "display_name", )
}, dev)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_multiple_devices(self): def test_multiple_devices(self):
now = 1470174257070 now = 1470174257070
yield self.store.store_device( yield self.store.store_device("user1", "device1", None)
"user1", "device1", None yield self.store.store_device("user1", "device2", None)
) yield self.store.store_device("user2", "device1", None)
yield self.store.store_device( yield self.store.store_device("user2", "device2", None)
"user1", "device2", None
)
yield self.store.store_device(
"user2", "device1", None
)
yield self.store.store_device(
"user2", "device2", None
)
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys("user1", "device1", now, 'json11')
"user1", "device1", now, 'json11') yield self.store.set_e2e_device_keys("user1", "device2", now, 'json12')
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys("user2", "device1", now, 'json21')
"user1", "device2", now, 'json12') yield self.store.set_e2e_device_keys("user2", "device2", now, 'json22')
yield self.store.set_e2e_device_keys(
"user2", "device1", now, 'json21')
yield self.store.set_e2e_device_keys(
"user2", "device2", now, 'json22')
res = yield self.store.get_e2e_device_keys((("user1", "device1"), res = yield self.store.get_e2e_device_keys(
("user2", "device2"))) (("user1", "device1"), ("user2", "device2"))
)
self.assertIn("user1", res) self.assertIn("user1", res)
self.assertIn("device1", res["user1"]) self.assertIn("device1", res["user1"])
self.assertNotIn("device2", res["user1"]) self.assertNotIn("device2", res["user1"])

View File

@ -33,23 +33,32 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
def insert_event(txn, i): def insert_event(txn, i):
event_id = '$event_%i:local' % i event_id = '$event_%i:local' % i
txn.execute(( txn.execute(
(
"INSERT INTO events (" "INSERT INTO events ("
" room_id, event_id, type, depth, topological_ordering," " room_id, event_id, type, depth, topological_ordering,"
" content, processed, outlier) " " content, processed, outlier) "
"VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)" "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)"
), (room_id, event_id, i, i, True, False)) ),
(room_id, event_id, i, i, True, False),
)
txn.execute(( txn.execute(
(
'INSERT INTO event_forward_extremities (room_id, event_id) ' 'INSERT INTO event_forward_extremities (room_id, event_id) '
'VALUES (?, ?)' 'VALUES (?, ?)'
), (room_id, event_id)) ),
(room_id, event_id),
)
txn.execute(( txn.execute(
(
'INSERT INTO event_reference_hashes ' 'INSERT INTO event_reference_hashes '
'(event_id, algorithm, hash) ' '(event_id, algorithm, hash) '
"VALUES (?, 'sha256', ?)" "VALUES (?, 'sha256', ?)"
), (event_id, b'ffff')) ),
(event_id, b'ffff'),
)
for i in range(0, 11): for i in range(0, 11):
yield self.store.runInteraction("insert", insert_event, i) yield self.store.runInteraction("insert", insert_event, i)

View File

@ -24,12 +24,13 @@ USER_ID = "@user:example.com"
PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}] PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}]
HIGHLIGHT = [ HIGHLIGHT = [
"notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"} "notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight"},
] ]
class EventPushActionsStoreTestCase(tests.unittest.TestCase): class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield tests.utils.setup_test_homeserver() hs = yield tests.utils.setup_test_homeserver()
@ -55,12 +56,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count): def _assert_counts(noitf_count, highlight_count):
counts = yield self.store.runInteraction( counts = yield self.store.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, "", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0
room_id, user_id, 0
) )
self.assertEquals( self.assertEquals(
counts, counts,
{"notify_count": noitf_count, "highlight_count": highlight_count} {"notify_count": noitf_count, "highlight_count": highlight_count},
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -72,11 +72,13 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.depth = stream event.depth = stream
yield self.store.add_push_actions_to_staging( yield self.store.add_push_actions_to_staging(
event.event_id, {user_id: action}, event.event_id, {user_id: action}
) )
yield self.store.runInteraction( yield self.store.runInteraction(
"", self.store._set_push_actions_for_event_and_users_txn, "",
[(event, None)], [(event, None)], self.store._set_push_actions_for_event_and_users_txn,
[(event, None)],
[(event, None)],
) )
def _rotate(stream): def _rotate(stream):
@ -86,8 +88,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
def _mark_read(stream, depth): def _mark_read(stream, depth):
return self.store.runInteraction( return self.store.runInteraction(
"", self.store._remove_old_push_actions_before_txn, "",
room_id, user_id, stream self.store._remove_old_push_actions_before_txn,
room_id,
user_id,
stream,
) )
yield _assert_counts(0, 0) yield _assert_counts(0, 0)
@ -112,9 +117,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield _rotate(7) yield _rotate(7)
yield self.store._simple_delete( yield self.store._simple_delete(
table="event_push_actions", table="event_push_actions", keyvalues={"1": 1}, desc=""
keyvalues={"1": 1},
desc="",
) )
yield _assert_counts(1, 0) yield _assert_counts(1, 0)
@ -132,7 +135,9 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_find_first_stream_ordering_after_ts(self): def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts): def add_event(so, ts):
return self.store._simple_insert("events", { return self.store._simple_insert(
"events",
{
"stream_ordering": so, "stream_ordering": so,
"received_ts": ts, "received_ts": ts,
"event_id": "event%i" % so, "event_id": "event%i" % so,
@ -143,7 +148,8 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
"outlier": False, "outlier": False,
"topological_ordering": 0, "topological_ordering": 0,
"depth": 0, "depth": 0,
}) },
)
# start with the base case where there are no events in the table # start with the base case where there are no events in the table
r = yield self.store.find_first_stream_ordering_after_ts(11) r = yield self.store.find_first_stream_ordering_after_ts(11)
@ -169,22 +175,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield add_event(stream_ordering, ts) yield add_event(stream_ordering, ts)
r = yield self.store.find_first_stream_ordering_after_ts(110) r = yield self.store.find_first_stream_ordering_after_ts(110)
self.assertEqual(r, 3, self.assertEqual(r, 3, "First event after 110ms should be 3, was %i" % r)
"First event after 110ms should be 3, was %i" % r)
# 4 and 5 are both after 120: we want 4 rather than 5 # 4 and 5 are both after 120: we want 4 rather than 5
r = yield self.store.find_first_stream_ordering_after_ts(120) r = yield self.store.find_first_stream_ordering_after_ts(120)
self.assertEqual(r, 4, self.assertEqual(r, 4, "First event after 120ms should be 4, was %i" % r)
"First event after 120ms should be 4, was %i" % r)
r = yield self.store.find_first_stream_ordering_after_ts(129) r = yield self.store.find_first_stream_ordering_after_ts(129)
self.assertEqual(r, 10, self.assertEqual(r, 10, "First event after 129ms should be 10, was %i" % r)
"First event after 129ms should be 10, was %i" % r)
# check we can get the last event # check we can get the last event
r = yield self.store.find_first_stream_ordering_after_ts(140) r = yield self.store.find_first_stream_ordering_after_ts(140)
self.assertEqual(r, 20, self.assertEqual(r, 20, "First event after 14ms should be 20, was %i" % r)
"First event after 14ms should be 20, was %i" % r)
# off the end # off the end
r = yield self.store.find_first_stream_ordering_after_ts(160) r = yield self.store.find_first_stream_ordering_after_ts(160)

View File

@ -39,15 +39,12 @@ class KeyStoreTestCase(tests.unittest.TestCase):
key2 = signedjson.key.decode_verify_key_base64( key2 = signedjson.key.decode_verify_key_base64(
"ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
) )
yield self.store.store_server_verify_key( yield self.store.store_server_verify_key("server1", "from_server", 0, key1)
"server1", "from_server", 0, key1 yield self.store.store_server_verify_key("server1", "from_server", 0, key2)
)
yield self.store.store_server_verify_key(
"server1", "from_server", 0, key2
)
res = yield self.store.get_server_verify_keys( res = yield self.store.get_server_verify_keys(
"server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"]) "server1", ["ed25519:key1", "ed25519:key2", "ed25519:key3"]
)
self.assertEqual(len(res.keys()), 2) self.assertEqual(len(res.keys()), 2)
self.assertEqual(res["ed25519:key1"].version, "key1") self.assertEqual(res["ed25519:key1"].version, "key1")

View File

@ -40,19 +40,13 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
user2_email = "user2@matrix.org" user2_email = "user2@matrix.org"
threepids = [ threepids = [
{'medium': 'email', 'address': user1_email}, {'medium': 'email', 'address': user1_email},
{'medium': 'email', 'address': user2_email} {'medium': 'email', 'address': user2_email},
] ]
user_num = len(threepids) user_num = len(threepids)
yield self.store.register( yield self.store.register(user_id=user1, token="123", password_hash=None)
user_id=user1,
token="123",
password_hash=None)
yield self.store.register( yield self.store.register(user_id=user2, token="456", password_hash=None)
user_id=user2,
token="456",
password_hash=None)
now = int(self.hs.get_clock().time_msec()) now = int(self.hs.get_clock().time_msec())
yield self.store.user_add_threepid(user1, "email", user1_email, now, now) yield self.store.user_add_threepid(user1, "email", user1_email, now, now)

View File

@ -24,7 +24,6 @@ from tests.utils import MockClock, setup_test_homeserver
class PresenceStoreTestCase(unittest.TestCase): class PresenceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver(clock=MockClock()) hs = yield setup_test_homeserver(clock=MockClock())
@ -38,16 +37,19 @@ class PresenceStoreTestCase(unittest.TestCase):
def test_presence_list(self): def test_presence_list(self):
self.assertEquals( self.assertEquals(
[], [],
(yield self.store.get_presence_list( (
observer_localpart=self.u_apple.localpart, yield self.store.get_presence_list(
)) observer_localpart=self.u_apple.localpart
)
),
) )
self.assertEquals( self.assertEquals(
[], [],
(yield self.store.get_presence_list( (
observer_localpart=self.u_apple.localpart, yield self.store.get_presence_list(
accepted=True, observer_localpart=self.u_apple.localpart, accepted=True
)) )
),
) )
yield self.store.add_presence_list_pending( yield self.store.add_presence_list_pending(
@ -57,16 +59,19 @@ class PresenceStoreTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 0}], [{"observed_user_id": "@banana:test", "accepted": 0}],
(yield self.store.get_presence_list( (
observer_localpart=self.u_apple.localpart, yield self.store.get_presence_list(
)) observer_localpart=self.u_apple.localpart
)
),
) )
self.assertEquals( self.assertEquals(
[], [],
(yield self.store.get_presence_list( (
observer_localpart=self.u_apple.localpart, yield self.store.get_presence_list(
accepted=True, observer_localpart=self.u_apple.localpart, accepted=True
)) )
),
) )
yield self.store.set_presence_list_accepted( yield self.store.set_presence_list_accepted(
@ -76,16 +81,19 @@ class PresenceStoreTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 1}], [{"observed_user_id": "@banana:test", "accepted": 1}],
(yield self.store.get_presence_list( (
observer_localpart=self.u_apple.localpart, yield self.store.get_presence_list(
)) observer_localpart=self.u_apple.localpart
)
),
) )
self.assertEquals( self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 1}], [{"observed_user_id": "@banana:test", "accepted": 1}],
(yield self.store.get_presence_list( (
observer_localpart=self.u_apple.localpart, yield self.store.get_presence_list(
accepted=True, observer_localpart=self.u_apple.localpart, accepted=True
)) )
),
) )
yield self.store.del_presence_list( yield self.store.del_presence_list(
@ -95,14 +103,17 @@ class PresenceStoreTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
[], [],
(yield self.store.get_presence_list( (
observer_localpart=self.u_apple.localpart, yield self.store.get_presence_list(
)) observer_localpart=self.u_apple.localpart
)
),
) )
self.assertEquals( self.assertEquals(
[], [],
(yield self.store.get_presence_list( (
observer_localpart=self.u_apple.localpart, yield self.store.get_presence_list(
accepted=True, observer_localpart=self.u_apple.localpart, accepted=True
)) )
),
) )

View File

@ -24,7 +24,6 @@ from tests.utils import setup_test_homeserver
class ProfileStoreTestCase(unittest.TestCase): class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver() hs = yield setup_test_homeserver()
@ -35,24 +34,17 @@ class ProfileStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_displayname(self): def test_displayname(self):
yield self.store.create_profile( yield self.store.create_profile(self.u_frank.localpart)
self.u_frank.localpart
)
yield self.store.set_profile_displayname( yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
self.u_frank.localpart, "Frank"
)
self.assertEquals( self.assertEquals(
"Frank", "Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
(yield self.store.get_profile_displayname(self.u_frank.localpart))
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_avatar_url(self): def test_avatar_url(self):
yield self.store.create_profile( yield self.store.create_profile(self.u_frank.localpart)
self.u_frank.localpart
)
yield self.store.set_profile_avatar_url( yield self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here" self.u_frank.localpart, "http://my.site/here"
@ -60,5 +52,5 @@ class ProfileStoreTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
"http://my.site/here", "http://my.site/here",
(yield self.store.get_profile_avatar_url(self.u_frank.localpart)) (yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
) )

View File

@ -26,12 +26,10 @@ from tests.utils import setup_test_homeserver
class RedactionTestCase(unittest.TestCase): class RedactionTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
resource_for_federation=Mock(), resource_for_federation=Mock(), http_client=None
http_client=None,
) )
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -46,17 +44,20 @@ class RedactionTestCase(unittest.TestCase):
self.depth = 1 self.depth = 1
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_room_member(self, room, user, membership, replaces_state=None, def inject_room_member(
extra_content={}): self, room, user, membership, replaces_state=None, extra_content={}
):
content = {"membership": membership} content = {"membership": membership}
content.update(extra_content) content.update(extra_content)
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new(
{
"type": EventTypes.Member, "type": EventTypes.Member,
"sender": user.to_string(), "sender": user.to_string(),
"state_key": user.to_string(), "state_key": user.to_string(),
"room_id": room.to_string(), "room_id": room.to_string(),
"content": content, "content": content,
}) }
)
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
@ -70,13 +71,15 @@ class RedactionTestCase(unittest.TestCase):
def inject_message(self, room, user, body): def inject_message(self, room, user, body):
self.depth += 1 self.depth += 1
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new(
{
"type": EventTypes.Message, "type": EventTypes.Message,
"sender": user.to_string(), "sender": user.to_string(),
"state_key": user.to_string(), "state_key": user.to_string(),
"room_id": room.to_string(), "room_id": room.to_string(),
"content": {"body": body, "msgtype": u"message"}, "content": {"body": body, "msgtype": u"message"},
}) }
)
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
@ -88,14 +91,16 @@ class RedactionTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_redaction(self, room, event_id, user, reason): def inject_redaction(self, room, event_id, user, reason):
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new(
{
"type": EventTypes.Redaction, "type": EventTypes.Redaction,
"sender": user.to_string(), "sender": user.to_string(),
"state_key": user.to_string(), "state_key": user.to_string(),
"room_id": room.to_string(), "room_id": room.to_string(),
"content": {"reason": reason}, "content": {"reason": reason},
"redacts": event_id, "redacts": event_id,
}) }
)
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
@ -105,9 +110,7 @@ class RedactionTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_redact(self): def test_redact(self):
yield self.inject_room_member( yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
self.room1, self.u_alice, Membership.JOIN
)
msg_event = yield self.inject_message(self.room1, self.u_alice, u"t") msg_event = yield self.inject_message(self.room1, self.u_alice, u"t")
@ -157,13 +160,10 @@ class RedactionTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_redact_join(self): def test_redact_join(self):
yield self.inject_room_member( yield self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
self.room1, self.u_alice, Membership.JOIN
)
msg_event = yield self.inject_room_member( msg_event = yield self.inject_room_member(
self.room1, self.u_bob, Membership.JOIN, self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
extra_content={"blue": "red"},
) )
event = yield self.store.get_event(msg_event.event_id) event = yield self.store.get_event(msg_event.event_id)

View File

@ -21,7 +21,6 @@ from tests.utils import setup_test_homeserver
class RegistrationStoreTestCase(unittest.TestCase): class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver() hs = yield setup_test_homeserver()
@ -30,10 +29,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.user_id = "@my-user:test" self.user_id = "@my-user:test"
self.tokens = [ self.tokens = ["AbCdEfGhIjKlMnOpQrStUvWxYz", "BcDeFgHiJkLmNoPqRsTuVwXyZa"]
"AbCdEfGhIjKlMnOpQrStUvWxYz",
"BcDeFgHiJkLmNoPqRsTuVwXyZa"
]
self.pwhash = "{xx1}123456789" self.pwhash = "{xx1}123456789"
self.device_id = "akgjhdjklgshg" self.device_id = "akgjhdjklgshg"
@ -51,34 +47,26 @@ class RegistrationStoreTestCase(unittest.TestCase):
"consent_server_notice_sent": None, "consent_server_notice_sent": None,
"appservice_id": None, "appservice_id": None,
}, },
(yield self.store.get_user_by_id(self.user_id)) (yield self.store.get_user_by_id(self.user_id)),
) )
result = yield self.store.get_user_by_access_token(self.tokens[0]) result = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertDictContainsSubset( self.assertDictContainsSubset({"name": self.user_id}, result)
{
"name": self.user_id,
},
result
)
self.assertTrue("token_id" in result) self.assertTrue("token_id" in result)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_add_tokens(self): def test_add_tokens(self):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash) yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], yield self.store.add_access_token_to_user(
self.device_id) self.user_id, self.tokens[1], self.device_id
)
result = yield self.store.get_user_by_access_token(self.tokens[1]) result = yield self.store.get_user_by_access_token(self.tokens[1])
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {"name": self.user_id, "device_id": self.device_id}, result
"name": self.user_id,
"device_id": self.device_id,
},
result
) )
self.assertTrue("token_id" in result) self.assertTrue("token_id" in result)
@ -87,12 +75,13 @@ class RegistrationStoreTestCase(unittest.TestCase):
def test_user_delete_access_tokens(self): def test_user_delete_access_tokens(self):
# add some tokens # add some tokens
yield self.store.register(self.user_id, self.tokens[0], self.pwhash) yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], yield self.store.add_access_token_to_user(
self.device_id) self.user_id, self.tokens[1], self.device_id
)
# now delete some # now delete some
yield self.store.user_delete_access_tokens( yield self.store.user_delete_access_tokens(
self.user_id, device_id=self.device_id, self.user_id, device_id=self.device_id
) )
# check they were deleted # check they were deleted
@ -107,8 +96,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
yield self.store.user_delete_access_tokens(self.user_id) yield self.store.user_delete_access_tokens(self.user_id)
user = yield self.store.get_user_by_access_token(self.tokens[0]) user = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertIsNone(user, self.assertIsNone(user, "access token was not deleted without device_id")
"access token was not deleted without device_id")
class TokenGenerator: class TokenGenerator:
@ -117,4 +105,4 @@ class TokenGenerator:
def generate(self, user_id): def generate(self, user_id):
self._last_issued_token += 1 self._last_issued_token += 1
return u"%s-%d" % (user_id, self._last_issued_token,) return u"%s-%d" % (user_id, self._last_issued_token)

View File

@ -24,7 +24,6 @@ from tests.utils import setup_test_homeserver
class RoomStoreTestCase(unittest.TestCase): class RoomStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver() hs = yield setup_test_homeserver()
@ -40,7 +39,7 @@ class RoomStoreTestCase(unittest.TestCase):
yield self.store.store_room( yield self.store.store_room(
self.room.to_string(), self.room.to_string(),
room_creator_user_id=self.u_creator.to_string(), room_creator_user_id=self.u_creator.to_string(),
is_public=True is_public=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -49,14 +48,13 @@ class RoomStoreTestCase(unittest.TestCase):
{ {
"room_id": self.room.to_string(), "room_id": self.room.to_string(),
"creator": self.u_creator.to_string(), "creator": self.u_creator.to_string(),
"is_public": True "is_public": True,
}, },
(yield self.store.get_room(self.room.to_string())) (yield self.store.get_room(self.room.to_string())),
) )
class RoomEventsStoreTestCase(unittest.TestCase): class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = setup_test_homeserver() hs = setup_test_homeserver()
@ -69,18 +67,13 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test") self.room = RoomID.from_string("!abcde:test")
yield self.store.store_room( yield self.store.store_room(
self.room.to_string(), self.room.to_string(), room_creator_user_id="@creator:text", is_public=True
room_creator_user_id="@creator:text",
is_public=True
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_room_event(self, **kwargs): def inject_room_event(self, **kwargs):
yield self.store.persist_event( yield self.store.persist_event(
self.event_factory.create_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
room_id=self.room.to_string(),
**kwargs
)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -88,22 +81,15 @@ class RoomEventsStoreTestCase(unittest.TestCase):
name = u"A-Room-Name" name = u"A-Room-Name"
yield self.inject_room_event( yield self.inject_room_event(
etype=EventTypes.Name, etype=EventTypes.Name, name=name, content={"name": name}, depth=1
name=name,
content={"name": name},
depth=1,
) )
state = yield self.store.get_current_state( state = yield self.store.get_current_state(room_id=self.room.to_string())
room_id=self.room.to_string()
)
self.assertEquals(1, len(state)) self.assertEquals(1, len(state))
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{"type": "m.room.name", {"type": "m.room.name", "room_id": self.room.to_string(), "name": name},
"room_id": self.room.to_string(), state[0],
"name": name},
state[0]
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -111,22 +97,15 @@ class RoomEventsStoreTestCase(unittest.TestCase):
topic = u"A place for things" topic = u"A place for things"
yield self.inject_room_event( yield self.inject_room_event(
etype=EventTypes.Topic, etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
topic=topic,
content={"topic": topic},
depth=1,
) )
state = yield self.store.get_current_state( state = yield self.store.get_current_state(room_id=self.room.to_string())
room_id=self.room.to_string()
)
self.assertEquals(1, len(state)) self.assertEquals(1, len(state))
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{"type": "m.room.topic", {"type": "m.room.topic", "room_id": self.room.to_string(), "topic": topic},
"room_id": self.room.to_string(), state[0],
"topic": topic},
state[0]
) )
# Not testing the various 'level' methods for now because there's lots # Not testing the various 'level' methods for now because there's lots

View File

@ -26,12 +26,10 @@ from tests.utils import setup_test_homeserver
class RoomMemberStoreTestCase(unittest.TestCase): class RoomMemberStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
resource_for_federation=Mock(), resource_for_federation=Mock(), http_client=None
http_client=None,
) )
# We can't test the RoomMemberStore on its own without the other event # We can't test the RoomMemberStore on its own without the other event
# storage logic # storage logic
@ -49,13 +47,15 @@ class RoomMemberStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_room_member(self, room, user, membership, replaces_state=None): def inject_room_member(self, room, user, membership, replaces_state=None):
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new(
{
"type": EventTypes.Member, "type": EventTypes.Member,
"sender": user.to_string(), "sender": user.to_string(),
"state_key": user.to_string(), "state_key": user.to_string(),
"room_id": room.to_string(), "room_id": room.to_string(),
"content": {"membership": membership}, "content": {"membership": membership},
}) }
)
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
@ -71,9 +71,12 @@ class RoomMemberStoreTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
[self.room.to_string()], [self.room.to_string()],
[m.room_id for m in ( [
m.room_id
for m in (
yield self.store.get_rooms_for_user_where_membership_is( yield self.store.get_rooms_for_user_where_membership_is(
self.u_alice.to_string(), [Membership.JOIN] self.u_alice.to_string(), [Membership.JOIN]
) )
)] )
],
) )

View File

@ -45,20 +45,20 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test") self.room = RoomID.from_string("!abc123:test")
yield self.store.store_room( yield self.store.store_room(
self.room.to_string(), self.room.to_string(), room_creator_user_id="@creator:text", is_public=True
room_creator_user_id="@creator:text",
is_public=True
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content): def inject_state_event(self, room, sender, typ, state_key, content):
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new(
{
"type": typ, "type": typ,
"sender": sender.to_string(), "sender": sender.to_string(),
"state_key": state_key, "state_key": state_key,
"room_id": room.to_string(), "room_id": room.to_string(),
"content": content, "content": content,
}) }
)
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
@ -80,27 +80,31 @@ class StateStoreTestCase(tests.unittest.TestCase):
# this defaults to a linear DAG as each new injection defaults to whatever # this defaults to a linear DAG as each new injection defaults to whatever
# forward extremities are currently in the DB for this room. # forward extremities are currently in the DB for this room.
e1 = yield self.inject_state_event( e1 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Create, '', {}, self.room, self.u_alice, EventTypes.Create, '', {}
) )
e2 = yield self.inject_state_event( e2 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Name, '', { self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
"name": "test room"
},
) )
e3 = yield self.inject_state_event( e3 = yield self.inject_state_event(
self.room, self.u_alice, EventTypes.Member, self.u_alice.to_string(), { self.room,
"membership": Membership.JOIN self.u_alice,
}, EventTypes.Member,
self.u_alice.to_string(),
{"membership": Membership.JOIN},
) )
e4 = yield self.inject_state_event( e4 = yield self.inject_state_event(
self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), { self.room,
"membership": Membership.JOIN self.u_bob,
}, EventTypes.Member,
self.u_bob.to_string(),
{"membership": Membership.JOIN},
) )
e5 = yield self.inject_state_event( e5 = yield self.inject_state_event(
self.room, self.u_bob, EventTypes.Member, self.u_bob.to_string(), { self.room,
"membership": Membership.LEAVE self.u_bob,
}, EventTypes.Member,
self.u_bob.to_string(),
{"membership": Membership.LEAVE},
) )
# check we get the full state as of the final event # check we get the full state as of the final event
@ -110,65 +114,66 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.assertIsNotNone(e4) self.assertIsNotNone(e4)
self.assertStateMapEqual({ self.assertStateMapEqual(
{
(e1.type, e1.state_key): e1, (e1.type, e1.state_key): e1,
(e2.type, e2.state_key): e2, (e2.type, e2.state_key): e2,
(e3.type, e3.state_key): e3, (e3.type, e3.state_key): e3,
# e4 is overwritten by e5 # e4 is overwritten by e5
(e5.type, e5.state_key): e5, (e5.type, e5.state_key): e5,
}, state) },
state,
)
# check we can filter to the m.room.name event (with a '' state key) # check we can filter to the m.room.name event (with a '' state key)
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Name, '')], filtered_types=None e5.event_id, [(EventTypes.Name, '')], filtered_types=None
) )
self.assertStateMapEqual({ self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
(e2.type, e2.state_key): e2,
}, state)
# check we can filter to the m.room.name event (with a wildcard None state key) # check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Name, None)], filtered_types=None e5.event_id, [(EventTypes.Name, None)], filtered_types=None
) )
self.assertStateMapEqual({ self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
(e2.type, e2.state_key): e2,
}, state)
# check we can grab the m.room.member events (with a wildcard None state key) # check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Member, None)], filtered_types=None e5.event_id, [(EventTypes.Member, None)], filtered_types=None
) )
self.assertStateMapEqual({ self.assertStateMapEqual(
(e3.type, e3.state_key): e3, {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
(e5.type, e5.state_key): e5, )
}, state)
# check we can use filter_types to grab a specific room member # check we can use filter_types to grab a specific room member
# without filtering out the other event types # without filtering out the other event types
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
e5.event_id, [(EventTypes.Member, self.u_alice.to_string())], e5.event_id,
[(EventTypes.Member, self.u_alice.to_string())],
filtered_types=[EventTypes.Member], filtered_types=[EventTypes.Member],
) )
self.assertStateMapEqual({ self.assertStateMapEqual(
{
(e1.type, e1.state_key): e1, (e1.type, e1.state_key): e1,
(e2.type, e2.state_key): e2, (e2.type, e2.state_key): e2,
(e3.type, e3.state_key): e3, (e3.type, e3.state_key): e3,
}, state) },
state,
)
# check that types=[], filtered_types=[EventTypes.Member] # check that types=[], filtered_types=[EventTypes.Member]
# doesn't return all members # doesn't return all members
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
e5.event_id, [], filtered_types=[EventTypes.Member], e5.event_id, [], filtered_types=[EventTypes.Member]
) )
self.assertStateMapEqual({ self.assertStateMapEqual(
(e1.type, e1.state_key): e1, {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
(e2.type, e2.state_key): e2, )
}, state)
####################################################### #######################################################
# _get_some_state_from_cache tests against a full cache # _get_some_state_from_cache tests against a full cache
@ -184,10 +189,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual({ self.assertDictEqual(
{
(e1.type, e1.state_key): e1.event_id, (e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id, (e2.type, e2.state_key): e2.event_id,
}, state_dict) },
state_dict,
)
# test _get_some_state_from_cache correctly filters in members with wildcard types # test _get_some_state_from_cache correctly filters in members with wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
@ -195,25 +203,33 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual({ self.assertDictEqual(
{
(e1.type, e1.state_key): e1.event_id, (e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id, (e2.type, e2.state_key): e2.event_id,
(e3.type, e3.state_key): e3.event_id, (e3.type, e3.state_key): e3.event_id,
# e4 is overwritten by e5 # e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id, (e5.type, e5.state_key): e5.event_id,
}, state_dict) },
state_dict,
)
# test _get_some_state_from_cache correctly filters in members with specific types # test _get_some_state_from_cache correctly filters in members with specific types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member] group,
[(EventTypes.Member, e5.state_key)],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual({ self.assertDictEqual(
{
(e1.type, e1.state_key): e1.event_id, (e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id, (e2.type, e2.state_key): e2.event_id,
(e5.type, e5.state_key): e5.event_id, (e5.type, e5.state_key): e5.event_id,
}, state_dict) },
state_dict,
)
# test _get_some_state_from_cache correctly filters in members with specific types # test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types # and no filtered_types
@ -222,24 +238,27 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual({ self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
(e5.type, e5.state_key): e5.event_id,
}, state_dict)
####################################################### #######################################################
# deliberately remove e2 (room name) from the _state_group_cache # deliberately remove e2 (room name) from the _state_group_cache
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group) (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
group
)
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertEqual(known_absent, set()) self.assertEqual(known_absent, set())
self.assertDictEqual(state_dict_ids, { self.assertDictEqual(
state_dict_ids,
{
(e1.type, e1.state_key): e1.event_id, (e1.type, e1.state_key): e1.event_id,
(e2.type, e2.state_key): e2.event_id, (e2.type, e2.state_key): e2.event_id,
(e3.type, e3.state_key): e3.event_id, (e3.type, e3.state_key): e3.event_id,
# e4 is overwritten by e5 # e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id, (e5.type, e5.state_key): e5.event_id,
}) },
)
state_dict_ids.pop((e2.type, e2.state_key)) state_dict_ids.pop((e2.type, e2.state_key))
self.store._state_group_cache.invalidate(group) self.store._state_group_cache.invalidate(group)
@ -252,22 +271,32 @@ class StateStoreTestCase(tests.unittest.TestCase):
(e1.type, e1.state_key), (e1.type, e1.state_key),
(e3.type, e3.state_key), (e3.type, e3.state_key),
(e5.type, e5.state_key), (e5.type, e5.state_key),
) ),
) )
(is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(group) (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
group
)
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertEqual(known_absent, set([ self.assertEqual(
known_absent,
set(
[
(e1.type, e1.state_key), (e1.type, e1.state_key),
(e3.type, e3.state_key), (e3.type, e3.state_key),
(e5.type, e5.state_key), (e5.type, e5.state_key),
])) ]
self.assertDictEqual(state_dict_ids, { ),
)
self.assertDictEqual(
state_dict_ids,
{
(e1.type, e1.state_key): e1.event_id, (e1.type, e1.state_key): e1.event_id,
(e3.type, e3.state_key): e3.event_id, (e3.type, e3.state_key): e3.event_id,
(e5.type, e5.state_key): e5.event_id, (e5.type, e5.state_key): e5.event_id,
}) },
)
############################################ ############################################
# test that things work with a partial cache # test that things work with a partial cache
@ -279,9 +308,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({ self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
(e1.type, e1.state_key): e1.event_id,
}, state_dict)
# test _get_some_state_from_cache correctly filters in members wildcard types # test _get_some_state_from_cache correctly filters in members wildcard types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
@ -289,23 +316,31 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({ self.assertDictEqual(
{
(e1.type, e1.state_key): e1.event_id, (e1.type, e1.state_key): e1.event_id,
(e3.type, e3.state_key): e3.event_id, (e3.type, e3.state_key): e3.event_id,
# e4 is overwritten by e5 # e4 is overwritten by e5
(e5.type, e5.state_key): e5.event_id, (e5.type, e5.state_key): e5.event_id,
}, state_dict) },
state_dict,
)
# test _get_some_state_from_cache correctly filters in members with specific types # test _get_some_state_from_cache correctly filters in members with specific types
(state_dict, is_all) = yield self.store._get_some_state_from_cache( (state_dict, is_all) = yield self.store._get_some_state_from_cache(
group, [(EventTypes.Member, e5.state_key)], filtered_types=[EventTypes.Member] group,
[(EventTypes.Member, e5.state_key)],
filtered_types=[EventTypes.Member],
) )
self.assertEqual(is_all, False) self.assertEqual(is_all, False)
self.assertDictEqual({ self.assertDictEqual(
{
(e1.type, e1.state_key): e1.event_id, (e1.type, e1.state_key): e1.event_id,
(e5.type, e5.state_key): e5.event_id, (e5.type, e5.state_key): e5.event_id,
}, state_dict) },
state_dict,
)
# test _get_some_state_from_cache correctly filters in members with specific types # test _get_some_state_from_cache correctly filters in members with specific types
# and no filtered_types # and no filtered_types
@ -314,6 +349,4 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
self.assertEqual(is_all, True) self.assertEqual(is_all, True)
self.assertDictEqual({ self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
(e5.type, e5.state_key): e5.event_id,
}, state_dict)

View File

@ -39,20 +39,12 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
{ {
ALICE: ProfileInfo(None, "alice"), ALICE: ProfileInfo(None, "alice"),
BOB: ProfileInfo(None, "bob"), BOB: ProfileInfo(None, "bob"),
BOBBY: ProfileInfo(None, "bobby") BOBBY: ProfileInfo(None, "bobby"),
}, },
) )
yield self.store.add_users_to_public_room( yield self.store.add_users_to_public_room("!room:id", [ALICE, BOB])
"!room:id",
[ALICE, BOB],
)
yield self.store.add_users_who_share_room( yield self.store.add_users_who_share_room(
"!room:id", "!room:id", False, ((ALICE, BOB), (BOB, ALICE))
False,
(
(ALICE, BOB),
(BOB, ALICE),
),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -62,11 +54,9 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
r = yield self.store.search_user_dir(ALICE, "bob", 10) r = yield self.store.search_user_dir(ALICE, "bob", 10)
self.assertFalse(r["limited"]) self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"])) self.assertEqual(1, len(r["results"]))
self.assertDictEqual(r["results"][0], { self.assertDictEqual(
"user_id": BOB, r["results"][0], {"user_id": BOB, "display_name": "bob", "avatar_url": None}
"display_name": "bob", )
"avatar_url": None,
})
@defer.inlineCallbacks @defer.inlineCallbacks
def test_search_user_dir_all_users(self): def test_search_user_dir_all_users(self):
@ -75,15 +65,13 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
r = yield self.store.search_user_dir(ALICE, "bob", 10) r = yield self.store.search_user_dir(ALICE, "bob", 10)
self.assertFalse(r["limited"]) self.assertFalse(r["limited"])
self.assertEqual(2, len(r["results"])) self.assertEqual(2, len(r["results"]))
self.assertDictEqual(r["results"][0], { self.assertDictEqual(
"user_id": BOB, r["results"][0],
"display_name": "bob", {"user_id": BOB, "display_name": "bob", "avatar_url": None},
"avatar_url": None, )
}) self.assertDictEqual(
self.assertDictEqual(r["results"][1], { r["results"][1],
"user_id": BOBBY, {"user_id": BOBBY, "display_name": "bobby", "avatar_url": None},
"display_name": "bobby", )
"avatar_url": None,
})
finally: finally:
self.hs.config.user_directory_search_all_users = False self.hs.config.user_directory_search_all_users = False

View File

@ -22,7 +22,6 @@ from . import unittest
class DistributorTestCase(unittest.TestCase): class DistributorTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.dist = Distributor() self.dist = Distributor()
@ -44,18 +43,14 @@ class DistributorTestCase(unittest.TestCase):
observers[0].side_effect = Exception("Awoogah!") observers[0].side_effect = Exception("Awoogah!")
with patch( with patch("synapse.util.distributor.logger", spec=["warning"]) as mock_logger:
"synapse.util.distributor.logger", spec=["warning"]
) as mock_logger:
self.dist.fire("alarm", "Go") self.dist.fire("alarm", "Go")
observers[0].assert_called_once_with("Go") observers[0].assert_called_once_with("Go")
observers[1].assert_called_once_with("Go") observers[1].assert_called_once_with("Go")
self.assertEquals(mock_logger.warning.call_count, 1) self.assertEquals(mock_logger.warning.call_count, 1)
self.assertIsInstance( self.assertIsInstance(mock_logger.warning.call_args[0][0], str)
mock_logger.warning.call_args[0][0], str
)
def test_signal_prereg(self): def test_signal_prereg(self):
observer = Mock() observer = Mock()
@ -69,4 +64,5 @@ class DistributorTestCase(unittest.TestCase):
def test_signal_undeclared(self): def test_signal_undeclared(self):
def code(): def code():
self.dist.fire("notification") self.dist.fire("notification")
self.assertRaises(KeyError, code) self.assertRaises(KeyError, code)

View File

@ -27,7 +27,6 @@ from . import unittest
@unittest.DEBUG @unittest.DEBUG
class DnsTestCase(unittest.TestCase): class DnsTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve(self): def test_resolve(self):
dns_client_mock = Mock() dns_client_mock = Mock()
@ -36,14 +35,11 @@ class DnsTestCase(unittest.TestCase):
host_name = "example.com" host_name = "example.com"
answer_srv = dns.RRHeader( answer_srv = dns.RRHeader(
type=dns.SRV, type=dns.SRV, payload=dns.Record_SRV(target=host_name)
payload=dns.Record_SRV(
target=host_name,
)
) )
dns_client_mock.lookupService.return_value = defer.succeed( dns_client_mock.lookupService.return_value = defer.succeed(
([answer_srv], None, None), ([answer_srv], None, None)
) )
cache = {} cache = {}
@ -68,9 +64,7 @@ class DnsTestCase(unittest.TestCase):
entry = Mock(spec_set=["expires"]) entry = Mock(spec_set=["expires"])
entry.expires = 0 entry.expires = 0
cache = { cache = {service_name: [entry]}
service_name: [entry]
}
servers = yield resolve_service( servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache service_name, dns_client=dns_client_mock, cache=cache
@ -93,12 +87,10 @@ class DnsTestCase(unittest.TestCase):
entry = Mock(spec_set=["expires"]) entry = Mock(spec_set=["expires"])
entry.expires = 999999999 entry.expires = 999999999
cache = { cache = {service_name: [entry]}
service_name: [entry]
}
servers = yield resolve_service( servers = yield resolve_service(
service_name, dns_client=dns_client_mock, cache=cache, clock=clock, service_name, dns_client=dns_client_mock, cache=cache, clock=clock
) )
self.assertFalse(dns_client_mock.lookupService.called) self.assertFalse(dns_client_mock.lookupService.called)
@ -117,9 +109,7 @@ class DnsTestCase(unittest.TestCase):
cache = {} cache = {}
with self.assertRaises(error.DNSServerError): with self.assertRaises(error.DNSServerError):
yield resolve_service( yield resolve_service(service_name, dns_client=dns_client_mock, cache=cache)
service_name, dns_client=dns_client_mock, cache=cache
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_name_error(self): def test_name_error(self):

View File

@ -35,10 +35,7 @@ class EventAuthTestCase(unittest.TestCase):
} }
# creator should be able to send state # creator should be able to send state
event_auth.check( event_auth.check(_random_state_event(creator), auth_events, do_sig_check=False)
_random_state_event(creator), auth_events,
do_sig_check=False,
)
# joiner should not be able to send state # joiner should not be able to send state
self.assertRaises( self.assertRaises(
@ -61,13 +58,9 @@ class EventAuthTestCase(unittest.TestCase):
auth_events = { auth_events = {
("m.room.create", ""): _create_event(creator), ("m.room.create", ""): _create_event(creator),
("m.room.member", creator): _join_event(creator), ("m.room.member", creator): _join_event(creator),
("m.room.power_levels", ""): _power_levels_event(creator, { ("m.room.power_levels", ""): _power_levels_event(
"state_default": "30", creator, {"state_default": "30", "users": {pleb: "29", king: "30"}}
"users": { ),
pleb: "29",
king: "30",
},
}),
("m.room.member", pleb): _join_event(pleb), ("m.room.member", pleb): _join_event(pleb),
("m.room.member", king): _join_event(king), ("m.room.member", king): _join_event(king),
} }
@ -82,10 +75,7 @@ class EventAuthTestCase(unittest.TestCase):
), ),
# king should be able to send state # king should be able to send state
event_auth.check( event_auth.check(_random_state_event(king), auth_events, do_sig_check=False)
_random_state_event(king), auth_events,
do_sig_check=False,
)
# helpers for making events # helpers for making events
@ -94,52 +84,54 @@ TEST_ROOM_ID = "!test:room"
def _create_event(user_id): def _create_event(user_id):
return FrozenEvent({ return FrozenEvent(
{
"room_id": TEST_ROOM_ID, "room_id": TEST_ROOM_ID,
"event_id": _get_event_id(), "event_id": _get_event_id(),
"type": "m.room.create", "type": "m.room.create",
"sender": user_id, "sender": user_id,
"content": { "content": {"creator": user_id},
"creator": user_id, }
}, )
})
def _join_event(user_id): def _join_event(user_id):
return FrozenEvent({ return FrozenEvent(
{
"room_id": TEST_ROOM_ID, "room_id": TEST_ROOM_ID,
"event_id": _get_event_id(), "event_id": _get_event_id(),
"type": "m.room.member", "type": "m.room.member",
"sender": user_id, "sender": user_id,
"state_key": user_id, "state_key": user_id,
"content": { "content": {"membership": "join"},
"membership": "join", }
}, )
})
def _power_levels_event(sender, content): def _power_levels_event(sender, content):
return FrozenEvent({ return FrozenEvent(
{
"room_id": TEST_ROOM_ID, "room_id": TEST_ROOM_ID,
"event_id": _get_event_id(), "event_id": _get_event_id(),
"type": "m.room.power_levels", "type": "m.room.power_levels",
"sender": sender, "sender": sender,
"state_key": "", "state_key": "",
"content": content, "content": content,
}) }
)
def _random_state_event(sender): def _random_state_event(sender):
return FrozenEvent({ return FrozenEvent(
{
"room_id": TEST_ROOM_ID, "room_id": TEST_ROOM_ID,
"event_id": _get_event_id(), "event_id": _get_event_id(),
"type": "test.state", "type": "test.state",
"sender": sender, "sender": sender,
"state_key": "", "state_key": "",
"content": { "content": {"membership": "join"},
"membership": "join", }
}, )
})
event_count = 0 event_count = 0

View File

@ -22,7 +22,6 @@ from . import unittest
class PreviewTestCase(unittest.TestCase): class PreviewTestCase(unittest.TestCase):
def test_long_summarize(self): def test_long_summarize(self):
example_paras = [ example_paras = [
u"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: u"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
@ -32,7 +31,6 @@ class PreviewTestCase(unittest.TestCase):
alternative spellings of the city.Tromsø is considered the northernmost alternative spellings of the city.Tromsø is considered the northernmost
city in the world with a population above 50,000. The most populous town city in the world with a population above 50,000. The most populous town
north of it is Alta, Norway, with a population of 14,272 (2013).""", north of it is Alta, Norway, with a population of 14,272 (2013).""",
u"""Tromsø lies in Northern Norway. The municipality has a population of u"""Tromsø lies in Northern Norway. The municipality has a population of
(2015) 72,066, but with an annual influx of students it has over 75,000 (2015) 72,066, but with an annual influx of students it has over 75,000
most of the year. It is the largest urban area in Northern Norway and the most of the year. It is the largest urban area in Northern Norway and the
@ -46,7 +44,6 @@ class PreviewTestCase(unittest.TestCase):
Sandnessund Bridge. Tromsø Airport connects the city to many destinations Sandnessund Bridge. Tromsø Airport connects the city to many destinations
in Europe. The city is warmer than most other places located on the same in Europe. The city is warmer than most other places located on the same
latitude, due to the warming effect of the Gulf Stream.""", latitude, due to the warming effect of the Gulf Stream.""",
u"""The city centre of Tromsø contains the highest number of old wooden u"""The city centre of Tromsø contains the highest number of old wooden
houses in Northern Norway, the oldest house dating from 1789. The Arctic houses in Northern Norway, the oldest house dating from 1789. The Arctic
Cathedral, a modern church from 1965, is probably the most famous landmark Cathedral, a modern church from 1965, is probably the most famous landmark
@ -67,7 +64,7 @@ class PreviewTestCase(unittest.TestCase):
u" the city of Tromsø. Outside of Norway, Tromso and Tromsö are" u" the city of Tromsø. Outside of Norway, Tromso and Tromsö are"
u" alternative spellings of the city.Tromsø is considered the northernmost" u" alternative spellings of the city.Tromsø is considered the northernmost"
u" city in the world with a population above 50,000. The most populous town" u" city in the world with a population above 50,000. The most populous town"
u" north of it is Alta, Norway, with a population of 14,272 (2013)." u" north of it is Alta, Norway, with a population of 14,272 (2013).",
) )
desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500)
@ -80,7 +77,7 @@ class PreviewTestCase(unittest.TestCase):
u" third largest north of the Arctic Circle (following Murmansk and Norilsk)." u" third largest north of the Arctic Circle (following Murmansk and Norilsk)."
u" Most of Tromsø, including the city centre, is located on the island of" u" Most of Tromsø, including the city centre, is located on the island of"
u" Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," u" Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012,"
u" Tromsøya had a population of 36,088. Substantial parts of the urban…" u" Tromsøya had a population of 36,088. Substantial parts of the urban…",
) )
def test_short_summarize(self): def test_short_summarize(self):
@ -88,11 +85,9 @@ class PreviewTestCase(unittest.TestCase):
u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
u" Troms county, Norway.", u" Troms county, Norway.",
u"Tromsø lies in Northern Norway. The municipality has a population of" u"Tromsø lies in Northern Norway. The municipality has a population of"
u" (2015) 72,066, but with an annual influx of students it has over 75,000" u" (2015) 72,066, but with an annual influx of students it has over 75,000"
u" most of the year.", u" most of the year.",
u"The city centre of Tromsø contains the highest number of old wooden" u"The city centre of Tromsø contains the highest number of old wooden"
u" houses in Northern Norway, the oldest house dating from 1789. The Arctic" u" houses in Northern Norway, the oldest house dating from 1789. The Arctic"
u" Cathedral, a modern church from 1965, is probably the most famous landmark" u" Cathedral, a modern church from 1965, is probably the most famous landmark"
@ -109,7 +104,7 @@ class PreviewTestCase(unittest.TestCase):
u"\n" u"\n"
u"Tromsø lies in Northern Norway. The municipality has a population of" u"Tromsø lies in Northern Norway. The municipality has a population of"
u" (2015) 72,066, but with an annual influx of students it has over 75,000" u" (2015) 72,066, but with an annual influx of students it has over 75,000"
u" most of the year." u" most of the year.",
) )
def test_small_then_large_summarize(self): def test_small_then_large_summarize(self):
@ -117,7 +112,6 @@ class PreviewTestCase(unittest.TestCase):
u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" u"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" u" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
u" Troms county, Norway.", u" Troms county, Norway.",
u"Tromsø lies in Northern Norway. The municipality has a population of" u"Tromsø lies in Northern Norway. The municipality has a population of"
u" (2015) 72,066, but with an annual influx of students it has over 75,000" u" (2015) 72,066, but with an annual influx of students it has over 75,000"
u" most of the year." u" most of the year."
@ -138,7 +132,7 @@ class PreviewTestCase(unittest.TestCase):
u" (2015) 72,066, but with an annual influx of students it has over 75,000" u" (2015) 72,066, but with an annual influx of students it has over 75,000"
u" most of the year. The city centre of Tromsø contains the highest number" u" most of the year. The city centre of Tromsø contains the highest number"
u" of old wooden houses in Northern Norway, the oldest house dating from" u" of old wooden houses in Northern Norway, the oldest house dating from"
u" 1789. The Arctic Cathedral, a modern church from…" u" 1789. The Arctic Cathedral, a modern church from…",
) )
@ -155,10 +149,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html") og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, { self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."})
u"og:title": u"Foo",
u"og:description": u"Some text."
})
def test_comment(self): def test_comment(self):
html = u""" html = u"""
@ -173,10 +164,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html") og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, { self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."})
u"og:title": u"Foo",
u"og:description": u"Some text."
})
def test_comment2(self): def test_comment2(self):
html = u""" html = u"""
@ -194,10 +182,13 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html") og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, { self.assertEquals(
og,
{
u"og:title": u"Foo", u"og:title": u"Foo",
u"og:description": u"Some text.\n\nSome more text.\n\nText\n\nMore text" u"og:description": u"Some text.\n\nSome more text.\n\nText\n\nMore text",
}) },
)
def test_script(self): def test_script(self):
html = u""" html = u"""
@ -212,10 +203,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html") og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, { self.assertEquals(og, {u"og:title": u"Foo", u"og:description": u"Some text."})
u"og:title": u"Foo",
u"og:description": u"Some text."
})
def test_missing_title(self): def test_missing_title(self):
html = u""" html = u"""
@ -228,10 +216,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html") og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, { self.assertEquals(og, {u"og:title": None, u"og:description": u"Some text."})
u"og:title": None,
u"og:description": u"Some text."
})
def test_h1_as_title(self): def test_h1_as_title(self):
html = u""" html = u"""
@ -245,10 +230,7 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html") og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, { self.assertEquals(og, {u"og:title": u"Title", u"og:description": u"Some text."})
u"og:title": u"Title",
u"og:description": u"Some text."
})
def test_missing_title_and_broken_h1(self): def test_missing_title_and_broken_h1(self):
html = u""" html = u"""
@ -262,7 +244,4 @@ class PreviewUrlTestCase(unittest.TestCase):
og = decode_and_calc_og(html, "http://example.com/test.html") og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEquals(og, { self.assertEquals(og, {u"og:title": None, u"og:description": u"Some text."})
u"og:title": None,
u"og:description": u"Some text."
})

View File

@ -29,8 +29,15 @@ from .utils import MockClock
_next_event_id = 1000 _next_event_id = 1000
def create_event(name=None, type=None, state_key=None, depth=2, event_id=None, def create_event(
prev_events=[], **kwargs): name=None,
type=None,
state_key=None,
depth=2,
event_id=None,
prev_events=[],
**kwargs
):
global _next_event_id global _next_event_id
if not event_id: if not event_id:
@ -39,9 +46,9 @@ def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
if not name: if not name:
if state_key is not None: if state_key is not None:
name = "<%s-%s, %s>" % (type, state_key, event_id,) name = "<%s-%s, %s>" % (type, state_key, event_id)
else: else:
name = "<%s, %s>" % (type, event_id,) name = "<%s, %s>" % (type, event_id)
d = { d = {
"event_id": event_id, "event_id": event_id,
@ -80,8 +87,9 @@ class StateGroupStore(object):
return defer.succeed(groups) return defer.succeed(groups)
def store_state_group(self, event_id, room_id, prev_group, delta_ids, def store_state_group(
current_state_ids): self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
state_group = self._next_group state_group = self._next_group
self._next_group += 1 self._next_group += 1
@ -91,7 +99,8 @@ class StateGroupStore(object):
def get_events(self, event_ids, **kwargs): def get_events(self, event_ids, **kwargs):
return { return {
e_id: self._event_id_to_event[e_id] for e_id in event_ids e_id: self._event_id_to_event[e_id]
for e_id in event_ids
if e_id in self._event_id_to_event if e_id in self._event_id_to_event
} }
@ -129,9 +138,7 @@ class Graph(object):
prev_events = [] prev_events = []
events[event_id] = create_event( events[event_id] = create_event(
event_id=event_id, event_id=event_id, prev_events=prev_events, **fields
prev_events=prev_events,
**fields
) )
self._leaves = clobbered self._leaves = clobbered
@ -147,10 +154,15 @@ class Graph(object):
class StateTestCase(unittest.TestCase): class StateTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = StateGroupStore() self.store = StateGroupStore()
hs = Mock(spec_set=[ hs = Mock(
"get_datastore", "get_auth", "get_state_handler", "get_clock", spec_set=[
"get_datastore",
"get_auth",
"get_state_handler",
"get_clock",
"get_state_resolution_handler", "get_state_resolution_handler",
]) ]
)
hs.get_datastore.return_value = self.store hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
@ -164,35 +176,13 @@ class StateTestCase(unittest.TestCase):
def test_branch_no_conflict(self): def test_branch_no_conflict(self):
graph = Graph( graph = Graph(
nodes={ nodes={
"START": DictObj( "START": DictObj(type=EventTypes.Create, state_key="", depth=1),
type=EventTypes.Create, "A": DictObj(type=EventTypes.Message, depth=2),
state_key="", "B": DictObj(type=EventTypes.Message, depth=3),
depth=1, "C": DictObj(type=EventTypes.Name, state_key="", depth=3),
), "D": DictObj(type=EventTypes.Message, depth=4),
"A": DictObj(
type=EventTypes.Message,
depth=2,
),
"B": DictObj(
type=EventTypes.Message,
depth=3,
),
"C": DictObj(
type=EventTypes.Name,
state_key="",
depth=3,
),
"D": DictObj(
type=EventTypes.Message,
depth=4,
),
}, },
edges={ edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
"A": ["START"],
"B": ["A"],
"C": ["A"],
"D": ["B", "C"]
}
) )
self.store.register_events(graph.walk()) self.store.register_events(graph.walk())
@ -224,27 +214,11 @@ class StateTestCase(unittest.TestCase):
membership=Membership.JOIN, membership=Membership.JOIN,
depth=2, depth=2,
), ),
"B": DictObj( "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
type=EventTypes.Name, "C": DictObj(type=EventTypes.Name, state_key="", depth=4),
state_key="", "D": DictObj(type=EventTypes.Message, depth=5),
depth=3,
),
"C": DictObj(
type=EventTypes.Name,
state_key="",
depth=4,
),
"D": DictObj(
type=EventTypes.Message,
depth=5,
),
}, },
edges={ edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
"A": ["START"],
"B": ["A"],
"C": ["A"],
"D": ["B", "C"]
}
) )
self.store.register_events(graph.walk()) self.store.register_events(graph.walk())
@ -259,8 +233,7 @@ class StateTestCase(unittest.TestCase):
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "C"}, {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
{e_id for e_id in prev_state_ids.values()}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -280,11 +253,7 @@ class StateTestCase(unittest.TestCase):
membership=Membership.JOIN, membership=Membership.JOIN,
depth=2, depth=2,
), ),
"B": DictObj( "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
type=EventTypes.Name,
state_key="",
depth=3,
),
"C": DictObj( "C": DictObj(
type=EventTypes.Member, type=EventTypes.Member,
state_key="@user_id_2:example.com", state_key="@user_id_2:example.com",
@ -298,18 +267,9 @@ class StateTestCase(unittest.TestCase):
depth=4, depth=4,
sender="@user_id_2:example.com", sender="@user_id_2:example.com",
), ),
"E": DictObj( "E": DictObj(type=EventTypes.Message, depth=5),
type=EventTypes.Message,
depth=5,
),
}, },
edges={ edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
"A": ["START"],
"B": ["A"],
"C": ["B"],
"D": ["B"],
"E": ["C", "D"]
}
) )
self.store.register_events(graph.walk()) self.store.register_events(graph.walk())
@ -324,8 +284,7 @@ class StateTestCase(unittest.TestCase):
prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store) prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
self.assertSetEqual( self.assertSetEqual(
{"START", "A", "B", "C"}, {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
{e for e in prev_state_ids.values()}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -357,30 +316,17 @@ class StateTestCase(unittest.TestCase):
state_key="", state_key="",
content={ content={
"events": {"m.room.name": 50}, "events": {"m.room.name": 50},
"users": {userid1: 100, "users": {userid1: 100, userid2: 60},
userid2: 60},
}, },
), ),
"A5": DictObj( "A5": DictObj(type=EventTypes.Name, state_key=""),
type=EventTypes.Name,
state_key="",
),
"B": DictObj( "B": DictObj(
type=EventTypes.PowerLevels, type=EventTypes.PowerLevels,
state_key="", state_key="",
content={ content={"events": {"m.room.name": 50}, "users": {userid2: 30}},
"events": {"m.room.name": 50},
"users": {userid2: 30},
},
),
"C": DictObj(
type=EventTypes.Name,
state_key="",
sender=userid2,
),
"D": DictObj(
type=EventTypes.Message,
), ),
"C": DictObj(type=EventTypes.Name, state_key="", sender=userid2),
"D": DictObj(type=EventTypes.Message),
} }
edges = { edges = {
"A2": ["A1"], "A2": ["A1"],
@ -389,7 +335,7 @@ class StateTestCase(unittest.TestCase):
"A5": ["A4"], "A5": ["A4"],
"B": ["A5"], "B": ["A5"],
"C": ["A5"], "C": ["A5"],
"D": ["B", "C"] "D": ["B", "C"],
} }
self._add_depths(nodes, edges) self._add_depths(nodes, edges)
graph = Graph(nodes, edges) graph = Graph(nodes, edges)
@ -406,8 +352,7 @@ class StateTestCase(unittest.TestCase):
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store) prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
self.assertSetEqual( self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"}, {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
{e for e in prev_state_ids.values()}
) )
def _add_depths(self, nodes, edges): def _add_depths(self, nodes, edges):
@ -432,9 +377,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
context = yield self.state.compute_event_context( context = yield self.state.compute_event_context(event, old_state=old_state)
event, old_state=old_state
)
current_state_ids = yield context.get_current_state_ids(self.store) current_state_ids = yield context.get_current_state_ids(self.store)
@ -454,9 +397,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
context = yield self.state.compute_event_context( context = yield self.state.compute_event_context(event, old_state=old_state)
event, old_state=old_state
)
prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = yield context.get_prev_state_ids(self.store)
@ -468,8 +409,7 @@ class StateTestCase(unittest.TestCase):
def test_trivial_annotate_message(self): def test_trivial_annotate_message(self):
prev_event_id = "prev_event_id" prev_event_id = "prev_event_id"
event = create_event( event = create_event(
type="test_message", name="event2", type="test_message", name="event2", prev_events=[(prev_event_id, {})]
prev_events=[(prev_event_id, {})],
) )
old_state = [ old_state = [
@ -479,7 +419,10 @@ class StateTestCase(unittest.TestCase):
] ]
group_name = self.store.store_state_group( group_name = self.store.store_state_group(
prev_event_id, event.room_id, None, None, prev_event_id,
event.room_id,
None,
None,
{(e.type, e.state_key): e.event_id for e in old_state}, {(e.type, e.state_key): e.event_id for e in old_state},
) )
self.store.register_event_id_state_group(prev_event_id, group_name) self.store.register_event_id_state_group(prev_event_id, group_name)
@ -489,8 +432,7 @@ class StateTestCase(unittest.TestCase):
current_state_ids = yield context.get_current_state_ids(self.store) current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual( self.assertEqual(
set([e.event_id for e in old_state]), set([e.event_id for e in old_state]), set(current_state_ids.values())
set(current_state_ids.values())
) )
self.assertEqual(group_name, context.state_group) self.assertEqual(group_name, context.state_group)
@ -499,8 +441,7 @@ class StateTestCase(unittest.TestCase):
def test_trivial_annotate_state(self): def test_trivial_annotate_state(self):
prev_event_id = "prev_event_id" prev_event_id = "prev_event_id"
event = create_event( event = create_event(
type="state", state_key="", name="event2", type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]
prev_events=[(prev_event_id, {})],
) )
old_state = [ old_state = [
@ -510,7 +451,10 @@ class StateTestCase(unittest.TestCase):
] ]
group_name = self.store.store_state_group( group_name = self.store.store_state_group(
prev_event_id, event.room_id, None, None, prev_event_id,
event.room_id,
None,
None,
{(e.type, e.state_key): e.event_id for e in old_state}, {(e.type, e.state_key): e.event_id for e in old_state},
) )
self.store.register_event_id_state_group(prev_event_id, group_name) self.store.register_event_id_state_group(prev_event_id, group_name)
@ -520,8 +464,7 @@ class StateTestCase(unittest.TestCase):
prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = yield context.get_prev_state_ids(self.store)
self.assertEqual( self.assertEqual(
set([e.event_id for e in old_state]), set([e.event_id for e in old_state]), set(prev_state_ids.values())
set(prev_state_ids.values())
) )
self.assertIsNotNone(context.state_group) self.assertIsNotNone(context.state_group)
@ -531,13 +474,12 @@ class StateTestCase(unittest.TestCase):
prev_event_id1 = "event_id1" prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2" prev_event_id2 = "event_id2"
event = create_event( event = create_event(
type="test_message", name="event3", type="test_message",
name="event3",
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
) )
creation = create_event( creation = create_event(type=EventTypes.Create, state_key="")
type=EventTypes.Create, state_key=""
)
old_state_1 = [ old_state_1 = [
creation, creation,
@ -557,7 +499,7 @@ class StateTestCase(unittest.TestCase):
self.store.register_events(old_state_2) self.store.register_events(old_state_2)
context = yield self._get_context( context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
) )
current_state_ids = yield context.get_current_state_ids(self.store) current_state_ids = yield context.get_current_state_ids(self.store)
@ -571,13 +513,13 @@ class StateTestCase(unittest.TestCase):
prev_event_id1 = "event_id1" prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2" prev_event_id2 = "event_id2"
event = create_event( event = create_event(
type="test4", state_key="", name="event", type="test4",
state_key="",
name="event",
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
) )
creation = create_event( creation = create_event(type=EventTypes.Create, state_key="")
type=EventTypes.Create, state_key=""
)
old_state_1 = [ old_state_1 = [
creation, creation,
@ -599,7 +541,7 @@ class StateTestCase(unittest.TestCase):
self.store.get_events = store.get_events self.store.get_events = store.get_events
context = yield self._get_context( context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
) )
current_state_ids = yield context.get_current_state_ids(self.store) current_state_ids = yield context.get_current_state_ids(self.store)
@ -613,29 +555,25 @@ class StateTestCase(unittest.TestCase):
prev_event_id1 = "event_id1" prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2" prev_event_id2 = "event_id2"
event = create_event( event = create_event(
type="test4", name="event", type="test4",
name="event",
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})], prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
) )
member_event = create_event( member_event = create_event(
type=EventTypes.Member, type=EventTypes.Member,
state_key="@user_id:example.com", state_key="@user_id:example.com",
content={ content={"membership": Membership.JOIN},
"membership": Membership.JOIN,
}
) )
power_levels = create_event( power_levels = create_event(
type=EventTypes.PowerLevels, state_key="", type=EventTypes.PowerLevels,
content={"users": { state_key="",
"@foo:bar": "100", content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}},
"@user_id:example.com": "100",
}}
) )
creation = create_event( creation = create_event(
type=EventTypes.Create, state_key="", type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}
content={"creator": "@foo:bar"}
) )
old_state_1 = [ old_state_1 = [
@ -658,14 +596,12 @@ class StateTestCase(unittest.TestCase):
self.store.get_events = store.get_events self.store.get_events = store.get_events
context = yield self._get_context( context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
) )
current_state_ids = yield context.get_current_state_ids(self.store) current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual( self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
old_state_2[3].event_id, current_state_ids[("test1", "1")]
)
# Reverse the depth to make sure we are actually using the depths # Reverse the depth to make sure we are actually using the depths
# during state resolution. # during state resolution.
@ -688,25 +624,30 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_2) store.register_events(old_state_2)
context = yield self._get_context( context = yield self._get_context(
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2, event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
) )
current_state_ids = yield context.get_current_state_ids(self.store) current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual( self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
old_state_1[3].event_id, current_state_ids[("test1", "1")]
)
def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2, def _get_context(
old_state_2): self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
):
sg1 = self.store.store_state_group( sg1 = self.store.store_state_group(
prev_event_id_1, event.room_id, None, None, prev_event_id_1,
event.room_id,
None,
None,
{(e.type, e.state_key): e.event_id for e in old_state_1}, {(e.type, e.state_key): e.event_id for e in old_state_1},
) )
self.store.register_event_id_state_group(prev_event_id_1, sg1) self.store.register_event_id_state_group(prev_event_id_1, sg1)
sg2 = self.store.store_state_group( sg2 = self.store.store_state_group(
prev_event_id_2, event.room_id, None, None, prev_event_id_2,
event.room_id,
None,
None,
{(e.type, e.state_key): e.event_id for e in old_state_2}, {(e.type, e.state_key): e.event_id for e in old_state_2},
) )
self.store.register_event_id_state_group(prev_event_id_2, sg2) self.store.register_event_id_state_group(prev_event_id_2, sg2)

View File

@ -18,7 +18,6 @@ from tests.utils import MockClock
class MockClockTestCase(unittest.TestCase): class MockClockTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.clock = MockClock() self.clock = MockClock()
@ -34,10 +33,12 @@ class MockClockTestCase(unittest.TestCase):
def _cb0(): def _cb0():
invoked[0] = 1 invoked[0] = 1
self.clock.call_later(10, _cb0) self.clock.call_later(10, _cb0)
def _cb1(): def _cb1():
invoked[1] = 1 invoked[1] = 1
self.clock.call_later(20, _cb1) self.clock.call_later(20, _cb1)
self.assertFalse(invoked[0]) self.assertFalse(invoked[0])
@ -56,10 +57,12 @@ class MockClockTestCase(unittest.TestCase):
def _cb0(): def _cb0():
invoked[0] = 1 invoked[0] = 1
t0 = self.clock.call_later(10, _cb0) t0 = self.clock.call_later(10, _cb0)
def _cb1(): def _cb1():
invoked[1] = 1 invoked[1] = 1
self.clock.call_later(20, _cb1) self.clock.call_later(20, _cb1)
self.clock.cancel_call_later(t0) self.clock.cancel_call_later(t0)

View File

@ -69,10 +69,7 @@ class GroupIDTestCase(unittest.TestCase):
self.assertEqual("my.domain", group_id.domain) self.assertEqual("my.domain", group_id.domain)
def test_validate(self): def test_validate(self):
bad_ids = [ bad_ids = ["$badsigil:domain", "+:empty"] + [
"$badsigil:domain",
"+:empty",
] + [
"+group" + c + ":domain" for c in "A%?æ£" "+group" + c + ":domain" for c in "A%?æ£"
] ]
for id_string in bad_ids: for id_string in bad_ids:

View File

@ -54,14 +54,12 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter = [] events_to_filter = []
for i in range(0, 10): for i in range(0, 10):
user = "@user%i:%s" % ( user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
i, "test_server" if i == 5 else "other_server"
)
evt = yield self.inject_room_member(user, extra_content={"a": "b"}) evt = yield self.inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt) events_to_filter.append(evt)
filtered = yield filter_events_for_server( filtered = yield filter_events_for_server(
self.store, "test_server", events_to_filter, self.store, "test_server", events_to_filter
) )
# the result should be 5 redacted events, and 5 unredacted events. # the result should be 5 redacted events, and 5 unredacted events.
@ -100,19 +98,21 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
# ... and the filtering happens. # ... and the filtering happens.
filtered = yield filter_events_for_server( filtered = yield filter_events_for_server(
self.store, "test_server", events_to_filter, self.store, "test_server", events_to_filter
) )
for i in range(0, len(events_to_filter)): for i in range(0, len(events_to_filter)):
self.assertEqual( self.assertEqual(
events_to_filter[i].event_id, filtered[i].event_id, events_to_filter[i].event_id,
"Unexpected event at result position %i" % (i, ) filtered[i].event_id,
"Unexpected event at result position %i" % (i,),
) )
for i in (0, 3): for i in (0, 3):
self.assertEqual( self.assertEqual(
events_to_filter[i].content["body"], filtered[i].content["body"], events_to_filter[i].content["body"],
"Unexpected event content at result position %i" % (i,) filtered[i].content["body"],
"Unexpected event content at result position %i" % (i,),
) )
for i in (1, 4): for i in (1, 4):
@ -121,13 +121,15 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_visibility(self, user_id, visibility): def inject_visibility(self, user_id, visibility):
content = {"history_visibility": visibility} content = {"history_visibility": visibility}
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new(
{
"type": "m.room.history_visibility", "type": "m.room.history_visibility",
"sender": user_id, "sender": user_id,
"state_key": "", "state_key": "",
"room_id": TEST_ROOM_ID, "room_id": TEST_ROOM_ID,
"content": content, "content": content,
}) }
)
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
@ -139,13 +141,15 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
def inject_room_member(self, user_id, membership="join", extra_content={}): def inject_room_member(self, user_id, membership="join", extra_content={}):
content = {"membership": membership} content = {"membership": membership}
content.update(extra_content) content.update(extra_content)
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new(
{
"type": "m.room.member", "type": "m.room.member",
"sender": user_id, "sender": user_id,
"state_key": user_id, "state_key": user_id,
"room_id": TEST_ROOM_ID, "room_id": TEST_ROOM_ID,
"content": content, "content": content,
}) }
)
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
@ -158,12 +162,14 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
def inject_message(self, user_id, content=None): def inject_message(self, user_id, content=None):
if content is None: if content is None:
content = {"body": "testytest"} content = {"body": "testytest"}
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new(
{
"type": "m.room.message", "type": "m.room.message",
"sender": user_id, "sender": user_id,
"room_id": TEST_ROOM_ID, "room_id": TEST_ROOM_ID,
"content": content, "content": content,
}) }
)
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
@ -192,56 +198,54 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
# history_visibility event. # history_visibility event.
room_state = [] room_state = []
history_visibility_evt = FrozenEvent({ history_visibility_evt = FrozenEvent(
{
"event_id": "$history_vis", "event_id": "$history_vis",
"type": "m.room.history_visibility", "type": "m.room.history_visibility",
"sender": "@resident_user_0:test.com", "sender": "@resident_user_0:test.com",
"state_key": "", "state_key": "",
"room_id": TEST_ROOM_ID, "room_id": TEST_ROOM_ID,
"content": {"history_visibility": "joined"}, "content": {"history_visibility": "joined"},
}) }
)
room_state.append(history_visibility_evt) room_state.append(history_visibility_evt)
test_store.add_event(history_visibility_evt) test_store.add_event(history_visibility_evt)
for i in range(0, 100000): for i in range(0, 100000):
user = "@resident_user_%i:test.com" % (i,) user = "@resident_user_%i:test.com" % (i,)
evt = FrozenEvent({ evt = FrozenEvent(
{
"event_id": "$res_event_%i" % (i,), "event_id": "$res_event_%i" % (i,),
"type": "m.room.member", "type": "m.room.member",
"state_key": user, "state_key": user,
"sender": user, "sender": user,
"room_id": TEST_ROOM_ID, "room_id": TEST_ROOM_ID,
"content": { "content": {"membership": "join", "extra": "zzz,"},
"membership": "join", }
"extra": "zzz," )
},
})
room_state.append(evt) room_state.append(evt)
test_store.add_event(evt) test_store.add_event(evt)
events_to_filter = [] events_to_filter = []
for i in range(0, 10): for i in range(0, 10):
user = "@user%i:%s" % ( user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
i, "test_server" if i == 5 else "other_server" evt = FrozenEvent(
) {
evt = FrozenEvent({
"event_id": "$evt%i" % (i,), "event_id": "$evt%i" % (i,),
"type": "m.room.member", "type": "m.room.member",
"state_key": user, "state_key": user,
"sender": user, "sender": user,
"room_id": TEST_ROOM_ID, "room_id": TEST_ROOM_ID,
"content": { "content": {"membership": "join", "extra": "zzz"},
"membership": "join", }
"extra": "zzz", )
},
})
events_to_filter.append(evt) events_to_filter.append(evt)
room_state.append(evt) room_state.append(evt)
test_store.add_event(evt) test_store.add_event(evt)
test_store.set_state_ids_for_event(evt, { test_store.set_state_ids_for_event(
(e.type, e.state_key): e.event_id for e in room_state evt, {(e.type, e.state_key): e.event_id for e in room_state}
}) )
pr = cProfile.Profile() pr = cProfile.Profile()
pr.enable() pr.enable()
@ -249,7 +253,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
logger.info("Starting filtering") logger.info("Starting filtering")
start = time.time() start = time.time()
filtered = yield filter_events_for_server( filtered = yield filter_events_for_server(
test_store, "test_server", events_to_filter, test_store, "test_server", events_to_filter
) )
logger.info("Filtering took %f seconds", time.time() - start) logger.info("Filtering took %f seconds", time.time() - start)
@ -275,6 +279,7 @@ class _TestStore(object):
filter_events_for_server filter_events_for_server
""" """
def __init__(self): def __init__(self):
# data for get_events: a map from event_id to event # data for get_events: a map from event_id to event
self.events = {} self.events = {}
@ -298,8 +303,8 @@ class _TestStore(object):
continue continue
if type != "m.room.member" or state_key is not None: if type != "m.room.member" or state_key is not None:
raise RuntimeError( raise RuntimeError(
"Unimplemented: get_state_ids with type (%s, %s)" % "Unimplemented: get_state_ids with type (%s, %s)"
(type, state_key), % (type, state_key)
) )
include_memberships = True include_memberships = True
@ -316,9 +321,7 @@ class _TestStore(object):
return succeed(res) return succeed(res)
def get_events(self, events): def get_events(self, events):
return succeed({ return succeed({event_id: self.events[event_id] for event_id in events})
event_id: self.events[event_id] for event_id in events
})
def are_users_erased(self, users): def are_users_erased(self, users):
return succeed({u: False for u in users}) return succeed({u: False for u in users})

View File

@ -56,6 +56,7 @@ def around(target):
def method_name(orig, *args, **kwargs): def method_name(orig, *args, **kwargs):
return orig(*args, **kwargs) return orig(*args, **kwargs)
""" """
def _around(code): def _around(code):
name = code.__name__ name = code.__name__
orig = getattr(target, name) orig = getattr(target, name)
@ -89,6 +90,7 @@ class TestCase(unittest.TestCase):
old_level = logging.getLogger().level old_level = logging.getLogger().level
if old_level != level: if old_level != level:
@around(self) @around(self)
def tearDown(orig): def tearDown(orig):
ret = orig() ret = orig()
@ -117,8 +119,9 @@ class TestCase(unittest.TestCase):
actual (dict): The test result. Extra keys will not be checked. actual (dict): The test result. Extra keys will not be checked.
""" """
for key in required: for key in required:
self.assertEquals(required[key], actual[key], self.assertEquals(
msg="%s mismatch. %s" % (key, actual)) required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
)
def DEBUG(target): def DEBUG(target):

View File

@ -67,12 +67,8 @@ class CacheTestCase(unittest.TestCase):
self.assertIsNone(cache.get("key2", None)) self.assertIsNone(cache.get("key2", None))
# both callbacks should have been callbacked # both callbacks should have been callbacked
self.assertTrue( self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
callback_record[0], "Invalidation callback for key1 not called", self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
)
self.assertTrue(
callback_record[1], "Invalidation callback for key2 not called",
)
# letting the other lookup complete should do nothing # letting the other lookup complete should do nothing
d1.callback("result1") d1.callback("result1")
@ -168,8 +164,7 @@ class DescriptorTestCase(unittest.TestCase):
with logcontext.LoggingContext() as c1: with logcontext.LoggingContext() as c1:
c1.name = "c1" c1.name = "c1"
r = yield obj.fn(1) r = yield obj.fn(1)
self.assertEqual(logcontext.LoggingContext.current_context(), self.assertEqual(logcontext.LoggingContext.current_context(), c1)
c1)
defer.returnValue(r) defer.returnValue(r)
def check_result(r): def check_result(r):
@ -179,14 +174,18 @@ class DescriptorTestCase(unittest.TestCase):
# set off a deferred which will do a cache lookup # set off a deferred which will do a cache lookup
d1 = do_lookup() d1 = do_lookup()
self.assertEqual(logcontext.LoggingContext.current_context(), self.assertEqual(
logcontext.LoggingContext.sentinel) logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel,
)
d1.addCallback(check_result) d1.addCallback(check_result)
# and another # and another
d2 = do_lookup() d2 = do_lookup()
self.assertEqual(logcontext.LoggingContext.current_context(), self.assertEqual(
logcontext.LoggingContext.sentinel) logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel,
)
d2.addCallback(check_result) d2.addCallback(check_result)
# let the lookup complete # let the lookup complete
@ -224,15 +223,16 @@ class DescriptorTestCase(unittest.TestCase):
except SynapseError: except SynapseError:
pass pass
self.assertEqual(logcontext.LoggingContext.current_context(), self.assertEqual(logcontext.LoggingContext.current_context(), c1)
c1)
obj = Cls() obj = Cls()
# set off a deferred which will do a cache lookup # set off a deferred which will do a cache lookup
d1 = do_lookup() d1 = do_lookup()
self.assertEqual(logcontext.LoggingContext.current_context(), self.assertEqual(
logcontext.LoggingContext.sentinel) logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel,
)
return d1 return d1
@ -288,14 +288,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@descriptors.cachedList("fn", "args1", inlineCallbacks=True) @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2): def list_fn(self, args1, arg2):
assert ( assert logcontext.LoggingContext.current_context().request == "c1"
logcontext.LoggingContext.current_context().request == "c1"
)
# we want this to behave like an asynchronous function # we want this to behave like an asynchronous function
yield run_on_reactor() yield run_on_reactor()
assert ( assert logcontext.LoggingContext.current_context().request == "c1"
logcontext.LoggingContext.current_context().request == "c1"
)
defer.returnValue(self.mock(args1, arg2)) defer.returnValue(self.mock(args1, arg2))
with logcontext.LoggingContext() as c1: with logcontext.LoggingContext() as c1:
@ -308,10 +304,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
logcontext.LoggingContext.sentinel, logcontext.LoggingContext.sentinel,
) )
r = yield d1 r = yield d1
self.assertEqual( self.assertEqual(logcontext.LoggingContext.current_context(), c1)
logcontext.LoggingContext.current_context(),
c1
)
obj.mock.assert_called_once_with([10, 20], 2) obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r, {10: 'fish', 20: 'chips'}) self.assertEqual(r, {10: 'fish', 20: 'chips'})
obj.mock.reset_mock() obj.mock.reset_mock()
@ -337,6 +330,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invalidate(self): def test_invalidate(self):
"""Make sure that invalidation callbacks are called.""" """Make sure that invalidation callbacks are called."""
class Cls(object): class Cls(object):
def __init__(self): def __init__(self):
self.mock = mock.Mock() self.mock = mock.Mock()

View File

@ -20,7 +20,6 @@ from tests import unittest
class DictCacheTestCase(unittest.TestCase): class DictCacheTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.cache = DictionaryCache("foobar") self.cache = DictionaryCache("foobar")
@ -41,9 +40,7 @@ class DictCacheTestCase(unittest.TestCase):
key = "test_simple_cache_hit_partial" key = "test_simple_cache_hit_partial"
seq = self.cache.sequence seq = self.cache.sequence
test_value = { test_value = {"test": "test_simple_cache_hit_partial"}
"test": "test_simple_cache_hit_partial"
}
self.cache.update(seq, key, test_value) self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test"]) c = self.cache.get(key, ["test"])
@ -53,9 +50,7 @@ class DictCacheTestCase(unittest.TestCase):
key = "test_simple_cache_miss_partial" key = "test_simple_cache_miss_partial"
seq = self.cache.sequence seq = self.cache.sequence
test_value = { test_value = {"test": "test_simple_cache_miss_partial"}
"test": "test_simple_cache_miss_partial"
}
self.cache.update(seq, key, test_value) self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"]) c = self.cache.get(key, ["test2"])
@ -79,15 +74,11 @@ class DictCacheTestCase(unittest.TestCase):
key = "test_simple_cache_hit_miss_partial" key = "test_simple_cache_hit_miss_partial"
seq = self.cache.sequence seq = self.cache.sequence
test_value_1 = { test_value_1 = {"test": "test_simple_cache_hit_miss_partial"}
"test": "test_simple_cache_hit_miss_partial",
}
self.cache.update(seq, key, test_value_1, fetched_keys=set("test")) self.cache.update(seq, key, test_value_1, fetched_keys=set("test"))
seq = self.cache.sequence seq = self.cache.sequence
test_value_2 = { test_value_2 = {"test2": "test_simple_cache_hit_miss_partial2"}
"test2": "test_simple_cache_hit_miss_partial2",
}
self.cache.update(seq, key, test_value_2, fetched_keys=set("test2")) self.cache.update(seq, key, test_value_2, fetched_keys=set("test2"))
c = self.cache.get(key) c = self.cache.get(key)
@ -96,5 +87,5 @@ class DictCacheTestCase(unittest.TestCase):
"test": "test_simple_cache_hit_miss_partial", "test": "test_simple_cache_hit_miss_partial",
"test2": "test_simple_cache_hit_miss_partial2", "test2": "test_simple_cache_hit_miss_partial2",
}, },
c.value c.value,
) )

View File

@ -22,7 +22,6 @@ from .. import unittest
class ExpiringCacheTestCase(unittest.TestCase): class ExpiringCacheTestCase(unittest.TestCase):
def test_get_set(self): def test_get_set(self):
clock = MockClock() clock = MockClock()
cache = ExpiringCache("test", clock, max_len=1) cache = ExpiringCache("test", clock, max_len=1)

View File

@ -27,7 +27,6 @@ from tests import unittest
class FileConsumerTests(unittest.TestCase): class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_pull_consumer(self): def test_pull_consumer(self):
string_file = StringIO() string_file = StringIO()
@ -87,7 +86,9 @@ class FileConsumerTests(unittest.TestCase):
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
resume_deferred = defer.Deferred() resume_deferred = defer.Deferred()
producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None) producer.resumeProducing.side_effect = lambda: resume_deferred.callback(
None
)
consumer.registerProducer(producer, True) consumer.registerProducer(producer, True)

View File

@ -26,7 +26,6 @@ from tests import unittest
class LinearizerTestCase(unittest.TestCase): class LinearizerTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_linearizer(self): def test_linearizer(self):
linearizer = Linearizer() linearizer = Linearizer()
@ -54,13 +53,11 @@ class LinearizerTestCase(unittest.TestCase):
def func(i, sleep=False): def func(i, sleep=False):
with logcontext.LoggingContext("func(%s)" % i) as lc: with logcontext.LoggingContext("func(%s)" % i) as lc:
with (yield linearizer.queue("")): with (yield linearizer.queue("")):
self.assertEqual( self.assertEqual(logcontext.LoggingContext.current_context(), lc)
logcontext.LoggingContext.current_context(), lc)
if sleep: if sleep:
yield Clock(reactor).sleep(0) yield Clock(reactor).sleep(0)
self.assertEqual( self.assertEqual(logcontext.LoggingContext.current_context(), lc)
logcontext.LoggingContext.current_context(), lc)
func(0, sleep=True) func(0, sleep=True)
for i in range(1, 100): for i in range(1, 100):

View File

@ -8,11 +8,8 @@ from .. import unittest
class LoggingContextTestCase(unittest.TestCase): class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value): def _check_test_key(self, value):
self.assertEquals( self.assertEquals(LoggingContext.current_context().request, value)
LoggingContext.current_context().request, value
)
def test_with_context(self): def test_with_context(self):
with LoggingContext() as context_one: with LoggingContext() as context_one:
@ -50,6 +47,7 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one") self._check_test_key("one")
callback_completed[0] = True callback_completed[0] = True
return res return res
d.addCallback(cb) d.addCallback(cb)
return d return d
@ -74,8 +72,7 @@ class LoggingContextTestCase(unittest.TestCase):
# make sure that the context was reset before it got thrown back # make sure that the context was reset before it got thrown back
# into the reactor # into the reactor
try: try:
self.assertIs(LoggingContext.current_context(), self.assertIs(LoggingContext.current_context(), sentinel_context)
sentinel_context)
d2.callback(None) d2.callback(None)
except BaseException: except BaseException:
d2.errback(twisted.python.failure.Failure()) d2.errback(twisted.python.failure.Failure())
@ -104,9 +101,7 @@ class LoggingContextTestCase(unittest.TestCase):
# a function which returns a deferred which looks like it has been # a function which returns a deferred which looks like it has been
# called, but is actually paused # called, but is actually paused
def testfunc(): def testfunc():
return logcontext.make_deferred_yieldable( return logcontext.make_deferred_yieldable(_chained_deferred_function())
_chained_deferred_function()
)
return self._test_run_in_background(testfunc) return self._test_run_in_background(testfunc)
@ -175,5 +170,6 @@ def _chained_deferred_function():
d2 = defer.Deferred() d2 = defer.Deferred()
reactor.callLater(0, d2.callback, res) reactor.callLater(0, d2.callback, res)
return d2 return d2
d.addCallback(cb) d.addCallback(cb)
return d return d

View File

@ -23,7 +23,6 @@ from .. import unittest
class LruCacheTestCase(unittest.TestCase): class LruCacheTestCase(unittest.TestCase):
def test_get_set(self): def test_get_set(self):
cache = LruCache(1) cache = LruCache(1)
cache["key"] = "value" cache["key"] = "value"
@ -235,7 +234,6 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
class LruCacheSizedTestCase(unittest.TestCase): class LruCacheSizedTestCase(unittest.TestCase):
def test_evict(self): def test_evict(self):
cache = LruCache(5, size_callback=len) cache = LruCache(5, size_callback=len)
cache["key1"] = [0] cache["key1"] = [0]

View File

@ -20,7 +20,6 @@ from tests import unittest
class ReadWriteLockTestCase(unittest.TestCase): class ReadWriteLockTestCase(unittest.TestCase):
def _assert_called_before_not_after(self, lst, first_false): def _assert_called_before_not_after(self, lst, first_false):
for i, d in enumerate(lst[:first_false]): for i, d in enumerate(lst[:first_false]):
self.assertTrue(d.called, msg="%d was unexpectedly false" % i) self.assertTrue(d.called, msg="%d was unexpectedly false" % i)

View File

@ -22,7 +22,6 @@ from .. import unittest
class SnapshotCacheTestCase(unittest.TestCase): class SnapshotCacheTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.cache = SnapshotCache() self.cache = SnapshotCache()
self.cache.DURATION_MS = 1 self.cache.DURATION_MS = 1

View File

@ -181,17 +181,8 @@ class StreamChangeCacheTests(unittest.TestCase):
# Query a subset of the entries mid-way through the stream. We should # Query a subset of the entries mid-way through the stream. We should
# only get back the subset. # only get back the subset.
self.assertEqual( self.assertEqual(
cache.get_entities_changed( cache.get_entities_changed(["bar@baz.net"], stream_pos=2),
[ set(["bar@baz.net"]),
"bar@baz.net",
],
stream_pos=2,
),
set(
[
"bar@baz.net",
]
),
) )
def test_max_pos(self): def test_max_pos(self):

View File

@ -38,8 +38,9 @@ USE_POSTGRES_FOR_TESTS = False
@defer.inlineCallbacks @defer.inlineCallbacks
def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None, def setup_test_homeserver(
**kargs): name="test", datastore=None, config=None, reactor=None, **kargs
):
"""Setup a homeserver suitable for running tests against. Keyword arguments """Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor. If no datastore is supplied a are passed to the Homeserver constructor. If no datastore is supplied a
datastore backed by an in-memory sqlite db will be given to the HS. datastore backed by an in-memory sqlite db will be given to the HS.
@ -96,20 +97,12 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
if USE_POSTGRES_FOR_TESTS: if USE_POSTGRES_FOR_TESTS:
config.database_config = { config.database_config = {
"name": "psycopg2", "name": "psycopg2",
"args": { "args": {"database": "synapse_test", "cp_min": 1, "cp_max": 5},
"database": "synapse_test",
"cp_min": 1,
"cp_max": 5,
},
} }
else: else:
config.database_config = { config.database_config = {
"name": "sqlite3", "name": "sqlite3",
"args": { "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
"database": ":memory:",
"cp_min": 1,
"cp_max": 1,
},
} }
db_engine = create_engine(config.database_config) db_engine = create_engine(config.database_config)
@ -121,7 +114,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
if datastore is None: if datastore is None:
hs = HomeServer( hs = HomeServer(
name, config=config, name,
config=config,
db_config=config.database_config, db_config=config.database_config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=db_engine, database_engine=db_engine,
@ -143,7 +137,10 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
hs.setup() hs.setup()
else: else:
hs = HomeServer( hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config, name,
db_pool=None,
datastore=datastore,
config=config,
version_string="Synapse/tests", version_string="Synapse/tests",
database_engine=db_engine, database_engine=db_engine,
room_list_handler=object(), room_list_handler=object(),
@ -158,8 +155,9 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
# because AuthHandler's constructor requires the HS, so we can't make one # because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg) # beforehand and pass it in to the HS's constructor (chicken / egg)
hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest() hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest()
hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5( hs.get_auth_handler().validate_hash = (
p.encode('utf8')).hexdigest() == h lambda p, h: hashlib.md5(p.encode('utf8')).hexdigest() == h
)
fed = kargs.get("resource_for_federation", None) fed = kargs.get("resource_for_federation", None)
if fed: if fed:
@ -173,7 +171,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
sleep_limit=hs.config.federation_rc_sleep_limit, sleep_limit=hs.config.federation_rc_sleep_limit,
sleep_msec=hs.config.federation_rc_sleep_delay, sleep_msec=hs.config.federation_rc_sleep_delay,
reject_limit=hs.config.federation_rc_reject_limit, reject_limit=hs.config.federation_rc_reject_limit,
concurrent_requests=hs.config.federation_rc_concurrent concurrent_requests=hs.config.federation_rc_concurrent,
), ),
) )
@ -199,7 +197,6 @@ def mock_getRawHeaders(headers=None):
# This is a mock /resource/ not an entire server # This is a mock /resource/ not an entire server
class MockHttpResource(HttpServer): class MockHttpResource(HttpServer):
def __init__(self, prefix=""): def __init__(self, prefix=""):
self.callbacks = [] # 3-tuple of method/pattern/function self.callbacks = [] # 3-tuple of method/pattern/function
self.prefix = prefix self.prefix = prefix
@ -263,15 +260,9 @@ class MockHttpResource(HttpServer):
matcher = pattern.match(path) matcher = pattern.match(path)
if matcher: if matcher:
try: try:
args = [ args = [urlparse.unquote(u) for u in matcher.groups()]
urlparse.unquote(u)
for u in matcher.groups()
]
(code, response) = yield func( (code, response) = yield func(mock_request, *args)
mock_request,
*args
)
defer.returnValue((code, response)) defer.returnValue((code, response))
except CodeMessageException as e: except CodeMessageException as e:
defer.returnValue((e.code, cs_error(e.msg, code=e.errcode))) defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
@ -372,8 +363,7 @@ class MockClock(object):
def _format_call(args, kwargs): def _format_call(args, kwargs):
return ", ".join( return ", ".join(
["%r" % (a) for a in args] + ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
["%s=%r" % (k, v) for k, v in kwargs.items()]
) )
@ -391,8 +381,9 @@ class DeferredMockCallable(object):
self.calls.append((args, kwargs)) self.calls.append((args, kwargs))
if not self.expectations: if not self.expectations:
raise ValueError("%r has no pending calls to handle call(%s)" % ( raise ValueError(
self, _format_call(args, kwargs)) "%r has no pending calls to handle call(%s)"
% (self, _format_call(args, kwargs))
) )
for (call, result, d) in self.expectations: for (call, result, d) in self.expectations:
@ -400,9 +391,9 @@ class DeferredMockCallable(object):
d.callback(None) d.callback(None)
return result return result
failure = AssertionError("Was not expecting call(%s)" % ( failure = AssertionError(
_format_call(args, kwargs) "Was not expecting call(%s)" % (_format_call(args, kwargs))
)) )
for _, _, d in self.expectations: for _, _, d in self.expectations:
try: try:
@ -418,17 +409,19 @@ class DeferredMockCallable(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def await_calls(self, timeout=1000): def await_calls(self, timeout=1000):
deferred = defer.DeferredList( deferred = defer.DeferredList(
[d for _, _, d in self.expectations], [d for _, _, d in self.expectations], fireOnOneErrback=True
fireOnOneErrback=True
) )
timer = reactor.callLater( timer = reactor.callLater(
timeout / 1000, timeout / 1000,
deferred.errback, deferred.errback,
AssertionError("%d pending calls left: %s" % ( AssertionError(
"%d pending calls left: %s"
% (
len([e for e in self.expectations if not e[2].called]), len([e for e in self.expectations if not e[2].called]),
[e for e in self.expectations if not e[2].called] [e for e in self.expectations if not e[2].called],
)) )
),
) )
yield deferred yield deferred
@ -443,7 +436,6 @@ class DeferredMockCallable(object):
self.calls = [] self.calls = []
raise AssertionError( raise AssertionError(
"Expected not to received any calls, got:\n" + "\n".join([ "Expected not to received any calls, got:\n"
"call(%s)" % _format_call(c[0], c[1]) for c in calls + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
])
) )