Port storage/ to Python 3 (#3725)

This commit is contained in:
Amber Brown 2018-08-31 00:19:58 +10:00 committed by GitHub
parent 475253a88e
commit 14e4d4f4bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 208 additions and 36 deletions

View file

@ -78,6 +78,9 @@ CONDITIONAL_REQUIREMENTS = {
"affinity": {
"affinity": ["affinity"],
},
"postgres": {
"psycopg2>=2.6": ["psycopg2"]
}
}

View file

@ -17,9 +17,10 @@ import sys
import threading
import time
from six import iteritems, iterkeys, itervalues
from six import PY2, iteritems, iterkeys, itervalues
from six.moves import intern, range
from canonicaljson import json
from prometheus_client import Histogram
from twisted.internet import defer
@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception):
something went wrong.
"""
pass
def db_to_json(db_content):
"""
Take some data from a database row and return a JSON-decoded object.
Args:
db_content (memoryview|buffer|bytes|bytearray|unicode)
"""
# psycopg2 on Python 3 returns memoryview objects, which we need to
# cast to bytes to decode
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
# psycopg2 on Python 2 returns buffer objects, which we need to cast to
# bytes to decode
if PY2 and isinstance(db_content, buffer):
db_content = bytes(db_content)
# Decode it to a Unicode string before feeding it to json.loads, so we
# consistenty get a Unicode-containing object out.
if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode('utf8')
try:
return json.loads(db_content)
except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise

View file

@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
devices = messages_by_device.keys()
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = (

View file

@ -24,7 +24,7 @@ from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from ._base import Cache, SQLBaseStore
from ._base import Cache, SQLBaseStore, db_to_json
logger = logging.getLogger(__name__)
@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore):
if device is not None:
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore):
retcol="content",
desc="_get_cached_user_device",
)
defer.returnValue(json.loads(content))
defer.returnValue(db_to_json(content))
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore):
desc="_get_cached_devices_for_user",
)
defer.returnValue({
device["device_id"]: json.loads(device["content"])
device["device_id"]: db_to_json(device["content"])
for device in devices
})
@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore):
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name

View file

@ -14,13 +14,13 @@
# limitations under the License.
from six import iteritems
from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore
from ._base import SQLBaseStore, db_to_json
class EndToEndKeyStore(SQLBaseStore):
@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore):
for user_id, device_keys in iteritems(results):
for device_id, device_info in iteritems(device_keys):
device_info["keys"] = json.loads(device_info.pop("key_json"))
device_info["keys"] = db_to_json(device_info.pop("key_json"))
defer.returnValue(results)

View file

@ -41,13 +41,18 @@ class PostgresEngine(object):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
# Set the bytea output to escape, vs the default of hex
cursor = db_conn.cursor()
cursor.execute("SET bytea_output TO escape")
# Asynchronous commit, don't wait for the server to call fsync before
# ending the transaction.
# https://www.postgresql.org/docs/current/static/wal-async-commit.html
if not self.synchronous_commit:
cursor = db_conn.cursor()
cursor.execute("SET synchronous_commit TO OFF")
cursor.close()
cursor.close()
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):

View file

@ -19,7 +19,7 @@ import logging
from collections import OrderedDict, deque, namedtuple
from functools import wraps
from six import iteritems
from six import iteritems, text_type
from six.moves import range
from canonicaljson import json
@ -1220,7 +1220,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
"sender": event.sender,
"contains_url": (
"url" in event.content
and isinstance(event.content["url"], basestring)
and isinstance(event.content["url"], text_type)
),
}
for event, _ in events_and_contexts
@ -1529,7 +1529,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
contains_url = "url" in content
if contains_url:
contains_url &= isinstance(content["url"], basestring)
contains_url &= isinstance(content["url"], text_type)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
@ -1910,9 +1910,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
(room_id,)
)
rows = txn.fetchall()
max_depth = max(row[0] for row in rows)
max_depth = max(row[1] for row in rows)
if max_depth <= token.topological:
if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from collections import namedtuple
@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
with Measure(self._clock, "_fetch_event_list"):
try:
event_id_lists = zip(*event_list)[0]
event_id_lists = list(zip(*event_list))[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
def fire(evs):
def fire(evs, exc):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
d.errback(e)
d.errback(exc)
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list)
self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):

View file

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.util.caches.descriptors import cachedInlineCallbacks
from ._base import SQLBaseStore
from ._base import SQLBaseStore, db_to_json
class FilteringStore(SQLBaseStore):
@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
defer.returnValue(db_to_json(def_json))
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)

View file

@ -15,7 +15,8 @@
# limitations under the License.
import logging
import types
import six
from canonicaljson import encode_canonical_json, json
@ -27,6 +28,11 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
if six.PY2:
db_binary_type = buffer
else:
db_binary_type = memoryview
class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore):
dataJson = r['data']
r['data'] = None
try:
if isinstance(dataJson, types.BufferType):
if isinstance(dataJson, db_binary_type):
dataJson = str(dataJson).decode("UTF8")
r['data'] = json.loads(dataJson)
except Exception as e:
logger.warn(
"Invalid JSON in data for pusher %d: %s, %s",
r['id'], dataJson, e.message,
r['id'], dataJson, e.args[0],
)
pass
if isinstance(r['pushkey'], types.BufferType):
if isinstance(r['pushkey'], db_binary_type):
r['pushkey'] = str(r['pushkey']).decode("UTF8")
return rows

View file

@ -18,14 +18,14 @@ from collections import namedtuple
import six
from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore
from ._base import SQLBaseStore, db_to_json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
@ -95,7 +95,8 @@ class TransactionStore(SQLBaseStore):
)
if result and result["response_code"]:
return result["response_code"], json.loads(str(result["response_json"]))
return result["response_code"], db_to_json(result["response_json"])
else:
return None