Make scripts/ and scripts-dev/ pass pyflakes (and the rest of the codebase on py3) (#4068)

This commit is contained in:
Amber Brown 2018-10-20 11:16:55 +11:00 committed by GitHub
parent 81d4f51524
commit e1728dfcbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 511 additions and 518 deletions

View File

@ -14,7 +14,7 @@ matrix:
- python: 2.7 - python: 2.7
env: TOX_ENV=packaging env: TOX_ENV=packaging
- python: 2.7 - python: 3.6
env: TOX_ENV=pep8 env: TOX_ENV=pep8
- python: 2.7 - python: 2.7

1
changelog.d/4068.misc Normal file
View File

@ -0,0 +1 @@
Make the Python scripts in the top-level scripts folders meet pep8 and pass flake8.

View File

@ -1,21 +1,20 @@
from synapse.events import FrozenEvent from __future__ import print_function
from synapse.api.auth import Auth
from mock import Mock
import argparse import argparse
import itertools import itertools
import json import json
import sys import sys
from mock import Mock
from synapse.api.auth import Auth
from synapse.events import FrozenEvent
def check_auth(auth, auth_chain, events): def check_auth(auth, auth_chain, events):
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)
auth_map = { auth_map = {e.event_id: e for e in auth_chain}
e.event_id: e
for e in auth_chain
}
create_events = {} create_events = {}
for e in auth_chain: for e in auth_chain:
@ -25,31 +24,26 @@ def check_auth(auth, auth_chain, events):
for e in itertools.chain(auth_chain, events): for e in itertools.chain(auth_chain, events):
auth_events_list = [auth_map[i] for i, _ in e.auth_events] auth_events_list = [auth_map[i] for i, _ in e.auth_events]
auth_events = { auth_events = {(e.type, e.state_key): e for e in auth_events_list}
(e.type, e.state_key): e
for e in auth_events_list
}
auth_events[("m.room.create", "")] = create_events[e.room_id] auth_events[("m.room.create", "")] = create_events[e.room_id]
try: try:
auth.check(e, auth_events=auth_events) auth.check(e, auth_events=auth_events)
except Exception as ex: except Exception as ex:
print "Failed:", e.event_id, e.type, e.state_key print("Failed:", e.event_id, e.type, e.state_key)
print "Auth_events:", auth_events print("Auth_events:", auth_events)
print ex print(ex)
print json.dumps(e.get_dict(), sort_keys=True, indent=4) print(json.dumps(e.get_dict(), sort_keys=True, indent=4))
# raise # raise
print "Success:", e.event_id, e.type, e.state_key print("Success:", e.event_id, e.type, e.state_key)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'json', 'json', nargs='?', type=argparse.FileType('r'), default=sys.stdin
nargs='?',
type=argparse.FileType('r'),
default=sys.stdin,
) )
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,10 +1,15 @@
from synapse.crypto.event_signing import *
from unpaddedbase64 import encode_base64
import argparse import argparse
import hashlib import hashlib
import sys
import json import json
import logging
import sys
from unpaddedbase64 import encode_base64
from synapse.crypto.event_signing import (
check_event_content_hash,
compute_event_reference_hash,
)
class dictobj(dict): class dictobj(dict):
@ -24,27 +29,26 @@ class dictobj(dict):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'), parser.add_argument(
default=sys.stdin) "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin
)
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig() logging.basicConfig()
event_json = dictobj(json.load(args.input_json)) event_json = dictobj(json.load(args.input_json))
algorithms = { algorithms = {"sha256": hashlib.sha256}
"sha256": hashlib.sha256,
}
for alg_name in event_json.hashes: for alg_name in event_json.hashes:
if check_event_content_hash(event_json, algorithms[alg_name]): if check_event_content_hash(event_json, algorithms[alg_name]):
print "PASS content hash %s" % (alg_name,) print("PASS content hash %s" % (alg_name,))
else: else:
print "FAIL content hash %s" % (alg_name,) print("FAIL content hash %s" % (alg_name,))
for algorithm in algorithms.values(): for algorithm in algorithms.values():
name, h_bytes = compute_event_reference_hash(event_json, algorithm) name, h_bytes = compute_event_reference_hash(event_json, algorithm)
print "Reference hash %s: %s" % (name, encode_base64(h_bytes)) print("Reference hash %s: %s" % (name, encode_base64(h_bytes)))
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,15 +1,15 @@
from signedjson.sign import verify_signed_json import argparse
import json
import logging
import sys
import urllib2
import dns.resolver
from signedjson.key import decode_verify_key_bytes, write_signing_keys from signedjson.key import decode_verify_key_bytes, write_signing_keys
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import urllib2
import json
import sys
import dns.resolver
import pprint
import argparse
import logging
def get_targets(server_name): def get_targets(server_name):
if ":" in server_name: if ":" in server_name:
@ -23,6 +23,7 @@ def get_targets(server_name):
except dns.resolver.NXDOMAIN: except dns.resolver.NXDOMAIN:
yield (server_name, 8448) yield (server_name, 8448)
def get_server_keys(server_name, target, port): def get_server_keys(server_name, target, port):
url = "https://%s:%i/_matrix/key/v1" % (target, port) url = "https://%s:%i/_matrix/key/v1" % (target, port)
keys = json.load(urllib2.urlopen(url)) keys = json.load(urllib2.urlopen(url))
@ -33,12 +34,14 @@ def get_server_keys(server_name, target, port):
verify_keys[key_id] = verify_key verify_keys[key_id] = verify_key
return verify_keys return verify_keys
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("signature_name") parser.add_argument("signature_name")
parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'), parser.add_argument(
default=sys.stdin) "input_json", nargs="?", type=argparse.FileType('r'), default=sys.stdin
)
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig() logging.basicConfig()
@ -48,24 +51,23 @@ def main():
for target, port in get_targets(server_name): for target, port in get_targets(server_name):
try: try:
keys = get_server_keys(server_name, target, port) keys = get_server_keys(server_name, target, port)
print "Using keys from https://%s:%s/_matrix/key/v1" % (target, port) print("Using keys from https://%s:%s/_matrix/key/v1" % (target, port))
write_signing_keys(sys.stdout, keys.values()) write_signing_keys(sys.stdout, keys.values())
break break
except: except Exception:
logging.exception("Error talking to %s:%s", target, port) logging.exception("Error talking to %s:%s", target, port)
json_to_check = json.load(args.input_json) json_to_check = json.load(args.input_json)
print "Checking JSON:" print("Checking JSON:")
for key_id in json_to_check["signatures"][args.signature_name]: for key_id in json_to_check["signatures"][args.signature_name]:
try: try:
key = keys[key_id] key = keys[key_id]
verify_signed_json(json_to_check, args.signature_name, key) verify_signed_json(json_to_check, args.signature_name, key)
print "PASS %s" % (key_id,) print("PASS %s" % (key_id,))
except: except Exception:
logging.exception("Check for key %s failed" % (key_id,)) logging.exception("Check for key %s failed" % (key_id,))
print "FAIL %s" % (key_id,) print("FAIL %s" % (key_id,))
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -1,13 +1,21 @@
import hashlib
import json
import sys
import time
import six
import psycopg2 import psycopg2
import yaml import yaml
import sys from canonicaljson import encode_canonical_json
import json
import time
import hashlib
from unpaddedbase64 import encode_base64
from signedjson.key import read_signing_keys from signedjson.key import read_signing_keys
from signedjson.sign import sign_json from signedjson.sign import sign_json
from canonicaljson import encode_canonical_json from unpaddedbase64 import encode_base64
if six.PY2:
db_type = six.moves.builtins.buffer
else:
db_type = memoryview
def select_v1_keys(connection): def select_v1_keys(connection):
@ -39,7 +47,9 @@ def select_v2_json(connection):
cursor.close() cursor.close()
results = {} results = {}
for server_name, key_id, key_json in rows: for server_name, key_id, key_json in rows:
results.setdefault(server_name, {})[key_id] = json.loads(str(key_json).decode("utf-8")) results.setdefault(server_name, {})[key_id] = json.loads(
str(key_json).decode("utf-8")
)
return results return results
@ -47,10 +57,7 @@ def convert_v1_to_v2(server_name, valid_until, keys, certificate):
return { return {
"old_verify_keys": {}, "old_verify_keys": {},
"server_name": server_name, "server_name": server_name,
"verify_keys": { "verify_keys": {key_id: {"key": key} for key_id, key in keys.items()},
key_id: {"key": key}
for key_id, key in keys.items()
},
"valid_until_ts": valid_until, "valid_until_ts": valid_until,
"tls_fingerprints": [fingerprint(certificate)], "tls_fingerprints": [fingerprint(certificate)],
} }
@ -65,7 +72,7 @@ def rows_v2(server, json):
valid_until = json["valid_until_ts"] valid_until = json["valid_until_ts"]
key_json = encode_canonical_json(json) key_json = encode_canonical_json(json)
for key_id in json["verify_keys"]: for key_id in json["verify_keys"]:
yield (server, key_id, "-", valid_until, valid_until, buffer(key_json)) yield (server, key_id, "-", valid_until, valid_until, db_type(key_json))
def main(): def main():
@ -87,7 +94,7 @@ def main():
result = {} result = {}
for server in keys: for server in keys:
if not server in json: if server not in json:
v2_json = convert_v1_to_v2( v2_json = convert_v1_to_v2(
server, valid_until, keys[server], certificates[server] server, valid_until, keys[server], certificates[server]
) )
@ -96,10 +103,7 @@ def main():
yaml.safe_dump(result, sys.stdout, default_flow_style=False) yaml.safe_dump(result, sys.stdout, default_flow_style=False)
rows = list( rows = list(row for server, json in result.items() for row in rows_v2(server, json))
row for server, json in result.items()
for row in rows_v2(server, json)
)
cursor = connection.cursor() cursor = connection.cursor()
cursor.executemany( cursor.executemany(
@ -107,7 +111,7 @@ def main():
" server_name, key_id, from_server," " server_name, key_id, from_server,"
" ts_added_ms, ts_valid_until_ms, key_json" " ts_added_ms, ts_valid_until_ms, key_json"
") VALUES (%s, %s, %s, %s, %s, %s)", ") VALUES (%s, %s, %s, %s, %s, %s)",
rows rows,
) )
connection.commit() connection.commit()

