Finish up upgrade script

This commit is contained in:
Erik Johnston 2014-12-15 16:14:34 +00:00
parent 65cdf4e724
commit b75adaedca
2 changed files with 95 additions and 31 deletions

View File

@ -4,25 +4,42 @@ from synapse.storage.event_federation import EventFederationStore
from syutil.base64util import encode_base64, decode_base64 from syutil.base64util import encode_base64, decode_base64
from synapse.events import FrozenEvent from synapse.crypto.event_signing import compute_event_signature
from synapse.events.builder import EventBuilder from synapse.events.builder import EventBuilder
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from syutil.crypto.jsonsign import verify_signed_json, SignatureVerifyException from syutil.crypto.jsonsign import (
from syutil.crypto.signing_key import ( verify_signed_json, SignatureVerifyException,
decode_verify_key_bytes, write_signing_keys
) )
from syutil.crypto.signing_key import decode_verify_key_bytes
from syutil.jsonutil import encode_canonical_json
import argparse
import dns.resolver import dns.resolver
import hashlib import hashlib
import json import json
import sqlite3 import sqlite3
import sys import syutil
import urllib2 import urllib2
delta_sql = """
CREATE TABLE IF NOT EXISTS event_json(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
json BLOB NOT NULL,
CONSTRAINT ev_j_uniq UNIQUE (event_id)
);
CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
"""
class Store(object): class Store(object):
_get_event_signatures_txn = SignatureStore.__dict__["_get_event_signatures_txn"] _get_event_signatures_txn = SignatureStore.__dict__["_get_event_signatures_txn"]
_get_event_content_hashes_txn = SignatureStore.__dict__["_get_event_content_hashes_txn"] _get_event_content_hashes_txn = SignatureStore.__dict__["_get_event_content_hashes_txn"]
@ -33,10 +50,9 @@ class Store(object):
cursor_to_dict = SQLBaseStore.__dict__["cursor_to_dict"] cursor_to_dict = SQLBaseStore.__dict__["cursor_to_dict"]
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"] _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_list_txn = SQLBaseStore.__dict__["_simple_select_list_txn"] _simple_select_list_txn = SQLBaseStore.__dict__["_simple_select_list_txn"]
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
def _generate_event_json(self, txn, rows): def _generate_event_json(self, txn, rows):
sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
events = [] events = []
for row in rows: for row in rows:
d = dict(row) d = dict(row)
@ -145,7 +161,9 @@ def get_key(server_name):
keys = json.load(urllib2.urlopen(url, timeout=2)) keys = json.load(urllib2.urlopen(url, timeout=2))
verify_keys = {} verify_keys = {}
for key_id, key_base64 in keys["verify_keys"].items(): for key_id, key_base64 in keys["verify_keys"].items():
verify_key = decode_verify_key_bytes(key_id, decode_base64(key_base64)) verify_key = decode_verify_key_bytes(
key_id, decode_base64(key_base64)
)
verify_signed_json(keys, server_name, verify_key) verify_signed_json(keys, server_name, verify_key)
verify_keys[key_id] = verify_key verify_keys[key_id] = verify_key
print "Got keys for: %s" % (server_name,) print "Got keys for: %s" % (server_name,)
@ -157,18 +175,11 @@ def get_key(server_name):
return {} return {}
def get_events(cursor): def reinsert_events(cursor, server_name, signing_key):
# cursor.execute( cursor.executescript(delta_sql)
# "SELECT * FROM events WHERE event_id = ? ORDER BY rowid DESC",
# ("$14182049031533SMfTT:matrix.org",)
# )
# cursor.execute(
# "SELECT * FROM events ORDER BY rowid DESC LIMIT 10000"
# )
cursor.execute( cursor.execute(
"SELECT * FROM events ORDER BY rowid DESC" "SELECT * FROM events ORDER BY rowid ASC"
) )
rows = store.cursor_to_dict(cursor) rows = store.cursor_to_dict(cursor)
@ -181,19 +192,26 @@ def get_events(cursor):
"sha256": hashlib.sha256, "sha256": hashlib.sha256,
} }
server_keys = {} key_id = "%s:%s" % (signing_key.alg, signing_key.version)
verify_key = signing_key.verify_key
verify_key.alg = signing_key.alg
verify_key.version = signing_key.version
server_keys = {
server_name: {
key_id: verify_key
}
}
for event in events: for event in events:
for alg_name in event.hashes: for alg_name in event.hashes:
if check_event_content_hash(event, algorithms[alg_name]): if check_event_content_hash(event, algorithms[alg_name]):
# print "PASS content hash %s" % (alg_name,)
pass pass
else: else:
pass pass
print "FAIL content hash %s %s" % (alg_name, event.event_id, ) print "FAIL content hash %s %s" % (alg_name, event.event_id, )
# print "%s %d" % (event.event_id, event.origin_server_ts)
# print json.dumps(event.get_pdu_json(), indent=4, sort_keys=True)
have_own_correctly_signed = False
for host, sigs in event.signatures.items(): for host, sigs in event.signatures.items():
pruned = prune_event(event) pruned = prune_event(event)
@ -207,17 +225,63 @@ def get_events(cursor):
host, host,
server_keys[host][key_id] server_keys[host][key_id]
) )
except SignatureVerifyException as e:
# print e
print "FAIL signature check %s %s" % (key_id, event.event_id)
# print json.dumps(pruned.get_pdu_json(), indent=4, sort_keys=True)
def main(): if host == server_name:
conn = sqlite3.connect(sys.argv[1]) have_own_correctly_signed = True
except SignatureVerifyException:
print "FAIL signature check %s %s" % (
key_id, event.event_id
)
# TODO: Re sign with our own server key
if not have_own_correctly_signed:
sigs = compute_event_signature(event, server_name, signing_key)
event.signatures.update(sigs)
pruned = prune_event(event)
for key_id in event.signatures[server_name]:
verify_signed_json(
pruned.get_pdu_json(),
server_name,
server_keys[server_name][key_id]
)
event_json = encode_canonical_json(
event.get_dict()
).decode("UTF-8")
store._simple_insert_txn(
cursor,
table="event_json",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"json": event_json,
},
or_replace=True,
)
def main(database, server_name, signing_key):
conn = sqlite3.connect(database)
cursor = conn.cursor() cursor = conn.cursor()
get_events(cursor) reinsert_events(cursor, server_name, signing_key)
conn.commit() conn.commit()
if __name__ == "__main__": if __name__ == "__main__":
main() parser = argparse.ArgumentParser()
parser.add_argument("database")
parser.add_argument("server_name")
parser.add_argument(
"signing_key", type=argparse.FileType('r'),
)
args = parser.parse_args()
signing_key = syutil.crypto.signing_key.read_signing_keys(
args.signing_key
)
main(args.database, args.server_name, signing_key[0])