mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-10-01 08:25:44 -04:00
Port storage/ to Python 3 (#3725)
This commit is contained in:
parent
475253a88e
commit
14e4d4f4bf
1
changelog.d/3725.misc
Normal file
1
changelog.d/3725.misc
Normal file
@ -0,0 +1 @@
|
||||
The synapse.storage module has been ported to Python 3.
|
@ -31,5 +31,5 @@ $TOX_BIN/pip install 'setuptools>=18.5'
|
||||
$TOX_BIN/pip install 'pip>=10'
|
||||
|
||||
{ python synapse/python_dependencies.py
|
||||
echo lxml psycopg2
|
||||
echo lxml
|
||||
} | xargs $TOX_BIN/pip install
|
||||
|
@ -78,6 +78,9 @@ CONDITIONAL_REQUIREMENTS = {
|
||||
"affinity": {
|
||||
"affinity": ["affinity"],
|
||||
},
|
||||
"postgres": {
|
||||
"psycopg2>=2.6": ["psycopg2"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -17,9 +17,10 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
from six import iteritems, iterkeys, itervalues
|
||||
from six import PY2, iteritems, iterkeys, itervalues
|
||||
from six.moves import intern, range
|
||||
|
||||
from canonicaljson import json
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from twisted.internet import defer
|
||||
@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception):
|
||||
something went wrong.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def db_to_json(db_content):
|
||||
"""
|
||||
Take some data from a database row and return a JSON-decoded object.
|
||||
|
||||
Args:
|
||||
db_content (memoryview|buffer|bytes|bytearray|unicode)
|
||||
"""
|
||||
# psycopg2 on Python 3 returns memoryview objects, which we need to
|
||||
# cast to bytes to decode
|
||||
if isinstance(db_content, memoryview):
|
||||
db_content = db_content.tobytes()
|
||||
|
||||
# psycopg2 on Python 2 returns buffer objects, which we need to cast to
|
||||
# bytes to decode
|
||||
if PY2 and isinstance(db_content, buffer):
|
||||
db_content = bytes(db_content)
|
||||
|
||||
# Decode it to a Unicode string before feeding it to json.loads, so we
|
||||
# consistenty get a Unicode-containing object out.
|
||||
if isinstance(db_content, (bytes, bytearray)):
|
||||
db_content = db_content.decode('utf8')
|
||||
|
||||
try:
|
||||
return json.loads(db_content)
|
||||
except Exception:
|
||||
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
|
||||
raise
|
||||
|
@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||
local_by_user_then_device = {}
|
||||
for user_id, messages_by_device in messages_by_user_then_device.items():
|
||||
messages_json_for_user = {}
|
||||
devices = messages_by_device.keys()
|
||||
devices = list(messages_by_device.keys())
|
||||
if len(devices) == 1 and devices[0] == "*":
|
||||
# Handle wildcard device_ids.
|
||||
sql = (
|
||||
|
@ -24,7 +24,7 @@ from synapse.api.errors import StoreError
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
||||
|
||||
from ._base import Cache, SQLBaseStore
|
||||
from ._base import Cache, SQLBaseStore, db_to_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore):
|
||||
if device is not None:
|
||||
key_json = device.get("key_json", None)
|
||||
if key_json:
|
||||
result["keys"] = json.loads(key_json)
|
||||
result["keys"] = db_to_json(key_json)
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore):
|
||||
retcol="content",
|
||||
desc="_get_cached_user_device",
|
||||
)
|
||||
defer.returnValue(json.loads(content))
|
||||
defer.returnValue(db_to_json(content))
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def _get_cached_devices_for_user(self, user_id):
|
||||
@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore):
|
||||
desc="_get_cached_devices_for_user",
|
||||
)
|
||||
defer.returnValue({
|
||||
device["device_id"]: json.loads(device["content"])
|
||||
device["device_id"]: db_to_json(device["content"])
|
||||
for device in devices
|
||||
})
|
||||
|
||||
@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore):
|
||||
|
||||
key_json = device.get("key_json", None)
|
||||
if key_json:
|
||||
result["keys"] = json.loads(key_json)
|
||||
result["keys"] = db_to_json(key_json)
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
|
@ -14,13 +14,13 @@
|
||||
# limitations under the License.
|
||||
from six import iteritems
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from ._base import SQLBaseStore, db_to_json
|
||||
|
||||
|
||||
class EndToEndKeyStore(SQLBaseStore):
|
||||
@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore):
|
||||
|
||||
for user_id, device_keys in iteritems(results):
|
||||
for device_id, device_info in iteritems(device_keys):
|
||||
device_info["keys"] = json.loads(device_info.pop("key_json"))
|
||||
device_info["keys"] = db_to_json(device_info.pop("key_json"))
|
||||
|
||||
defer.returnValue(results)
|
||||
|
||||
|
@ -41,13 +41,18 @@ class PostgresEngine(object):
|
||||
db_conn.set_isolation_level(
|
||||
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||
)
|
||||
|
||||
# Set the bytea output to escape, vs the default of hex
|
||||
cursor = db_conn.cursor()
|
||||
cursor.execute("SET bytea_output TO escape")
|
||||
|
||||
# Asynchronous commit, don't wait for the server to call fsync before
|
||||
# ending the transaction.
|
||||
# https://www.postgresql.org/docs/current/static/wal-async-commit.html
|
||||
if not self.synchronous_commit:
|
||||
cursor = db_conn.cursor()
|
||||
cursor.execute("SET synchronous_commit TO OFF")
|
||||
cursor.close()
|
||||
|
||||
cursor.close()
|
||||
|
||||
def is_deadlock(self, error):
|
||||
if isinstance(error, self.module.DatabaseError):
|
||||
|
@ -19,7 +19,7 @@ import logging
|
||||
from collections import OrderedDict, deque, namedtuple
|
||||
from functools import wraps
|
||||
|
||||
from six import iteritems
|
||||
from six import iteritems, text_type
|
||||
from six.moves import range
|
||||
|
||||
from canonicaljson import json
|
||||
@ -1220,7 +1220,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
||||
"sender": event.sender,
|
||||
"contains_url": (
|
||||
"url" in event.content
|
||||
and isinstance(event.content["url"], basestring)
|
||||
and isinstance(event.content["url"], text_type)
|
||||
),
|
||||
}
|
||||
for event, _ in events_and_contexts
|
||||
@ -1529,7 +1529,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
||||
|
||||
contains_url = "url" in content
|
||||
if contains_url:
|
||||
contains_url &= isinstance(content["url"], basestring)
|
||||
contains_url &= isinstance(content["url"], text_type)
|
||||
except (KeyError, AttributeError):
|
||||
# If the event is missing a necessary field then
|
||||
# skip over it.
|
||||
@ -1910,9 +1910,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
||||
(room_id,)
|
||||
)
|
||||
rows = txn.fetchall()
|
||||
max_depth = max(row[0] for row in rows)
|
||||
max_depth = max(row[1] for row in rows)
|
||||
|
||||
if max_depth <= token.topological:
|
||||
if max_depth < token.topological:
|
||||
# We need to ensure we don't delete all the events from the database
|
||||
# otherwise we wouldn't be able to send any events (due to not
|
||||
# having any backwards extremeties)
|
||||
|
@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
with Measure(self._clock, "_fetch_event_list"):
|
||||
try:
|
||||
event_id_lists = zip(*event_list)[0]
|
||||
event_id_lists = list(zip(*event_list))[0]
|
||||
event_ids = [
|
||||
item for sublist in event_id_lists for item in sublist
|
||||
]
|
||||
@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
logger.exception("do_fetch")
|
||||
|
||||
# We only want to resolve deferreds from the main thread
|
||||
def fire(evs):
|
||||
def fire(evs, exc):
|
||||
for _, d in evs:
|
||||
if not d.called:
|
||||
with PreserveLoggingContext():
|
||||
d.errback(e)
|
||||
d.errback(exc)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self.hs.get_reactor().callFromThread(fire, event_list)
|
||||
self.hs.get_reactor().callFromThread(fire, event_list, e)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
|
||||
|
@ -13,14 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from ._base import SQLBaseStore, db_to_json
|
||||
|
||||
|
||||
class FilteringStore(SQLBaseStore):
|
||||
@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
|
||||
desc="get_user_filter",
|
||||
)
|
||||
|
||||
defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
|
||||
defer.returnValue(db_to_json(def_json))
|
||||
|
||||
def add_user_filter(self, user_localpart, user_filter):
|
||||
def_json = encode_canonical_json(user_filter)
|
||||
|
@ -15,7 +15,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import types
|
||||
|
||||
import six
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
|
||||
@ -27,6 +28,11 @@ from ._base import SQLBaseStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if six.PY2:
|
||||
db_binary_type = buffer
|
||||
else:
|
||||
db_binary_type = memoryview
|
||||
|
||||
|
||||
class PusherWorkerStore(SQLBaseStore):
|
||||
def _decode_pushers_rows(self, rows):
|
||||
@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
dataJson = r['data']
|
||||
r['data'] = None
|
||||
try:
|
||||
if isinstance(dataJson, types.BufferType):
|
||||
if isinstance(dataJson, db_binary_type):
|
||||
dataJson = str(dataJson).decode("UTF8")
|
||||
|
||||
r['data'] = json.loads(dataJson)
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"Invalid JSON in data for pusher %d: %s, %s",
|
||||
r['id'], dataJson, e.message,
|
||||
r['id'], dataJson, e.args[0],
|
||||
)
|
||||
pass
|
||||
|
||||
if isinstance(r['pushkey'], types.BufferType):
|
||||
if isinstance(r['pushkey'], db_binary_type):
|
||||
r['pushkey'] = str(r['pushkey']).decode("UTF8")
|
||||
|
||||
return rows
|
||||
|
@ -18,14 +18,14 @@ from collections import namedtuple
|
||||
|
||||
import six
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from ._base import SQLBaseStore, db_to_json
|
||||
|
||||
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
|
||||
# despite being deprecated and removed in favor of memoryview
|
||||
@ -95,7 +95,8 @@ class TransactionStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
if result and result["response_code"]:
|
||||
return result["response_code"], json.loads(str(result["response_json"]))
|
||||
return result["response_code"], db_to_json(result["response_json"])
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@ -240,7 +240,6 @@ class RestHelper(object):
|
||||
self.assertEquals(200, code)
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
|
||||
if txn_id is None:
|
||||
txn_id = "m%s" % (str(time.time()))
|
||||
@ -248,9 +247,16 @@ class RestHelper(object):
|
||||
body = "body_text_here"
|
||||
|
||||
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
|
||||
content = '{"msgtype":"m.text","body":"%s"}' % body
|
||||
content = {"msgtype": "m.text", "body": body}
|
||||
if tok:
|
||||
path = path + "?access_token=%s" % tok
|
||||
|
||||
(code, response) = yield self.mock_resource.trigger("PUT", path, content)
|
||||
self.assertEquals(expect_code, code, msg=str(response))
|
||||
request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
|
||||
render(request, self.resource, self.hs.get_reactor())
|
||||
|
||||
assert int(channel.result["code"]) == expect_code, (
|
||||
"Expected: %d, got: %d, resp: %r"
|
||||
% (expect_code, int(channel.result["code"]), channel.result["body"])
|
||||
)
|
||||
|
||||
return channel.json_body
|
||||
|
106
tests/storage/test_purge.py
Normal file
106
tests/storage/test_purge.py
Normal file
@ -0,0 +1,106 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.rest.client.v1 import room
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
class PurgeTests(HomeserverTestCase):
|
||||
|
||||
user_id = "@red:server"
|
||||
servlets = [room.register_servlets]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = self.setup_test_homeserver("server", http_client=None)
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
def test_purge(self):
|
||||
"""
|
||||
Purging a room will delete everything before the topological point.
|
||||
"""
|
||||
# Send four messages to the room
|
||||
first = self.helper.send(self.room_id, body="test1")
|
||||
second = self.helper.send(self.room_id, body="test2")
|
||||
third = self.helper.send(self.room_id, body="test3")
|
||||
last = self.helper.send(self.room_id, body="test4")
|
||||
|
||||
storage = self.hs.get_datastore()
|
||||
|
||||
# Get the topological token
|
||||
event = storage.get_topological_token_for_event(last["event_id"])
|
||||
self.pump()
|
||||
event = self.successResultOf(event)
|
||||
|
||||
# Purge everything before this topological token
|
||||
purge = storage.purge_history(self.room_id, event, True)
|
||||
self.pump()
|
||||
self.assertEqual(self.successResultOf(purge), None)
|
||||
|
||||
# Try and get the events
|
||||
get_first = storage.get_event(first["event_id"])
|
||||
get_second = storage.get_event(second["event_id"])
|
||||
get_third = storage.get_event(third["event_id"])
|
||||
get_last = storage.get_event(last["event_id"])
|
||||
self.pump()
|
||||
|
||||
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
|
||||
# and last is not.
|
||||
self.failureResultOf(get_first)
|
||||
self.failureResultOf(get_second)
|
||||
self.failureResultOf(get_third)
|
||||
self.successResultOf(get_last)
|
||||
|
||||
def test_purge_wont_delete_extrems(self):
|
||||
"""
|
||||
Purging a room will delete everything before the topological point.
|
||||
"""
|
||||
# Send four messages to the room
|
||||
first = self.helper.send(self.room_id, body="test1")
|
||||
second = self.helper.send(self.room_id, body="test2")
|
||||
third = self.helper.send(self.room_id, body="test3")
|
||||
last = self.helper.send(self.room_id, body="test4")
|
||||
|
||||
storage = self.hs.get_datastore()
|
||||
|
||||
# Set the topological token higher than it should be
|
||||
event = storage.get_topological_token_for_event(last["event_id"])
|
||||
self.pump()
|
||||
event = self.successResultOf(event)
|
||||
event = "t{}-{}".format(
|
||||
*list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
|
||||
)
|
||||
|
||||
# Purge everything before this topological token
|
||||
purge = storage.purge_history(self.room_id, event, True)
|
||||
self.pump()
|
||||
f = self.failureResultOf(purge)
|
||||
self.assertIn("greater than forward", f.value.args[0])
|
||||
|
||||
# Try and get the events
|
||||
get_first = storage.get_event(first["event_id"])
|
||||
get_second = storage.get_event(second["event_id"])
|
||||
get_third = storage.get_event(third["event_id"])
|
||||
get_last = storage.get_event(last["event_id"])
|
||||
self.pump()
|
||||
|
||||
# Nothing is deleted.
|
||||
self.successResultOf(get_first)
|
||||
self.successResultOf(get_second)
|
||||
self.successResultOf(get_third)
|
||||
self.successResultOf(get_last)
|
@ -151,6 +151,7 @@ class HomeserverTestCase(TestCase):
|
||||
hijack_auth (bool): Whether to hijack auth to return the user specified
|
||||
in user_id.
|
||||
"""
|
||||
|
||||
servlets = []
|
||||
hijack_auth = True
|
||||
|
||||
@ -279,3 +280,13 @@ class HomeserverTestCase(TestCase):
|
||||
kwargs = dict(kwargs)
|
||||
kwargs.update(self._hs_args)
|
||||
return setup_test_homeserver(self.addCleanup, *args, **kwargs)
|
||||
|
||||
def pump(self):
|
||||
"""
|
||||
Pump the reactor enough that Deferreds will fire.
|
||||
"""
|
||||
self.reactor.pump([0.0] * 100)
|
||||
|
||||
def get_success(self, d):
|
||||
self.pump()
|
||||
return self.successResultOf(d)
|
||||
|
@ -147,6 +147,8 @@ def setup_test_homeserver(
|
||||
config.max_mau_value = 50
|
||||
config.mau_limits_reserved_threepids = []
|
||||
config.admin_contact = None
|
||||
config.rc_messages_per_second = 10000
|
||||
config.rc_message_burst_count = 10000
|
||||
|
||||
# we need a sane default_room_version, otherwise attempts to create rooms will
|
||||
# fail.
|
||||
|
Loading…
Reference in New Issue
Block a user