View File

@ -1,8 +1,16 @@
#! /usr/bin/python #! /usr/bin/python
from __future__ import print_function
import argparse
import ast import ast
import os
import re
import sys
import yaml import yaml
class DefinitionVisitor(ast.NodeVisitor): class DefinitionVisitor(ast.NodeVisitor):
def __init__(self): def __init__(self):
super(DefinitionVisitor, self).__init__() super(DefinitionVisitor, self).__init__()
@ -42,15 +50,18 @@ def non_empty(defs):
functions = {name: non_empty(f) for name, f in defs['def'].items()} functions = {name: non_empty(f) for name, f in defs['def'].items()}
classes = {name: non_empty(f) for name, f in defs['class'].items()} classes = {name: non_empty(f) for name, f in defs['class'].items()}
result = {} result = {}
if functions: result['def'] = functions if functions:
if classes: result['class'] = classes result['def'] = functions
if classes:
result['class'] = classes
names = defs['names'] names = defs['names']
uses = [] uses = []
for name in names.get('Load', ()): for name in names.get('Load', ()):
if name not in names.get('Param', ()) and name not in names.get('Store', ()): if name not in names.get('Param', ()) and name not in names.get('Store', ()):
uses.append(name) uses.append(name)
uses.extend(defs['attrs']) uses.extend(defs['attrs'])
if uses: result['uses'] = uses if uses:
result['uses'] = uses
result['names'] = names result['names'] = names
result['attrs'] = defs['attrs'] result['attrs'] = defs['attrs']
return result return result
@ -95,7 +106,6 @@ def used_names(prefix, item, defs, names):
if __name__ == '__main__': if __name__ == '__main__':
import sys, os, argparse, re
parser = argparse.ArgumentParser(description='Find definitions.') parser = argparse.ArgumentParser(description='Find definitions.')
parser.add_argument( parser.add_argument(
@ -105,24 +115,28 @@ if __name__ == '__main__':
"--ignore", action="append", metavar="REGEXP", help="Ignore a pattern" "--ignore", action="append", metavar="REGEXP", help="Ignore a pattern"
) )
parser.add_argument( parser.add_argument(
"--pattern", action="append", metavar="REGEXP", "--pattern", action="append", metavar="REGEXP", help="Search for a pattern"
help="Search for a pattern"
) )
parser.add_argument( parser.add_argument(
"directories", nargs='+', metavar="DIR", "directories",
help="Directories to search for definitions" nargs='+',
metavar="DIR",
help="Directories to search for definitions",
) )
parser.add_argument( parser.add_argument(
"--referrers", default=0, type=int, "--referrers",
help="Include referrers up to the given depth" default=0,
type=int,
help="Include referrers up to the given depth",
) )
parser.add_argument( parser.add_argument(
"--referred", default=0, type=int, "--referred",
help="Include referred down to the given depth" default=0,
type=int,
help="Include referred down to the given depth",
) )
parser.add_argument( parser.add_argument(
"--format", default="yaml", "--format", default="yaml", help="Output format, one of 'yaml' or 'dot'"
help="Output format, one of 'yaml' or 'dot'"
) )
args = parser.parse_args() args = parser.parse_args()
@ -162,7 +176,7 @@ if __name__ == '__main__':
for used_by in entry.get("used", ()): for used_by in entry.get("used", ()):
referrers.add(used_by) referrers.add(used_by)
for name, definition in names.items(): for name, definition in names.items():
if not name in referrers: if name not in referrers:
continue continue
if ignore and any(pattern.match(name) for pattern in ignore): if ignore and any(pattern.match(name) for pattern in ignore):
continue continue
@ -176,7 +190,7 @@ if __name__ == '__main__':
for uses in entry.get("uses", ()): for uses in entry.get("uses", ()):
referred.add(uses) referred.add(uses)
for name, definition in names.items(): for name, definition in names.items():
if not name in referred: if name not in referred:
continue continue
if ignore and any(pattern.match(name) for pattern in ignore): if ignore and any(pattern.match(name) for pattern in ignore):
continue continue
@ -185,12 +199,12 @@ if __name__ == '__main__':
if args.format == 'yaml': if args.format == 'yaml':
yaml.dump(result, sys.stdout, default_flow_style=False) yaml.dump(result, sys.stdout, default_flow_style=False)
elif args.format == 'dot': elif args.format == 'dot':
print "digraph {" print("digraph {")
for name, entry in result.items(): for name, entry in result.items():
print name print(name)
for used_by in entry.get("used", ()): for used_by in entry.get("used", ()):
if used_by in result: if used_by in result:
print used_by, "->", name print(used_by, "->", name)
print "}" print("}")
else: else:
raise ValueError("Unknown format %r" % (args.format)) raise ValueError("Unknown format %r" % (args.format))

View File

@ -1,8 +1,11 @@
#!/usr/bin/env python2 #!/usr/bin/env python2
import pymacaroons from __future__ import print_function
import sys import sys
import pymacaroons
if len(sys.argv) == 1: if len(sys.argv) == 1:
sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],)) sys.stderr.write("usage: %s macaroon [key]\n" % (sys.argv[0],))
sys.exit(1) sys.exit(1)
@ -11,14 +14,14 @@ macaroon_string = sys.argv[1]
key = sys.argv[2] if len(sys.argv) > 2 else None key = sys.argv[2] if len(sys.argv) > 2 else None
macaroon = pymacaroons.Macaroon.deserialize(macaroon_string) macaroon = pymacaroons.Macaroon.deserialize(macaroon_string)
print macaroon.inspect() print(macaroon.inspect())
print "" print("")
verifier = pymacaroons.Verifier() verifier = pymacaroons.Verifier()
verifier.satisfy_general(lambda c: True) verifier.satisfy_general(lambda c: True)
try: try:
verifier.verify(macaroon, key) verifier.verify(macaroon, key)
print "Signature is correct" print("Signature is correct")
except Exception as e: except Exception as e:
print str(e) print(str(e))

View File

