mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-10-01 08:25:44 -04:00
Port storage/ to Python 3 (#3725)
This commit is contained in:
parent
475253a88e
commit
14e4d4f4bf
1
changelog.d/3725.misc
Normal file
1
changelog.d/3725.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
The synapse.storage module has been ported to Python 3.
|
@ -31,5 +31,5 @@ $TOX_BIN/pip install 'setuptools>=18.5'
|
|||||||
$TOX_BIN/pip install 'pip>=10'
|
$TOX_BIN/pip install 'pip>=10'
|
||||||
|
|
||||||
{ python synapse/python_dependencies.py
|
{ python synapse/python_dependencies.py
|
||||||
echo lxml psycopg2
|
echo lxml
|
||||||
} | xargs $TOX_BIN/pip install
|
} | xargs $TOX_BIN/pip install
|
||||||
|
@ -78,6 +78,9 @@ CONDITIONAL_REQUIREMENTS = {
|
|||||||
"affinity": {
|
"affinity": {
|
||||||
"affinity": ["affinity"],
|
"affinity": ["affinity"],
|
||||||
},
|
},
|
||||||
|
"postgres": {
|
||||||
|
"psycopg2>=2.6": ["psycopg2"]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,9 +17,10 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from six import iteritems, iterkeys, itervalues
|
from six import PY2, iteritems, iterkeys, itervalues
|
||||||
from six.moves import intern, range
|
from six.moves import intern, range
|
||||||
|
|
||||||
|
from canonicaljson import json
|
||||||
from prometheus_client import Histogram
|
from prometheus_client import Histogram
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception):
|
|||||||
something went wrong.
|
something went wrong.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def db_to_json(db_content):
|
||||||
|
"""
|
||||||
|
Take some data from a database row and return a JSON-decoded object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_content (memoryview|buffer|bytes|bytearray|unicode)
|
||||||
|
"""
|
||||||
|
# psycopg2 on Python 3 returns memoryview objects, which we need to
|
||||||
|
# cast to bytes to decode
|
||||||
|
if isinstance(db_content, memoryview):
|
||||||
|
db_content = db_content.tobytes()
|
||||||
|
|
||||||
|
# psycopg2 on Python 2 returns buffer objects, which we need to cast to
|
||||||
|
# bytes to decode
|
||||||
|
if PY2 and isinstance(db_content, buffer):
|
||||||
|
db_content = bytes(db_content)
|
||||||
|
|
||||||
|
# Decode it to a Unicode string before feeding it to json.loads, so we
|
||||||
|
# consistenty get a Unicode-containing object out.
|
||||||
|
if isinstance(db_content, (bytes, bytearray)):
|
||||||
|
db_content = db_content.decode('utf8')
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(db_content)
|
||||||
|
except Exception:
|
||||||
|
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
|
||||||
|
raise
|
||||||
|
@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
|||||||
local_by_user_then_device = {}
|
local_by_user_then_device = {}
|
||||||
for user_id, messages_by_device in messages_by_user_then_device.items():
|
for user_id, messages_by_device in messages_by_user_then_device.items():
|
||||||
messages_json_for_user = {}
|
messages_json_for_user = {}
|
||||||
devices = messages_by_device.keys()
|
devices = list(messages_by_device.keys())
|
||||||
if len(devices) == 1 and devices[0] == "*":
|
if len(devices) == 1 and devices[0] == "*":
|
||||||
# Handle wildcard device_ids.
|
# Handle wildcard device_ids.
|
||||||
sql = (
|
sql = (
|
||||||
|
@ -24,7 +24,7 @@ from synapse.api.errors import StoreError
|
|||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
||||||
|
|
||||||
from ._base import Cache, SQLBaseStore
|
from ._base import Cache, SQLBaseStore, db_to_json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore):
|
|||||||
if device is not None:
|
if device is not None:
|
||||||
key_json = device.get("key_json", None)
|
key_json = device.get("key_json", None)
|
||||||
if key_json:
|
if key_json:
|
||||||
result["keys"] = json.loads(key_json)
|
result["keys"] = db_to_json(key_json)
|
||||||
device_display_name = device.get("device_display_name", None)
|
device_display_name = device.get("device_display_name", None)
|
||||||
if device_display_name:
|
if device_display_name:
|
||||||
result["device_display_name"] = device_display_name
|
result["device_display_name"] = device_display_name
|
||||||
@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore):
|
|||||||
retcol="content",
|
retcol="content",
|
||||||
desc="_get_cached_user_device",
|
desc="_get_cached_user_device",
|
||||||
)
|
)
|
||||||
defer.returnValue(json.loads(content))
|
defer.returnValue(db_to_json(content))
|
||||||
|
|
||||||
@cachedInlineCallbacks()
|
@cachedInlineCallbacks()
|
||||||
def _get_cached_devices_for_user(self, user_id):
|
def _get_cached_devices_for_user(self, user_id):
|
||||||
@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore):
|
|||||||
desc="_get_cached_devices_for_user",
|
desc="_get_cached_devices_for_user",
|
||||||
)
|
)
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
device["device_id"]: json.loads(device["content"])
|
device["device_id"]: db_to_json(device["content"])
|
||||||
for device in devices
|
for device in devices
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore):
|
|||||||
|
|
||||||
key_json = device.get("key_json", None)
|
key_json = device.get("key_json", None)
|
||||||
if key_json:
|
if key_json:
|
||||||
result["keys"] = json.loads(key_json)
|
result["keys"] = db_to_json(key_json)
|
||||||
device_display_name = device.get("device_display_name", None)
|
device_display_name = device.get("device_display_name", None)
|
||||||
if device_display_name:
|
if device_display_name:
|
||||||
result["device_display_name"] = device_display_name
|
result["device_display_name"] = device_display_name
|
||||||
|
@ -14,13 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from six import iteritems
|
from six import iteritems
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json, json
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore, db_to_json
|
||||||
|
|
||||||
|
|
||||||
class EndToEndKeyStore(SQLBaseStore):
|
class EndToEndKeyStore(SQLBaseStore):
|
||||||
@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||||||
|
|
||||||
for user_id, device_keys in iteritems(results):
|
for user_id, device_keys in iteritems(results):
|
||||||
for device_id, device_info in iteritems(device_keys):
|
for device_id, device_info in iteritems(device_keys):
|
||||||
device_info["keys"] = json.loads(device_info.pop("key_json"))
|
device_info["keys"] = db_to_json(device_info.pop("key_json"))
|
||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
@ -41,13 +41,18 @@ class PostgresEngine(object):
|
|||||||
db_conn.set_isolation_level(
|
db_conn.set_isolation_level(
|
||||||
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set the bytea output to escape, vs the default of hex
|
||||||
|
cursor = db_conn.cursor()
|
||||||
|
cursor.execute("SET bytea_output TO escape")
|
||||||
|
|
||||||
# Asynchronous commit, don't wait for the server to call fsync before
|
# Asynchronous commit, don't wait for the server to call fsync before
|
||||||
# ending the transaction.
|
# ending the transaction.
|
||||||
# https://www.postgresql.org/docs/current/static/wal-async-commit.html
|
# https://www.postgresql.org/docs/current/static/wal-async-commit.html
|
||||||
if not self.synchronous_commit:
|
if not self.synchronous_commit:
|
||||||
cursor = db_conn.cursor()
|
|
||||||
cursor.execute("SET synchronous_commit TO OFF")
|
cursor.execute("SET synchronous_commit TO OFF")
|
||||||
cursor.close()
|
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
def is_deadlock(self, error):
|
def is_deadlock(self, error):
|
||||||
if isinstance(error, self.module.DatabaseError):
|
if isinstance(error, self.module.DatabaseError):
|
||||||
|
@ -19,7 +19,7 @@ import logging
|
|||||||
from collections import OrderedDict, deque, namedtuple
|
from collections import OrderedDict, deque, namedtuple
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
from six import iteritems
|
from six import iteritems, text_type
|
||||||
from six.moves import range
|
from six.moves import range
|
||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
@ -1220,7 +1220,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
|||||||
"sender": event.sender,
|
"sender": event.sender,
|
||||||
"contains_url": (
|
"contains_url": (
|
||||||
"url" in event.content
|
"url" in event.content
|
||||||
and isinstance(event.content["url"], basestring)
|
and isinstance(event.content["url"], text_type)
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
for event, _ in events_and_contexts
|
for event, _ in events_and_contexts
|
||||||
@ -1529,7 +1529,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
|||||||
|
|
||||||
contains_url = "url" in content
|
contains_url = "url" in content
|
||||||
if contains_url:
|
if contains_url:
|
||||||
contains_url &= isinstance(content["url"], basestring)
|
contains_url &= isinstance(content["url"], text_type)
|
||||||
except (KeyError, AttributeError):
|
except (KeyError, AttributeError):
|
||||||
# If the event is missing a necessary field then
|
# If the event is missing a necessary field then
|
||||||
# skip over it.
|
# skip over it.
|
||||||
@ -1910,9 +1910,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
|||||||
(room_id,)
|
(room_id,)
|
||||||
)
|
)
|
||||||
rows = txn.fetchall()
|
rows = txn.fetchall()
|
||||||
max_depth = max(row[0] for row in rows)
|
max_depth = max(row[1] for row in rows)
|
||||||
|
|
||||||
if max_depth <= token.topological:
|
if max_depth < token.topological:
|
||||||
# We need to ensure we don't delete all the events from the database
|
# We need to ensure we don't delete all the events from the database
|
||||||
# otherwise we wouldn't be able to send any events (due to not
|
# otherwise we wouldn't be able to send any events (due to not
|
||||||
# having any backwards extremeties)
|
# having any backwards extremeties)
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
"""
|
"""
|
||||||
with Measure(self._clock, "_fetch_event_list"):
|
with Measure(self._clock, "_fetch_event_list"):
|
||||||
try:
|
try:
|
||||||
event_id_lists = zip(*event_list)[0]
|
event_id_lists = list(zip(*event_list))[0]
|
||||||
event_ids = [
|
event_ids = [
|
||||||
item for sublist in event_id_lists for item in sublist
|
item for sublist in event_id_lists for item in sublist
|
||||||
]
|
]
|
||||||
@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
logger.exception("do_fetch")
|
logger.exception("do_fetch")
|
||||||
|
|
||||||
# We only want to resolve deferreds from the main thread
|
# We only want to resolve deferreds from the main thread
|
||||||
def fire(evs):
|
def fire(evs, exc):
|
||||||
for _, d in evs:
|
for _, d in evs:
|
||||||
if not d.called:
|
if not d.called:
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
d.errback(e)
|
d.errback(exc)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
self.hs.get_reactor().callFromThread(fire, event_list)
|
self.hs.get_reactor().callFromThread(fire, event_list, e)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
|
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
|
||||||
|
@ -13,14 +13,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json, json
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore, db_to_json
|
||||||
|
|
||||||
|
|
||||||
class FilteringStore(SQLBaseStore):
|
class FilteringStore(SQLBaseStore):
|
||||||
@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
|
|||||||
desc="get_user_filter",
|
desc="get_user_filter",
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
|
defer.returnValue(db_to_json(def_json))
|
||||||
|
|
||||||
def add_user_filter(self, user_localpart, user_filter):
|
def add_user_filter(self, user_localpart, user_filter):
|
||||||
def_json = encode_canonical_json(user_filter)
|
def_json = encode_canonical_json(user_filter)
|
||||||
|
@ -15,7 +15,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import types
|
|
||||||
|
import six
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json, json
|
from canonicaljson import encode_canonical_json, json
|
||||||
|
|
||||||
@ -27,6 +28,11 @@ from ._base import SQLBaseStore
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if six.PY2:
|
||||||
|
db_binary_type = buffer
|
||||||
|
else:
|
||||||
|
db_binary_type = memoryview
|
||||||
|
|
||||||
|
|
||||||
class PusherWorkerStore(SQLBaseStore):
|
class PusherWorkerStore(SQLBaseStore):
|
||||||
def _decode_pushers_rows(self, rows):
|
def _decode_pushers_rows(self, rows):
|
||||||
@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
dataJson = r['data']
|
dataJson = r['data']
|
||||||
r['data'] = None
|
r['data'] = None
|
||||||
try:
|
try:
|
||||||
if isinstance(dataJson, types.BufferType):
|
if isinstance(dataJson, db_binary_type):
|
||||||
dataJson = str(dataJson).decode("UTF8")
|
dataJson = str(dataJson).decode("UTF8")
|
||||||
|
|
||||||
r['data'] = json.loads(dataJson)
|
r['data'] = json.loads(dataJson)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Invalid JSON in data for pusher %d: %s, %s",
|
"Invalid JSON in data for pusher %d: %s, %s",
|
||||||
r['id'], dataJson, e.message,
|
r['id'], dataJson, e.args[0],
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if isinstance(r['pushkey'], types.BufferType):
|
if isinstance(r['pushkey'], db_binary_type):
|
||||||
r['pushkey'] = str(r['pushkey']).decode("UTF8")
|
r['pushkey'] = str(r['pushkey']).decode("UTF8")
|
||||||
|
|
||||||
return rows
|
return rows
|
||||||
|
@ -18,14 +18,14 @@ from collections import namedtuple
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json, json
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore, db_to_json
|
||||||
|
|
||||||
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
|
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
|
||||||
# despite being deprecated and removed in favor of memoryview
|
# despite being deprecated and removed in favor of memoryview
|
||||||
@ -95,7 +95,8 @@ class TransactionStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if result and result["response_code"]:
|
if result and result["response_code"]:
|
||||||
return result["response_code"], json.loads(str(result["response_json"]))
|
return result["response_code"], db_to_json(result["response_json"])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -240,7 +240,6 @@ class RestHelper(object):
|
|||||||
self.assertEquals(200, code)
|
self.assertEquals(200, code)
|
||||||
defer.returnValue(response)
|
defer.returnValue(response)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
|
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
|
||||||
if txn_id is None:
|
if txn_id is None:
|
||||||
txn_id = "m%s" % (str(time.time()))
|
txn_id = "m%s" % (str(time.time()))
|
||||||
@ -248,9 +247,16 @@ class RestHelper(object):
|
|||||||
body = "body_text_here"
|
body = "body_text_here"
|
||||||
|
|
||||||
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
|
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
|
||||||
content = '{"msgtype":"m.text","body":"%s"}' % body
|
content = {"msgtype": "m.text", "body": body}
|
||||||
if tok:
|
if tok:
|
||||||
path = path + "?access_token=%s" % tok
|
path = path + "?access_token=%s" % tok
|
||||||
|
|
||||||
(code, response) = yield self.mock_resource.trigger("PUT", path, content)
|
request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
|
||||||
self.assertEquals(expect_code, code, msg=str(response))
|
render(request, self.resource, self.hs.get_reactor())
|
||||||
|
|
||||||
|
assert int(channel.result["code"]) == expect_code, (
|
||||||
|
"Expected: %d, got: %d, resp: %r"
|
||||||
|
% (expect_code, int(channel.result["code"]), channel.result["body"])
|
||||||
|
)
|
||||||
|
|
||||||
|
return channel.json_body
|
||||||
|
106
tests/storage/test_purge.py
Normal file
106
tests/storage/test_purge.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.rest.client.v1 import room
|
||||||
|
|
||||||
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class PurgeTests(HomeserverTestCase):
|
||||||
|
|
||||||
|
user_id = "@red:server"
|
||||||
|
servlets = [room.register_servlets]
|
||||||
|
|
||||||
|
def make_homeserver(self, reactor, clock):
|
||||||
|
hs = self.setup_test_homeserver("server", http_client=None)
|
||||||
|
return hs
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, hs):
|
||||||
|
self.room_id = self.helper.create_room_as(self.user_id)
|
||||||
|
|
||||||
|
def test_purge(self):
|
||||||
|
"""
|
||||||
|
Purging a room will delete everything before the topological point.
|
||||||
|
"""
|
||||||
|
# Send four messages to the room
|
||||||
|
first = self.helper.send(self.room_id, body="test1")
|
||||||
|
second = self.helper.send(self.room_id, body="test2")
|
||||||
|
third = self.helper.send(self.room_id, body="test3")
|
||||||
|
last = self.helper.send(self.room_id, body="test4")
|
||||||
|
|
||||||
|
storage = self.hs.get_datastore()
|
||||||
|
|
||||||
|
# Get the topological token
|
||||||
|
event = storage.get_topological_token_for_event(last["event_id"])
|
||||||
|
self.pump()
|
||||||
|
event = self.successResultOf(event)
|
||||||
|
|
||||||
|
# Purge everything before this topological token
|
||||||
|
purge = storage.purge_history(self.room_id, event, True)
|
||||||
|
self.pump()
|
||||||
|
self.assertEqual(self.successResultOf(purge), None)
|
||||||
|
|
||||||
|
# Try and get the events
|
||||||
|
get_first = storage.get_event(first["event_id"])
|
||||||
|
get_second = storage.get_event(second["event_id"])
|
||||||
|
get_third = storage.get_event(third["event_id"])
|
||||||
|
get_last = storage.get_event(last["event_id"])
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
|
||||||
|
# and last is not.
|
||||||
|
self.failureResultOf(get_first)
|
||||||
|
self.failureResultOf(get_second)
|
||||||
|
self.failureResultOf(get_third)
|
||||||
|
self.successResultOf(get_last)
|
||||||
|
|
||||||
|
def test_purge_wont_delete_extrems(self):
|
||||||
|
"""
|
||||||
|
Purging a room will delete everything before the topological point.
|
||||||
|
"""
|
||||||
|
# Send four messages to the room
|
||||||
|
first = self.helper.send(self.room_id, body="test1")
|
||||||
|
second = self.helper.send(self.room_id, body="test2")
|
||||||
|
third = self.helper.send(self.room_id, body="test3")
|
||||||
|
last = self.helper.send(self.room_id, body="test4")
|
||||||
|
|
||||||
|
storage = self.hs.get_datastore()
|
||||||
|
|
||||||
|
# Set the topological token higher than it should be
|
||||||
|
event = storage.get_topological_token_for_event(last["event_id"])
|
||||||
|
self.pump()
|
||||||
|
event = self.successResultOf(event)
|
||||||
|
event = "t{}-{}".format(
|
||||||
|
*list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Purge everything before this topological token
|
||||||
|
purge = storage.purge_history(self.room_id, event, True)
|
||||||
|
self.pump()
|
||||||
|
f = self.failureResultOf(purge)
|
||||||
|
self.assertIn("greater than forward", f.value.args[0])
|
||||||
|
|
||||||
|
# Try and get the events
|
||||||
|
get_first = storage.get_event(first["event_id"])
|
||||||
|
get_second = storage.get_event(second["event_id"])
|
||||||
|
get_third = storage.get_event(third["event_id"])
|
||||||
|
get_last = storage.get_event(last["event_id"])
|
||||||
|
self.pump()
|
||||||
|
|
||||||
|
# Nothing is deleted.
|
||||||
|
self.successResultOf(get_first)
|
||||||
|
self.successResultOf(get_second)
|
||||||
|
self.successResultOf(get_third)
|
||||||
|
self.successResultOf(get_last)
|
@ -151,6 +151,7 @@ class HomeserverTestCase(TestCase):
|
|||||||
hijack_auth (bool): Whether to hijack auth to return the user specified
|
hijack_auth (bool): Whether to hijack auth to return the user specified
|
||||||
in user_id.
|
in user_id.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
servlets = []
|
servlets = []
|
||||||
hijack_auth = True
|
hijack_auth = True
|
||||||
|
|
||||||
@ -279,3 +280,13 @@ class HomeserverTestCase(TestCase):
|
|||||||
kwargs = dict(kwargs)
|
kwargs = dict(kwargs)
|
||||||
kwargs.update(self._hs_args)
|
kwargs.update(self._hs_args)
|
||||||
return setup_test_homeserver(self.addCleanup, *args, **kwargs)
|
return setup_test_homeserver(self.addCleanup, *args, **kwargs)
|
||||||
|
|
||||||
|
def pump(self):
|
||||||
|
"""
|
||||||
|
Pump the reactor enough that Deferreds will fire.
|
||||||
|
"""
|
||||||
|
self.reactor.pump([0.0] * 100)
|
||||||
|
|
||||||
|
def get_success(self, d):
|
||||||
|
self.pump()
|
||||||
|
return self.successResultOf(d)
|
||||||
|
@ -147,6 +147,8 @@ def setup_test_homeserver(
|
|||||||
config.max_mau_value = 50
|
config.max_mau_value = 50
|
||||||
config.mau_limits_reserved_threepids = []
|
config.mau_limits_reserved_threepids = []
|
||||||
config.admin_contact = None
|
config.admin_contact = None
|
||||||
|
config.rc_messages_per_second = 10000
|
||||||
|
config.rc_message_burst_count = 10000
|
||||||
|
|
||||||
# we need a sane default_room_version, otherwise attempts to create rooms will
|
# we need a sane default_room_version, otherwise attempts to create rooms will
|
||||||
# fail.
|
# fail.
|
||||||
|
Loading…
Reference in New Issue
Block a user