@ -18,22 +18,22 @@
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import base64
import json
import sys
from urlparse import urlparse, urlunparse from urlparse import urlparse, urlunparse
import nacl.signing import nacl.signing
import json
import base64
import requests import requests
import sys
from requests.adapters import HTTPAdapter
import srvlookup import srvlookup
import yaml import yaml
from requests.adapters import HTTPAdapter
# uncomment the following to enable debug logging of http requests # uncomment the following to enable debug logging of http requests
# from httplib import HTTPConnection # from httplib import HTTPConnection
# HTTPConnection.debuglevel = 1 # HTTPConnection.debuglevel = 1
def encode_base64(input_bytes): def encode_base64(input_bytes):
"""Encode bytes as a base64 string without any padding.""" """Encode bytes as a base64 string without any padding."""
@ -88,6 +88,7 @@ def sign_json(json_object, signing_key, signing_name):
NACL_ED25519 = "ed25519" NACL_ED25519 = "ed25519"
def decode_signing_key_base64(algorithm, version, key_base64): def decode_signing_key_base64(algorithm, version, key_base64):
"""Decode a base64 encoded signing key """Decode a base64 encoded signing key
Args: Args:
@ -143,9 +144,7 @@ def request_json(method, origin_name, origin_key, destination, path, content):
authorization_headers = [] authorization_headers = []
for key, sig in signed_json["signatures"][origin_name].items(): for key, sig in signed_json["signatures"][origin_name].items():
header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (origin_name, key, sig)
origin_name, key, sig,
)
authorization_headers.append(bytes(header)) authorization_headers.append(bytes(header))
print("Authorization: %s" % header, file=sys.stderr) print("Authorization: %s" % header, file=sys.stderr)
@ -158,10 +157,7 @@ def request_json(method, origin_name, origin_key, destination, path, content):
result = s.request( result = s.request(
method=method, method=method,
url=dest, url=dest,
headers={ headers={"Host": destination, "Authorization": authorization_headers[0]},
"Host": destination,
"Authorization": authorization_headers[0]
},
verify=False, verify=False,
data=content, data=content,
) )
@ -171,50 +167,50 @@ def request_json(method, origin_name, origin_key, destination, path, content):
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description= description="Signs and sends a federation request to a matrix homeserver"
"Signs and sends a federation request to a matrix homeserver",
) )
parser.add_argument( parser.add_argument(
"-N", "--server-name", "-N",
"--server-name",
help="Name to give as the local homeserver. If unspecified, will be " help="Name to give as the local homeserver. If unspecified, will be "
"read from the config file.", "read from the config file.",
) )
parser.add_argument( parser.add_argument(
"-k", "--signing-key-path", "-k",
"--signing-key-path",
help="Path to the file containing the private ed25519 key to sign the " help="Path to the file containing the private ed25519 key to sign the "
"request with.", "request with.",
) )
parser.add_argument( parser.add_argument(
"-c", "--config", "-c",
"--config",
default="homeserver.yaml", default="homeserver.yaml",
help="Path to server config file. Ignored if --server-name and " help="Path to server config file. Ignored if --server-name and "
"--signing-key-path are both given.", "--signing-key-path are both given.",
) )
parser.add_argument( parser.add_argument(
"-d", "--destination", "-d",
"--destination",
default="matrix.org", default="matrix.org",
help="name of the remote homeserver. We will do SRV lookups and " help="name of the remote homeserver. We will do SRV lookups and "
"connect appropriately.", "connect appropriately.",
) )
parser.add_argument( parser.add_argument(
"-X", "--method", "-X",
"--method",
help="HTTP method to use for the request. Defaults to GET if --data is" help="HTTP method to use for the request. Defaults to GET if --data is"
"unspecified, POST if it is." "unspecified, POST if it is.",
) )
parser.add_argument( parser.add_argument("--body", help="Data to send as the body of the HTTP request")
"--body",
help="Data to send as the body of the HTTP request"
)
parser.add_argument( parser.add_argument(
"path", "path", help="request path. We will add '/_matrix/federation/v1/' to this."
help="request path. We will add '/_matrix/federation/v1/' to this."
) )
args = parser.parse_args() args = parser.parse_args()
@ -227,7 +223,9 @@ def main():
result = request_json( result = request_json(
args.method, args.method,
args.server_name, key, args.destination, args.server_name,
key,
args.destination,
"/_matrix/federation/v1/" + args.path, "/_matrix/federation/v1/" + args.path,
content=args.body, content=args.body,
) )
@ -263,7 +261,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
try: try:
srv = srvlookup.lookup("matrix", "tcp", s)[0] srv = srvlookup.lookup("matrix", "tcp", s)[0]
return srv.host, srv.port return srv.host, srv.port
except: except Exception:
return s, 8448 return s, 8448
def get_connection(self, url, proxies=None): def get_connection(self, url, proxies=None):
@ -272,10 +270,9 @@ class MatrixConnectionAdapter(HTTPAdapter):
(host, port) = self.lookup(parsed.netloc) (host, port) = self.lookup(parsed.netloc)
netloc = "%s:%d" % (host, port) netloc = "%s:%d" % (host, port)
print("Connecting to %s" % (netloc,), file=sys.stderr) print("Connecting to %s" % (netloc,), file=sys.stderr)
url = urlunparse(( url = urlunparse(
"https", netloc, parsed.path, parsed.params, parsed.query, ("https", netloc, parsed.path, parsed.params, parsed.query, parsed.fragment)
parsed.fragment, )
))
return super(MatrixConnectionAdapter, self).get_connection(url, proxies) return super(MatrixConnectionAdapter, self).get_connection(url, proxies)

View File

@ -1,23 +1,31 @@
from synapse.storage.pdu import PduStore from __future__ import print_function
from synapse.storage.signatures import SignatureStore
from synapse.storage._base import SQLBaseStore
from synapse.federation.units import Pdu
from synapse.crypto.event_signing import (
add_event_pdu_content_hash, compute_pdu_event_reference_hash
)
from synapse.api.events.utils import prune_pdu
from unpaddedbase64 import encode_base64, decode_base64
from canonicaljson import encode_canonical_json
import sqlite3 import sqlite3
import sys import sys
from unpaddedbase64 import decode_base64, encode_base64
from synapse.crypto.event_signing import (
add_event_pdu_content_hash,
compute_pdu_event_reference_hash,
)
from synapse.federation.units import Pdu
from synapse.storage._base import SQLBaseStore
from synapse.storage.pdu import PduStore
from synapse.storage.signatures import SignatureStore
class Store(object): class Store(object):
_get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"] _get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"]
_get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"] _get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"]
_get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"] _get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"]
_get_pdu_origin_signatures_txn = SignatureStore.__dict__["_get_pdu_origin_signatures_txn"] _get_pdu_origin_signatures_txn = SignatureStore.__dict__[
"_get_pdu_origin_signatures_txn"
]
_store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"] _store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"]
_store_pdu_reference_hash_txn = SignatureStore.__dict__["_store_pdu_reference_hash_txn"] _store_pdu_reference_hash_txn = SignatureStore.__dict__[
"_store_pdu_reference_hash_txn"
]
_store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"] _store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"]
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"] _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
@ -26,9 +34,7 @@ store = Store()
def select_pdus(cursor): def select_pdus(cursor):
cursor.execute( cursor.execute("SELECT pdu_id, origin FROM pdus ORDER BY depth ASC")
"SELECT pdu_id, origin FROM pdus ORDER BY depth ASC"
)
ids = cursor.fetchall() ids = cursor.fetchall()
@ -41,23 +47,30 @@ def select_pdus(cursor):
for pdu in pdus: for pdu in pdus:
try: try:
if pdu.prev_pdus: if pdu.prev_pdus:
print "PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus print("PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus)
for pdu_id, origin, hashes in pdu.prev_pdus: for pdu_id, origin, hashes in pdu.prev_pdus:
ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)] ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)]
hashes[ref_alg] = encode_base64(ref_hsh) hashes[ref_alg] = encode_base64(ref_hsh)
store._store_prev_pdu_hash_txn(cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh) store._store_prev_pdu_hash_txn(
print "SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh
)
print("SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus)
pdu = add_event_pdu_content_hash(pdu) pdu = add_event_pdu_content_hash(pdu)
ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu) ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu)
reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh) reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh)
store._store_pdu_reference_hash_txn(cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh) store._store_pdu_reference_hash_txn(
cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh
)
for alg, hsh_base64 in pdu.hashes.items(): for alg, hsh_base64 in pdu.hashes.items():
print alg, hsh_base64 print(alg, hsh_base64)
store._store_pdu_content_hash_txn(cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64)) store._store_pdu_content_hash_txn(
cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64)
)
except Exception:
print("FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus)
except:
print "FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus
def main(): def main():
conn = sqlite3.connect(sys.argv[1]) conn = sqlite3.connect(sys.argv[1])
@ -65,5 +78,6 @@ def main():
select_pdus(cursor) select_pdus(cursor)
conn.commit() conn.commit()
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -1,18 +1,17 @@
#! /usr/bin/python #! /usr/bin/python
import ast
import argparse import argparse
import ast
import os import os
import sys import sys
import yaml import yaml
PATTERNS_V1 = [] PATTERNS_V1 = []
PATTERNS_V2 = [] PATTERNS_V2 = []
RESULT = { RESULT = {"v1": PATTERNS_V1, "v2": PATTERNS_V2}
"v1": PATTERNS_V1,
"v2": PATTERNS_V2,
}
class CallVisitor(ast.NodeVisitor): class CallVisitor(ast.NodeVisitor):
def visit_Call(self, node): def visit_Call(self, node):
@ -21,7 +20,6 @@ class CallVisitor(ast.NodeVisitor):
else: else:
return return
if name == "client_path_patterns": if name == "client_path_patterns":
PATTERNS_V1.append(node.args[0].s) PATTERNS_V1.append(node.args[0].s)
elif name == "client_v2_patterns": elif name == "client_v2_patterns":
@ -42,8 +40,10 @@ def find_patterns_in_file(filepath):
parser = argparse.ArgumentParser(description='Find url patterns.') parser = argparse.ArgumentParser(description='Find url patterns.')
parser.add_argument( parser.add_argument(
"directories", nargs='+', metavar="DIR", "directories",
help="Directories to search for definitions" nargs='+',
metavar="DIR",
help="Directories to search for definitions",
) )
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,8 +1,9 @@
import requests
import collections import collections
import json
import sys import sys
import time import time
import json
import requests
Entry = collections.namedtuple("Entry", "name position rows") Entry = collections.namedtuple("Entry", "name position rows")
@ -30,11 +31,11 @@ def parse_response(content):
def replicate(server, streams): def replicate(server, streams):
return parse_response(requests.get( return parse_response(
server + "/_synapse/replication", requests.get(
verify=False, server + "/_synapse/replication", verify=False, params=streams
params=streams ).content
).content) )
def main(): def main():
@ -53,7 +54,7 @@ def main():
while True: while True:
try: try:
results = replicate(server, streams) results = replicate(server, streams)
except: except Exception:
sys.stdout.write("connection_lost(" + repr(streams) + ")\n") sys.stdout.write("connection_lost(" + repr(streams) + ")\n")
break break
for update in results.values(): for update in results.values():
@ -62,6 +63,5 @@ def main():
streams[update.name] = update.position streams[update.name] = update.position
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -1,12 +1,10 @@
#!/usr/bin/env python #!/usr/bin/env python
import argparse import argparse
import getpass
import sys import sys
import bcrypt import bcrypt
import getpass
import yaml import yaml
bcrypt_rounds=12 bcrypt_rounds=12
@ -52,4 +50,3 @@ if __name__ == "__main__":
password = prompt_for_pass() password = prompt_for_pass()
print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds)) print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds))

View File

@ -36,12 +36,9 @@ from __future__ import print_function
import argparse import argparse
import logging import logging
import sys
import os import os
import shutil import shutil
import sys
from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.filepath import MediaFilePaths
@ -77,24 +74,23 @@ def move_media(origin_server, file_id, src_paths, dest_paths):
if not os.path.exists(original_file): if not os.path.exists(original_file):
logger.warn( logger.warn(
"Original for %s/%s (%s) does not exist", "Original for %s/%s (%s) does not exist",
origin_server, file_id, original_file, origin_server,
file_id,
original_file,
) )
else: else:
mkdir_and_move( mkdir_and_move(
original_file, original_file, dest_paths.remote_media_filepath(origin_server, file_id)
dest_paths.remote_media_filepath(origin_server, file_id),
) )
# now look for thumbnails # now look for thumbnails
original_thumb_dir = src_paths.remote_media_thumbnail_dir( original_thumb_dir = src_paths.remote_media_thumbnail_dir(origin_server, file_id)
origin_server, file_id,
)
if not os.path.exists(original_thumb_dir): if not os.path.exists(original_thumb_dir):
return return
mkdir_and_move( mkdir_and_move(
original_thumb_dir, original_thumb_dir,
dest_paths.remote_media_thumbnail_dir(origin_server, file_id) dest_paths.remote_media_thumbnail_dir(origin_server, file_id),
) )
@ -109,24 +105,16 @@ def mkdir_and_move(original_file, dest_file):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=__doc__, description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
formatter_class = argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"-v", action='store_true', help='enable debug logging')
parser.add_argument(
"src_repo",
help="Path to source content repo",
)
parser.add_argument(
"dest_repo",
help="Path to source content repo",
) )
parser.add_argument("-v", action='store_true', help='enable debug logging')
parser.add_argument("src_repo", help="Path to source content repo")
parser.add_argument("dest_repo", help="Path to source content repo")
args = parser.parse_args() args = parser.parse_args()
logging_config = { logging_config = {
"level": logging.DEBUG if args.v else logging.INFO, "level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
} }
logging.basicConfig(**logging_config) logging.basicConfig(**logging_config)

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function
import argparse import argparse
import getpass import getpass
@ -22,19 +23,23 @@ import hmac
import json import json
import sys import sys
import urllib2 import urllib2
from six import input
import yaml import yaml
def request_registration(user, password, server_location, shared_secret, admin=False): def request_registration(user, password, server_location, shared_secret, admin=False):
req = urllib2.Request( req = urllib2.Request(
"%s/_matrix/client/r0/admin/register" % (server_location,), "%s/_matrix/client/r0/admin/register" % (server_location,),
headers={'Content-Type': 'application/json'} headers={'Content-Type': 'application/json'},
) )
try: try:
if sys.version_info[:3] >= (2, 7, 9): if sys.version_info[:3] >= (2, 7, 9):
# As of version 2.7.9, urllib2 now checks SSL certs # As of version 2.7.9, urllib2 now checks SSL certs
import ssl import ssl
f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
else: else:
f = urllib2.urlopen(req) f = urllib2.urlopen(req)
@ -42,18 +47,15 @@ def request_registration(user, password, server_location, shared_secret, admin=F
f.close() f.close()
nonce = json.loads(body)["nonce"] nonce = json.loads(body)["nonce"]
except urllib2.HTTPError as e: except urllib2.HTTPError as e:
print "ERROR! Received %d %s" % (e.code, e.reason,) print("ERROR! Received %d %s" % (e.code, e.reason))
if 400 <= e.code < 500: if 400 <= e.code < 500:
if e.info().type == "application/json": if e.info().type == "application/json":
resp = json.load(e) resp = json.load(e)
if "error" in resp: if "error" in resp:
print resp["error"] print(resp["error"])
sys.exit(1) sys.exit(1)
mac = hmac.new( mac = hmac.new(key=shared_secret, digestmod=hashlib.sha1)
key=shared_secret,
digestmod=hashlib.sha1,
)
mac.update(nonce) mac.update(nonce)
mac.update("\x00") mac.update("\x00")
@ -75,30 +77,31 @@ def request_registration(user, password, server_location, shared_secret, admin=F
server_location = server_location.rstrip("/") server_location = server_location.rstrip("/")
print "Sending registration request..." print("Sending registration request...")
req = urllib2.Request( req = urllib2.Request(
"%s/_matrix/client/r0/admin/register" % (server_location,), "%s/_matrix/client/r0/admin/register" % (server_location,),
data=json.dumps(data), data=json.dumps(data),
headers={'Content-Type': 'application/json'} headers={'Content-Type': 'application/json'},
) )
try: try:
if sys.version_info[:3] >= (2, 7, 9): if sys.version_info[:3] >= (2, 7, 9):
# As of version 2.7.9, urllib2 now checks SSL certs # As of version 2.7.9, urllib2 now checks SSL certs
import ssl import ssl
f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23)) f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
else: else:
f = urllib2.urlopen(req) f = urllib2.urlopen(req)
f.read() f.read()
f.close() f.close()
print "Success." print("Success.")
except urllib2.HTTPError as e: except urllib2.HTTPError as e:
print "ERROR! Received %d %s" % (e.code, e.reason,) print("ERROR! Received %d %s" % (e.code, e.reason))
if 400 <= e.code < 500: if 400 <= e.code < 500:
if e.info().type == "application/json": if e.info().type == "application/json":
resp = json.load(e) resp = json.load(e)
if "error" in resp: if "error" in resp:
print resp["error"] print(resp["error"])
sys.exit(1) sys.exit(1)
@ -106,35 +109,35 @@ def register_new_user(user, password, server_location, shared_secret, admin):
if not user: if not user:
try: try:
default_user = getpass.getuser() default_user = getpass.getuser()
except: except Exception:
default_user = None default_user = None
if default_user: if default_user:
user = raw_input("New user localpart [%s]: " % (default_user,)) user = input("New user localpart [%s]: " % (default_user,))
if not user: if not user:
user = default_user user = default_user
else: else:
user = raw_input("New user localpart: ") user = input("New user localpart: ")
if not user: if not user:
print "Invalid user name" print("Invalid user name")
sys.exit(1) sys.exit(1)
if not password: if not password:
password = getpass.getpass("Password: ") password = getpass.getpass("Password: ")
if not password: if not password:
print "Password cannot be blank." print("Password cannot be blank.")
sys.exit(1) sys.exit(1)
confirm_password = getpass.getpass("Confirm password: ") confirm_password = getpass.getpass("Confirm password: ")
if password != confirm_password: if password != confirm_password:
print "Passwords do not match" print("Passwords do not match")
sys.exit(1) sys.exit(1)
if admin is None: if admin is None:
admin = raw_input("Make admin [no]: ") admin = input("Make admin [no]: ")
if admin in ("y", "yes", "true"): if admin in ("y", "yes", "true"):
admin = True admin = True
else: else:
@ -148,40 +151,49 @@ if __name__ == "__main__":
description="Used to register new users with a given home server when" description="Used to register new users with a given home server when"
" registration has been disabled. The home server must be" " registration has been disabled. The home server must be"
" configured with the 'registration_shared_secret' option" " configured with the 'registration_shared_secret' option"
" set.", " set."
) )
parser.add_argument( parser.add_argument(
"-u", "--user", "-u",
"--user",
default=None, default=None,
help="Local part of the new user. Will prompt if omitted.", help="Local part of the new user. Will prompt if omitted.",
) )
parser.add_argument( parser.add_argument(
"-p", "--password", "-p",
"--password",
default=None, default=None,
help="New password for user. Will prompt if omitted.", help="New password for user. Will prompt if omitted.",
) )
admin_group = parser.add_mutually_exclusive_group() admin_group = parser.add_mutually_exclusive_group()
admin_group.add_argument( admin_group.add_argument(
"-a", "--admin", "-a",
"--admin",
action="store_true", action="store_true",
help="Register new user as an admin. Will prompt if --no-admin is not set either.", help=(
"Register new user as an admin. "
"Will prompt if --no-admin is not set either."
),
) )
admin_group.add_argument( admin_group.add_argument(
"--no-admin", "--no-admin",
action="store_true", action="store_true",
help="Register new user as a regular user. Will prompt if --admin is not set either.", help=(
"Register new user as a regular user. "
"Will prompt if --admin is not set either."
),
) )
group = parser.add_mutually_exclusive_group(required=True) group = parser.add_mutually_exclusive_group(required=True)
group.add_argument( group.add_argument(
"-c", "--config", "-c",
"--config",
type=argparse.FileType('r'), type=argparse.FileType('r'),
help="Path to server config file. Used to read in shared secret.", help="Path to server config file. Used to read in shared secret.",
) )
group.add_argument( group.add_argument(
"-k", "--shared-secret", "-k", "--shared-secret", help="Shared secret as defined in server config file."
help="Shared secret as defined in server config file.",
) )
parser.add_argument( parser.add_argument(
@ -198,7 +210,7 @@ if __name__ == "__main__":
config = yaml.safe_load(args.config) config = yaml.safe_load(args.config)
secret = config.get("registration_shared_secret", None) secret = config.get("registration_shared_secret", None)
if not secret: if not secret:
print "No 'registration_shared_secret' defined in config." print("No 'registration_shared_secret' defined in config.")
sys.exit(1) sys.exit(1)
else: else:
secret = args.shared_secret secret = args.shared_secret

View File

@ -15,23 +15,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer, reactor
from twisted.enterprise import adbapi
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
import argparse import argparse
import curses import curses
import logging import logging
import sys import sys
import time import time
import traceback import traceback
import yaml
from six import string_types from six import string_types
import yaml
from twisted.enterprise import adbapi
from twisted.internet import defer, reactor
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger("synapse_port_db") logger = logging.getLogger("synapse_port_db")
@ -105,6 +105,7 @@ class Store(object):
*All* database interactions should go through this object. *All* database interactions should go through this object.
""" """
def __init__(self, db_pool, engine): def __init__(self, db_pool, engine):
self.db_pool = db_pool self.db_pool = db_pool
self.database_engine = engine self.database_engine = engine
@ -135,7 +136,8 @@ class Store(object):
txn = conn.cursor() txn = conn.cursor()
return func( return func(
LoggingTransaction(txn, desc, self.database_engine, [], []), LoggingTransaction(txn, desc, self.database_engine, [], []),
*args, **kwargs *args,
**kwargs
) )
except self.database_engine.module.DatabaseError as e: except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e): if self.database_engine.is_deadlock(e):
@ -158,22 +160,20 @@ class Store(object):
def r(txn): def r(txn):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
return self.runInteraction("execute_sql", r) return self.runInteraction("execute_sql", r)
def insert_many_txn(self, txn, table, headers, rows): def insert_many_txn(self, txn, table, headers, rows):
sql = "INSERT INTO %s (%s) VALUES (%s)" % ( sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table, table,
", ".join(k for k in headers), ", ".join(k for k in headers),
", ".join("%s" for _ in headers) ", ".join("%s" for _ in headers),
) )
try: try:
txn.executemany(sql, rows) txn.executemany(sql, rows)
except: except Exception:
logger.exception( logger.exception("Failed to insert: %s", table)
"Failed to insert: %s",
table,
)
raise raise
@ -206,7 +206,7 @@ class Porter(object):
"table_name": table, "table_name": table,
"forward_rowid": 1, "forward_rowid": 1,
"backward_rowid": 0, "backward_rowid": 0,
} },
) )
forward_chunk = 1 forward_chunk = 1
@ -221,10 +221,10 @@ class Porter(object):
table, forward_chunk, backward_chunk table, forward_chunk, backward_chunk
) )
else: else:
def delete_all(txn): def delete_all(txn):
txn.execute( txn.execute(
"DELETE FROM port_from_sqlite3 WHERE table_name = %s", "DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
(table,)
) )
txn.execute("TRUNCATE %s CASCADE" % (table,)) txn.execute("TRUNCATE %s CASCADE" % (table,))
@ -232,11 +232,7 @@ class Porter(object):
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
) )
forward_chunk = 1 forward_chunk = 1
@ -251,12 +247,16 @@ class Porter(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, forward_chunk, def handle_table(
backward_chunk): self, table, postgres_size, table_size, forward_chunk, backward_chunk
):
logger.info( logger.info(
"Table %s: %i/%i (rows %i-%i) already ported", "Table %s: %i/%i (rows %i-%i) already ported",
table, postgres_size, table_size, table,
backward_chunk+1, forward_chunk-1, postgres_size,
table_size,
backward_chunk + 1,
forward_chunk - 1,
) )
if not table_size: if not table_size:
@ -271,7 +271,9 @@ class Porter(object):
return return
if table in ( if table in (
"user_directory", "user_directory_search", "users_who_share_rooms", "user_directory",
"user_directory_search",
"users_who_share_rooms",
"users_in_pubic_room", "users_in_pubic_room",
): ):
# We don't port these tables, as they're a faff and we can regenreate # We don't port these tables, as they're a faff and we can regenreate
@ -283,37 +285,35 @@ class Porter(object):
# We need to make sure there is a single row, `(X, null), as that is # We need to make sure there is a single row, `(X, null), as that is
# what synapse expects to be there. # what synapse expects to be there.
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table=table, table=table, values={"stream_id": None}
values={"stream_id": None},
) )
self.progress.update(table, table_size) # Mark table as done self.progress.update(table, table_size) # Mark table as done
return return
forward_select = ( forward_select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,)
% (table,)
) )
backward_select = ( backward_select = (
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" % (table,)
% (table,)
) )
do_forward = [True] do_forward = [True]
do_backward = [True] do_backward = [True]
while True: while True:
def r(txn): def r(txn):
forward_rows = [] forward_rows = []
backward_rows = [] backward_rows = []
if do_forward[0]: if do_forward[0]:
txn.execute(forward_select, (forward_chunk, self.batch_size,)) txn.execute(forward_select, (forward_chunk, self.batch_size))
forward_rows = txn.fetchall() forward_rows = txn.fetchall()
if not forward_rows: if not forward_rows:
do_forward[0] = False do_forward[0] = False
if do_backward[0]: if do_backward[0]:
txn.execute(backward_select, (backward_chunk, self.batch_size,)) txn.execute(backward_select, (backward_chunk, self.batch_size))
backward_rows = txn.fetchall() backward_rows = txn.fetchall()
if not backward_rows: if not backward_rows:
do_backward[0] = False do_backward[0] = False
@ -325,9 +325,7 @@ class Porter(object):
return headers, forward_rows, backward_rows return headers, forward_rows, backward_rows
headers, frows, brows = yield self.sqlite_store.runInteraction( headers, frows, brows = yield self.sqlite_store.runInteraction("select", r)
"select", r
)
if frows or brows: if frows or brows:
if frows: if frows:
@ -339,9 +337,7 @@ class Porter(object):
rows = self._convert_rows(table, headers, rows) rows = self._convert_rows(table, headers, rows)
def insert(txn): def insert(txn):
self.postgres_store.insert_many_txn( self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
txn, table, headers[1:], rows
)
self.postgres_store._simple_update_one_txn( self.postgres_store._simple_update_one_txn(
txn, txn,
@ -362,8 +358,9 @@ class Porter(object):
return return
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_search_table(self, postgres_size, table_size, forward_chunk, def handle_search_table(
backward_chunk): self, postgres_size, table_size, forward_chunk, backward_chunk
):
select = ( select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es" " FROM event_search as es"
@ -373,8 +370,9 @@ class Porter(object):
) )
while True: while True:
def r(txn): def r(txn):
txn.execute(select, (forward_chunk, self.batch_size,)) txn.execute(select, (forward_chunk, self.batch_size))
rows = txn.fetchall() rows = txn.fetchall()
headers = [column[0] for column in txn.description] headers = [column[0] for column in txn.description]
@ -402,7 +400,9 @@ class Porter(object):
else: else:
rows_dict.append(d) rows_dict.append(d)
txn.executemany(sql, [ txn.executemany(
sql,
[
( (
row["event_id"], row["event_id"],
row["room_id"], row["room_id"],
@ -413,7 +413,8 @@ class Porter(object):
row["stream_ordering"], row["stream_ordering"],
) )
for row in rows_dict for row in rows_dict
]) ],
)
self.postgres_store._simple_update_one_txn( self.postgres_store._simple_update_one_txn(
txn, txn,
@ -437,7 +438,8 @@ class Porter(object):
def setup_db(self, db_config, database_engine): def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect( db_conn = database_engine.module.connect(
**{ **{
k: v for k, v in db_config.get("args", {}).items() k: v
for k, v in db_config.get("args", {}).items()
if not k.startswith("cp_") if not k.startswith("cp_")
} }
) )
@ -450,13 +452,11 @@ class Porter(object):
def run(self): def run(self):
try: try:
sqlite_db_pool = adbapi.ConnectionPool( sqlite_db_pool = adbapi.ConnectionPool(
self.sqlite_config["name"], self.sqlite_config["name"], **self.sqlite_config["args"]
**self.sqlite_config["args"]
) )
postgres_db_pool = adbapi.ConnectionPool( postgres_db_pool = adbapi.ConnectionPool(
self.postgres_config["name"], self.postgres_config["name"], **self.postgres_config["args"]
**self.postgres_config["args"]
) )
sqlite_engine = create_engine(sqlite_config) sqlite_engine = create_engine(sqlite_config)
@ -465,9 +465,7 @@ class Porter(object):
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine) self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine) self.postgres_store = Store(postgres_db_pool, postgres_engine)
yield self.postgres_store.execute( yield self.postgres_store.execute(postgres_engine.check_database)
postgres_engine.check_database
)
# Step 1. Set up databases. # Step 1. Set up databases.
self.progress.set_state("Preparing SQLite3") self.progress.set_state("Preparing SQLite3")
@ -477,6 +475,7 @@ class Porter(object):
self.setup_db(postgres_config, postgres_engine) self.setup_db(postgres_config, postgres_engine)
self.progress.set_state("Creating port tables") self.progress.set_state("Creating port tables")
def create_port_table(txn): def create_port_table(txn):
txn.execute( txn.execute(
"CREATE TABLE IF NOT EXISTS port_from_sqlite3 (" "CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
@ -501,9 +500,7 @@ class Porter(object):
) )
try: try:
yield self.postgres_store.runInteraction( yield self.postgres_store.runInteraction("alter_table", alter_table)
"alter_table", alter_table
)
except Exception as e: except Exception as e:
pass pass
@ -514,11 +511,7 @@ class Porter(object):
# Step 2. Get tables. # Step 2. Get tables.
self.progress.set_state("Fetching tables") self.progress.set_state("Fetching tables")
sqlite_tables = yield self.sqlite_store._simple_select_onecol( sqlite_tables = yield self.sqlite_store._simple_select_onecol(
table="sqlite_master", table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
keyvalues={
"type": "table",
},
retcol="name",
) )
postgres_tables = yield self.postgres_store._simple_select_onecol( postgres_tables = yield self.postgres_store._simple_select_onecol(
@ -545,18 +538,14 @@ class Porter(object):
# Step 4. Do the copying. # Step 4. Do the copying.
self.progress.set_state("Copying to postgres") self.progress.set_state("Copying to postgres")
yield defer.gatherResults( yield defer.gatherResults(
[ [self.handle_table(*res) for res in setup_res], consumeErrors=True
self.handle_table(*res)
for res in setup_res
],
consumeErrors=True,
) )
# Step 5. Do final post-processing # Step 5. Do final post-processing
yield self._setup_state_group_id_seq() yield self._setup_state_group_id_seq()
self.progress.done() self.progress.done()
except: except Exception:
global end_error_exec_info global end_error_exec_info
end_error_exec_info = sys.exc_info() end_error_exec_info = sys.exc_info()
logger.exception("") logger.exception("")
@ -566,9 +555,7 @@ class Porter(object):
def _convert_rows(self, table, headers, rows): def _convert_rows(self, table, headers, rows):
bool_col_names = BOOLEAN_COLUMNS.get(table, []) bool_col_names = BOOLEAN_COLUMNS.get(table, [])
bool_cols = [ bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
i for i, h in enumerate(headers) if h in bool_col_names
]
class BadValueException(Exception): class BadValueException(Exception):
pass pass
@ -577,18 +564,21 @@ class Porter(object):
if j in bool_cols: if j in bool_cols:
return bool(col) return bool(col)
elif isinstance(col, string_types) and "\0" in col: elif isinstance(col, string_types) and "\0" in col:
logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col) logger.warn(
raise BadValueException(); "DROPPING ROW: NUL value in table %s col %s: %r",
table,
headers[j],
col,
)
raise BadValueException()
return col return col
outrows = [] outrows = []
for i, row in enumerate(rows): for i, row in enumerate(rows):
try: try:
outrows.append(tuple( outrows.append(
conv(j, col) tuple(conv(j, col) for j, col in enumerate(row) if j > 0)
for j, col in enumerate(row) )
if j > 0
))
except BadValueException: except BadValueException:
pass pass
@ -616,9 +606,7 @@ class Porter(object):
return headers, [r for r in rows if r[ts_ind] < yesterday] return headers, [r for r in rows if r[ts_ind] < yesterday]
headers, rows = yield self.sqlite_store.runInteraction( headers, rows = yield self.sqlite_store.runInteraction("select", r)
"select", r,
)
rows = self._convert_rows("sent_transactions", headers, rows) rows = self._convert_rows("sent_transactions", headers, rows)
@ -639,7 +627,7 @@ class Porter(object):
txn.execute( txn.execute(
"SELECT rowid FROM sent_transactions WHERE ts >= ?" "SELECT rowid FROM sent_transactions WHERE ts >= ?"
" ORDER BY rowid ASC LIMIT 1", " ORDER BY rowid ASC LIMIT 1",
(yesterday,) (yesterday,),
) )
rows = txn.fetchall() rows = txn.fetchall()
@ -657,21 +645,17 @@ class Porter(object):
"table_name": "sent_transactions", "table_name": "sent_transactions",
"forward_rowid": next_chunk, "forward_rowid": next_chunk,
"backward_rowid": 0, "backward_rowid": 0,
} },
) )
def get_sent_table_size(txn): def get_sent_table_size(txn):
txn.execute( txn.execute(
"SELECT count(*) FROM sent_transactions" "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
" WHERE ts >= ?",
(yesterday,)
) )
size, = txn.fetchone() size, = txn.fetchone()
return int(size) return int(size)
remaining_count = yield self.sqlite_store.execute( remaining_count = yield self.sqlite_store.execute(get_sent_table_size)
get_sent_table_size
)
total_count = remaining_count + inserted_rows total_count = remaining_count + inserted_rows
@ -680,13 +664,11 @@ class Porter(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
frows = yield self.sqlite_store.execute_sql( frows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
forward_chunk,
) )
brows = yield self.sqlite_store.execute_sql( brows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
backward_chunk,
) )
defer.returnValue(frows[0][0] + brows[0][0]) defer.returnValue(frows[0][0] + brows[0][0])
@ -694,7 +676,7 @@ class Porter(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_already_ported_count(self, table): def _get_already_ported_count(self, table):
rows = yield self.postgres_store.execute_sql( rows = yield self.postgres_store.execute_sql(
"SELECT count(*) FROM %s" % (table,), "SELECT count(*) FROM %s" % (table,)
) )
defer.returnValue(rows[0][0]) defer.returnValue(rows[0][0])
@ -718,21 +700,20 @@ class Porter(object):
def r(txn): def r(txn):
txn.execute("SELECT MAX(id) FROM state_groups") txn.execute("SELECT MAX(id) FROM state_groups")
next_id = txn.fetchone()[0] + 1 next_id = txn.fetchone()[0] + 1
txn.execute( txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
"ALTER SEQUENCE state_group_id_seq RESTART WITH %s",
(next_id,),
)
return self.postgres_store.runInteraction("setup_state_group_id_seq", r) return self.postgres_store.runInteraction("setup_state_group_id_seq", r)
############################################## ##############################################
###### The following is simply UI stuff ###### # The following is simply UI stuff
############################################## ##############################################
class Progress(object): class Progress(object):
"""Used to report progress of the port """Used to report progress of the port
""" """
def __init__(self): def __init__(self):
self.tables = {} self.tables = {}
@ -758,6 +739,7 @@ class Progress(object):
class CursesProgress(Progress): class CursesProgress(Progress):
"""Reports progress to a curses window """Reports progress to a curses window
""" """
def __init__(self, stdscr): def __init__(self, stdscr):
self.stdscr = stdscr self.stdscr = stdscr
@ -801,7 +783,7 @@ class CursesProgress(Progress):
duration = int(now) - int(self.start_time) duration = int(now) - int(self.start_time)
minutes, seconds = divmod(duration, 60) minutes, seconds = divmod(duration, 60)
duration_str = '%02dm %02ds' % (minutes, seconds,) duration_str = '%02dm %02ds' % (minutes, seconds)
if self.finished: if self.finished:
status = "Time spent: %s (Done!)" % (duration_str,) status = "Time spent: %s (Done!)" % (duration_str,)
@ -814,16 +796,12 @@ class CursesProgress(Progress):
est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60) est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
else: else:
est_remaining_str = "Unknown" est_remaining_str = "Unknown"
status = ( status = "Time spent: %s (est. remaining: %s)" % (
"Time spent: %s (est. remaining: %s)" duration_str,
% (duration_str, est_remaining_str,) est_remaining_str,
) )
self.stdscr.addstr( self.stdscr.addstr(0, 0, status, curses.A_BOLD)
0, 0,
status,
curses.A_BOLD,
)
max_len = max([len(t) for t in self.tables.keys()]) max_len = max([len(t) for t in self.tables.keys()])
@ -831,9 +809,7 @@ class CursesProgress(Progress):
middle_space = 1 middle_space = 1
items = self.tables.items() items = self.tables.items()
items.sort( items.sort(key=lambda i: (i[1]["perc"], i[0]))
key=lambda i: (i[1]["perc"], i[0]),
)
for i, (table, data) in enumerate(items): for i, (table, data) in enumerate(items):
if i + 2 >= rows: if i + 2 >= rows:
@ -844,9 +820,7 @@ class CursesProgress(Progress):
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr( self.stdscr.addstr(
i + 2, left_margin + max_len - len(table), i + 2, left_margin + max_len - len(table), table, curses.A_BOLD | color
table,
curses.A_BOLD | color,
) )
size = 20 size = 20
@ -857,15 +831,13 @@ class CursesProgress(Progress):
) )
self.stdscr.addstr( self.stdscr.addstr(
i + 2, left_margin + max_len + middle_space, i + 2,
left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
) )
if self.finished: if self.finished:
self.stdscr.addstr( self.stdscr.addstr(rows - 1, 0, "Press any key to exit...")
rows - 1, 0,
"Press any key to exit...",
)
self.stdscr.refresh() self.stdscr.refresh()
self.last_update = time.time() self.last_update = time.time()
@ -877,29 +849,25 @@ class CursesProgress(Progress):
def set_state(self, state): def set_state(self, state):
self.stdscr.clear() self.stdscr.clear()
self.stdscr.addstr( self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
0, 0,
state + "...",
curses.A_BOLD,
)
self.stdscr.refresh() self.stdscr.refresh()
class TerminalProgress(Progress): class TerminalProgress(Progress):
"""Just prints progress to the terminal """Just prints progress to the terminal
""" """
def update(self, table, num_done): def update(self, table, num_done):
super(TerminalProgress, self).update(table, num_done) super(TerminalProgress, self).update(table, num_done)
data = self.tables[table] data = self.tables[table]
print "%s: %d%% (%d/%d)" % ( print(
table, data["perc"], "%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
data["num_done"], data["total"],
) )
def set_state(self, state): def set_state(self, state):
print state + "..." print(state + "...")
############################################## ##############################################
@ -913,21 +881,25 @@ if __name__ == "__main__":
) )
parser.add_argument("-v", action='store_true') parser.add_argument("-v", action='store_true')
parser.add_argument( parser.add_argument(
"--sqlite-database", required=True, "--sqlite-database",
required=True,
help="The snapshot of the SQLite database file. This must not be" help="The snapshot of the SQLite database file. This must not be"
" currently used by a running synapse server" " currently used by a running synapse server",
) )
parser.add_argument( parser.add_argument(
"--postgres-config", type=argparse.FileType('r'), required=True, "--postgres-config",
help="The database config file for the PostgreSQL database" type=argparse.FileType('r'),
required=True,
help="The database config file for the PostgreSQL database",
) )
parser.add_argument( parser.add_argument(
"--curses", action='store_true', "--curses", action='store_true', help="display a curses based progress UI"
help="display a curses based progress UI"
) )
parser.add_argument( parser.add_argument(
"--batch-size", type=int, default=1000, "--batch-size",
type=int,
default=1000,
help="The number of rows to select from the SQLite table each" help="The number of rows to select from the SQLite table each"
" iteration [default=1000]", " iteration [default=1000]",
) )
@ -936,7 +908,7 @@ if __name__ == "__main__":
logging_config = { logging_config = {
"level": logging.DEBUG if args.v else logging.INFO, "level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
} }
if args.curses: if args.curses:

View File

@ -106,10 +106,7 @@ class Config(object):
@classmethod @classmethod
def check_file(cls, file_path, config_name): def check_file(cls, file_path, config_name):
if file_path is None: if file_path is None:
raise ConfigError( raise ConfigError("Missing config for %s." % (config_name,))
"Missing config for %s."
% (config_name,)
)
try: try:
os.stat(file_path) os.stat(file_path)
except OSError as e: except OSError as e:
@ -128,9 +125,7 @@ class Config(object):
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
if not os.path.isdir(dir_path): if not os.path.isdir(dir_path):
raise ConfigError( raise ConfigError("%s is not a directory" % (dir_path,))
"%s is not a directory" % (dir_path,)
)
return dir_path return dir_path
@classmethod @classmethod
@ -156,21 +151,20 @@ class Config(object):
return results return results
def generate_config( def generate_config(
self, self, config_dir_path, server_name, is_generating_file, report_stats=None
config_dir_path,
server_name,
is_generating_file,
report_stats=None,
): ):
default_config = "# vim:ft=yaml\n" default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( default_config += "\n\n".join(
dedent(conf)
for conf in self.invoke_all(
"default_config", "default_config",
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name, server_name=server_name,
is_generating_file=is_generating_file, is_generating_file=is_generating_file,
report_stats=report_stats, report_stats=report_stats,
)) )
)
config = yaml.load(default_config) config = yaml.load(default_config)
@ -178,15 +172,14 @@ class Config(object):
@classmethod @classmethod
def load_config(cls, description, argv): def load_config(cls, description, argv):
config_parser = argparse.ArgumentParser( config_parser = argparse.ArgumentParser(description=description)
description=description,
)
config_parser.add_argument( config_parser.add_argument(
"-c", "--config-path", "-c",
"--config-path",
action="append", action="append",
metavar="CONFIG_FILE", metavar="CONFIG_FILE",
help="Specify config file. Can be given multiple times and" help="Specify config file. Can be given multiple times and"
" may specify directories containing *.yaml files." " may specify directories containing *.yaml files.",
) )
config_parser.add_argument( config_parser.add_argument(
@ -203,9 +196,7 @@ class Config(object):
obj = cls() obj = cls()
obj.read_config_files( obj.read_config_files(
config_files, config_files, keys_directory=config_args.keys_directory, generate_keys=False
keys_directory=config_args.keys_directory,
generate_keys=False,
) )
return obj return obj
@ -213,38 +204,38 @@ class Config(object):
def load_or_generate_config(cls, description, argv): def load_or_generate_config(cls, description, argv):
config_parser = argparse.ArgumentParser(add_help=False) config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument( config_parser.add_argument(
"-c", "--config-path", "-c",
"--config-path",
action="append", action="append",
metavar="CONFIG_FILE", metavar="CONFIG_FILE",
help="Specify config file. Can be given multiple times and" help="Specify config file. Can be given multiple times and"
" may specify directories containing *.yaml files." " may specify directories containing *.yaml files.",
) )
config_parser.add_argument( config_parser.add_argument(
"--generate-config", "--generate-config",
action="store_true", action="store_true",
help="Generate a config file for the server name" help="Generate a config file for the server name",
) )
config_parser.add_argument( config_parser.add_argument(
"--report-stats", "--report-stats",
action="store", action="store",
help="Whether the generated config reports anonymized usage statistics", help="Whether the generated config reports anonymized usage statistics",
choices=["yes", "no"] choices=["yes", "no"],
) )
config_parser.add_argument( config_parser.add_argument(
"--generate-keys", "--generate-keys",
action="store_true", action="store_true",
help="Generate any missing key files then exit" help="Generate any missing key files then exit",
) )
config_parser.add_argument( config_parser.add_argument(
"--keys-directory", "--keys-directory",
metavar="DIRECTORY", metavar="DIRECTORY",
help="Used with 'generate-*' options to specify where files such as" help="Used with 'generate-*' options to specify where files such as"
" certs and signing keys should be stored in, unless explicitly" " certs and signing keys should be stored in, unless explicitly"
" specified in the config." " specified in the config.",
) )
config_parser.add_argument( config_parser.add_argument(
"-H", "--server-name", "-H", "--server-name", help="The server name to generate a config file for"
help="The server name to generate a config file for"
) )
config_args, remaining_args = config_parser.parse_known_args(argv) config_args, remaining_args = config_parser.parse_known_args(argv)
@ -257,8 +248,8 @@ class Config(object):
if config_args.generate_config: if config_args.generate_config:
if config_args.report_stats is None: if config_args.report_stats is None:
config_parser.error( config_parser.error(
"Please specify either --report-stats=yes or --report-stats=no\n\n" + "Please specify either --report-stats=yes or --report-stats=no\n\n"
MISSING_REPORT_STATS_SPIEL + MISSING_REPORT_STATS_SPIEL
) )
if not config_files: if not config_files:
config_parser.error( config_parser.error(
@ -287,26 +278,32 @@ class Config(object):
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name, server_name=server_name,
report_stats=(config_args.report_stats == "yes"), report_stats=(config_args.report_stats == "yes"),
is_generating_file=True is_generating_file=True,
) )
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
config_file.write(config_str) config_file.write(config_str)
print(( print(
(
"A config file has been generated in %r for server name" "A config file has been generated in %r for server name"
" %r with corresponding SSL keys and self-signed" " %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it" " certificates. Please review this file and customise it"
" to your needs." " to your needs."
) % (config_path, server_name)) )
% (config_path, server_name)
)
print( print(
"If this server name is incorrect, you will need to" "If this server name is incorrect, you will need to"
" regenerate the SSL certificates" " regenerate the SSL certificates"
) )
return return
else: else:
print(( print(
(
"Config file %r already exists. Generating any missing key" "Config file %r already exists. Generating any missing key"
" files." " files."
) % (config_path,)) )
% (config_path,)
)
generate_keys = True generate_keys = True
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -338,8 +335,7 @@ class Config(object):
return obj return obj
def read_config_files(self, config_files, keys_directory=None, def read_config_files(self, config_files, keys_directory=None, generate_keys=False):
generate_keys=False):
if not keys_directory: if not keys_directory:
keys_directory = os.path.dirname(config_files[-1]) keys_directory = os.path.dirname(config_files[-1])
@ -364,8 +360,9 @@ class Config(object):
if "report_stats" not in config: if "report_stats" not in config:
raise ConfigError( raise ConfigError(
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS
MISSING_REPORT_STATS_SPIEL + "\n"
+ MISSING_REPORT_STATS_SPIEL
) )
if generate_keys: if generate_keys:
@ -399,16 +396,16 @@ def find_config_files(search_paths):
for entry in os.listdir(config_path): for entry in os.listdir(config_path):
entry_path = os.path.join(config_path, entry) entry_path = os.path.join(config_path, entry)
if not os.path.isfile(entry_path): if not os.path.isfile(entry_path):
print ( err = "Found subdirectory in config directory: %r. IGNORING."
"Found subdirectory in config directory: %r. IGNORING." print(err % (entry_path,))
) % (entry_path, )
continue continue
if not entry.endswith(".yaml"): if not entry.endswith(".yaml"):
print ( err = (
"Found file in config directory that does not" "Found file in config directory that does not end in "
" end in '.yaml': %r. IGNORING." "'.yaml': %r. IGNORING."
) % (entry_path, ) )
print(err % (entry_path,))
continue continue
files.append(entry_path) files.append(entry_path)

View File

@ -18,7 +18,7 @@ import threading
import time import time
from six import PY2, iteritems, iterkeys, itervalues from six import PY2, iteritems, iterkeys, itervalues
from six.moves import intern, range from six.moves import builtins, intern, range
from canonicaljson import json from canonicaljson import json
from prometheus_client import Histogram from prometheus_client import Histogram
@ -1233,7 +1233,7 @@ def db_to_json(db_content):
# psycopg2 on Python 2 returns buffer objects, which we need to cast to # psycopg2 on Python 2 returns buffer objects, which we need to cast to
# bytes to decode # bytes to decode
if PY2 and isinstance(db_content, buffer): if PY2 and isinstance(db_content, builtins.buffer):
db_content = bytes(db_content) db_content = bytes(db_content)
# Decode it to a Unicode string before feeding it to json.loads, so we # Decode it to a Unicode string before feeding it to json.loads, so we

View File

@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
# py2 sqlite has buffer hardcoded as only binary type, so we must use it, # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview # despite being deprecated and removed in favor of memoryview
if six.PY2: if six.PY2:
db_binary_type = buffer db_binary_type = six.moves.builtins.buffer
else: else:
db_binary_type = memoryview db_binary_type = memoryview

View File

@ -29,7 +29,7 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if six.PY2: if six.PY2:
db_binary_type = buffer db_binary_type = six.moves.builtins.buffer
else: else:
db_binary_type = memoryview db_binary_type = memoryview

View File

@ -27,7 +27,7 @@ from ._base import SQLBaseStore
# py2 sqlite has buffer hardcoded as only binary type, so we must use it, # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview # despite being deprecated and removed in favor of memoryview
if six.PY2: if six.PY2:
db_binary_type = buffer db_binary_type = six.moves.builtins.buffer
else: else:
db_binary_type = memoryview db_binary_type = memoryview

View File

@ -30,7 +30,7 @@ from ._base import SQLBaseStore, db_to_json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it, # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview # despite being deprecated and removed in favor of memoryview
if six.PY2: if six.PY2:
db_binary_type = buffer db_binary_type = six.moves.builtins.buffer
else: else:
db_binary_type = memoryview db_binary_type = memoryview

View File

@ -15,6 +15,8 @@
import logging import logging
from six import integer_types
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from synapse.util import caches from synapse.util import caches
@ -47,7 +49,7 @@ class StreamChangeCache(object):
def has_entity_changed(self, entity, stream_pos): def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos """Returns True if the entity may have been updated since stream_pos
""" """
assert type(stream_pos) is int or type(stream_pos) is long assert type(stream_pos) in integer_types
if stream_pos < self._earliest_known_stream_pos: if stream_pos < self._earliest_known_stream_pos:
self.metrics.inc_misses() self.metrics.inc_misses()

76
synctl
View File

@ -76,8 +76,7 @@ def start(configfile):
try: try:
subprocess.check_call(args) subprocess.check_call(args)
write("started synapse.app.homeserver(%r)" % write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
(configfile,), colour=GREEN)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
write( write(
"error starting (exit code: %d); see above for logs" % e.returncode, "error starting (exit code: %d); see above for logs" % e.returncode,
@ -86,21 +85,15 @@ def start(configfile):
def start_worker(app, configfile, worker_configfile): def start_worker(app, configfile, worker_configfile):
args = [ args = [sys.executable, "-B", "-m", app, "-c", configfile, "-c", worker_configfile]
sys.executable, "-B",
"-m", app,
"-c", configfile,
"-c", worker_configfile
]
try: try:
subprocess.check_call(args) subprocess.check_call(args)
write("started %s(%r)" % (app, worker_configfile), colour=GREEN) write("started %s(%r)" % (app, worker_configfile), colour=GREEN)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
write( write(
"error starting %s(%r) (exit code: %d); see above for logs" % ( "error starting %s(%r) (exit code: %d); see above for logs"
app, worker_configfile, e.returncode, % (app, worker_configfile, e.returncode),
),
colour=RED, colour=RED,
) )
@ -120,9 +113,9 @@ def stop(pidfile, app):
abort("Cannot stop %s: Unknown error" % (app,)) abort("Cannot stop %s: Unknown error" % (app,))
Worker = collections.namedtuple("Worker", [ Worker = collections.namedtuple(
"app", "configfile", "pidfile", "cache_factor", "cache_factors", "Worker", ["app", "configfile", "pidfile", "cache_factor", "cache_factors"]
]) )
def main(): def main():
@ -141,12 +134,11 @@ def main():
help="the homeserver config file, defaults to homeserver.yaml", help="the homeserver config file, defaults to homeserver.yaml",
) )
parser.add_argument( parser.add_argument(
"-w", "--worker", "-w", "--worker", metavar="WORKERCONFIG", help="start or stop a single worker"
metavar="WORKERCONFIG",
help="start or stop a single worker",
) )
parser.add_argument( parser.add_argument(
"-a", "--all-processes", "-a",
"--all-processes",
metavar="WORKERCONFIGDIR", metavar="WORKERCONFIGDIR",
help="start or stop all the workers in the given directory" help="start or stop all the workers in the given directory"
" and the main synapse process", " and the main synapse process",
@ -155,10 +147,7 @@ def main():
options = parser.parse_args() options = parser.parse_args()
if options.worker and options.all_processes: if options.worker and options.all_processes:
write( write('Cannot use "--worker" with "--all-processes"', stream=sys.stderr)
'Cannot use "--worker" with "--all-processes"',
stream=sys.stderr
)
sys.exit(1) sys.exit(1)
configfile = options.configfile configfile = options.configfile
@ -167,9 +156,7 @@ def main():
write( write(
"No config file found\n" "No config file found\n"
"To generate a config file, run '%s -c %s --generate-config" "To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % ( " --server-name=<server name>'\n" % (" ".join(SYNAPSE), options.configfile),
" ".join(SYNAPSE), options.configfile
),
stream=sys.stderr, stream=sys.stderr,
) )
sys.exit(1) sys.exit(1)
@ -194,8 +181,7 @@ def main():
worker_configfile = options.worker worker_configfile = options.worker
if not os.path.exists(worker_configfile): if not os.path.exists(worker_configfile):
write( write(
"No worker config found at %r" % (worker_configfile,), "No worker config found at %r" % (worker_configfile,), stream=sys.stderr
stream=sys.stderr,
) )
sys.exit(1) sys.exit(1)
worker_configfiles.append(worker_configfile) worker_configfiles.append(worker_configfile)
@ -211,9 +197,9 @@ def main():
stream=sys.stderr, stream=sys.stderr,
) )
sys.exit(1) sys.exit(1)
worker_configfiles.extend(sorted(glob.glob( worker_configfiles.extend(
os.path.join(worker_configdir, "*.yaml") sorted(glob.glob(os.path.join(worker_configdir, "*.yaml")))
))) )
workers = [] workers = []
for worker_configfile in worker_configfiles: for worker_configfile in worker_configfiles:
@ -223,14 +209,12 @@ def main():
if worker_app == "synapse.app.homeserver": if worker_app == "synapse.app.homeserver":
# We need to special case all of this to pick up options that may # We need to special case all of this to pick up options that may
# be set in the main config file or in this worker config file. # be set in the main config file or in this worker config file.
worker_pidfile = ( worker_pidfile = worker_config.get("pid_file") or pidfile
worker_config.get("pid_file") worker_cache_factor = (
or pidfile worker_config.get("synctl_cache_factor") or cache_factor
) )
worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
worker_cache_factors = ( worker_cache_factors = (
worker_config.get("synctl_cache_factors") worker_config.get("synctl_cache_factors") or cache_factors
or cache_factors
) )
daemonize = worker_config.get("daemonize") or config.get("daemonize") daemonize = worker_config.get("daemonize") or config.get("daemonize")
assert daemonize, "Main process must have daemonize set to true" assert daemonize, "Main process must have daemonize set to true"
@ -239,19 +223,27 @@ def main():
for key in worker_config: for key in worker_config:
if key == "worker_app": # But we allow worker_app if key == "worker_app": # But we allow worker_app
continue continue
assert not key.startswith("worker_"), \ assert not key.startswith(
"Main process cannot use worker_* config" "worker_"
), "Main process cannot use worker_* config"
else: else:
worker_pidfile = worker_config["worker_pid_file"] worker_pidfile = worker_config["worker_pid_file"]
worker_daemonize = worker_config["worker_daemonize"] worker_daemonize = worker_config["worker_daemonize"]
assert worker_daemonize, "In config %r: expected '%s' to be True" % ( assert worker_daemonize, "In config %r: expected '%s' to be True" % (
worker_configfile, "worker_daemonize") worker_configfile,
"worker_daemonize",
)
worker_cache_factor = worker_config.get("synctl_cache_factor") worker_cache_factor = worker_config.get("synctl_cache_factor")
worker_cache_factors = worker_config.get("synctl_cache_factors", {}) worker_cache_factors = worker_config.get("synctl_cache_factors", {})
workers.append(Worker( workers.append(
worker_app, worker_configfile, worker_pidfile, worker_cache_factor, Worker(
worker_app,
worker_configfile,
worker_pidfile,
worker_cache_factor,
worker_cache_factors, worker_cache_factors,
)) )
)
action = options.action action = options.action

View File

@ -108,10 +108,10 @@ commands =
[testenv:pep8] [testenv:pep8]
skip_install = True skip_install = True
basepython = python2.7 basepython = python3.6
deps = deps =
flake8 flake8
commands = /bin/sh -c "flake8 synapse tests {env:PEP8SUFFIX:}" commands = /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}"
[testenv:check_isort] [testenv:check_isort]
skip_install = True skip_install = True