Fix unicode database support

This commit is contained in:
Erik Johnston 2015-03-25 17:15:20 +00:00
parent 0ba393924a
commit 0e8f5095c7
15 changed files with 88 additions and 44 deletions

View File

@ -110,14 +110,12 @@ class SynapseHomeServer(HomeServer):
return None return None
def build_db_pool(self): def build_db_pool(self):
name = self.db_config.pop("name", None) name = self.db_config["name"]
if name == "MySQLdb":
return adbapi.ConnectionPool(
name,
**self.db_config
)
raise RuntimeError("Unsupported database type") return adbapi.ConnectionPool(
name,
**self.db_config.get("args", {})
)
def create_resource_tree(self, redirect_root_to_web_client): def create_resource_tree(self, redirect_root_to_web_client):
"""Create the resource tree for this Home Server. """Create the resource tree for this Home Server.
@ -323,7 +321,7 @@ def change_resource_limit(soft_file_no):
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard)) resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard))
logger.info("Set file limit to: %d", soft_file_no) logger.info("Set file limit to: %d", soft_file_no)
except (ValueError, resource.error) as e: except ( ValueError, resource.error) as e:
logger.warn("Failed to set file limit: %s", e) logger.warn("Failed to set file limit: %s", e)
@ -363,20 +361,33 @@ def setup(config_options):
if config.database_config: if config.database_config:
with open(config.database_config, 'r') as f: with open(config.database_config, 'r') as f:
db_config = yaml.safe_load(f) db_config = yaml.safe_load(f)
name = db_config.get("name", None)
if name == "MySQLdb":
db_config.update({
"sql_mode": "TRADITIONAL",
"charset": "utf8",
"use_unicode": True,
})
else: else:
db_config = { db_config = {
"name": "sqlite3", "name": "sqlite3",
"database": config.database_path, "database": config.database_path,
} }
db_config = {
k: v for k, v in db_config.items()
if not k.startswith("cp_")
}
name = db_config.get("name", None)
if name in ["MySQLdb", "mysql.connector"]:
db_config.setdefault("args", {}).update({
"sql_mode": "TRADITIONAL",
"charset": "utf8",
"use_unicode": True,
})
elif name == "sqlite3":
db_config.setdefault("args", {}).update({
"cp_min": 1,
"cp_max": 1,
"cp_openfun": prepare_database,
})
else:
raise RuntimeError("Unsupported database type '%s'" % (name,))
hs = SynapseHomeServer( hs = SynapseHomeServer(
config.server_name, config.server_name,
domain_with_port=domain_with_port, domain_with_port=domain_with_port,
@ -401,8 +412,8 @@ def setup(config_options):
# with sqlite3.connect(db_name) as db_conn: # with sqlite3.connect(db_name) as db_conn:
# prepare_sqlite3_database(db_conn) # prepare_sqlite3_database(db_conn)
# prepare_database(db_conn) # prepare_database(db_conn)
import MySQLdb import mysql.connector
db_conn = MySQLdb.connect(**db_config) db_conn = mysql.connector.connect(**db_config.get("args", {}))
prepare_database(db_conn) prepare_database(db_conn)
except UpgradeDatabaseException: except UpgradeDatabaseException:
sys.stderr.write( sys.stderr.write(

View File

@ -57,7 +57,7 @@ class LoginHandler(BaseHandler):
logger.warn("Attempted to login as %s but they do not exist", user) logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info[0]["password_hash"] stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash): if bcrypt.checkpw(password, stored_hash):
# generate an access token and store it. # generate an access token and store it.
token = self.reg_handler._generate_token(user) token = self.reg_handler._generate_token(user)

View File

@ -19,9 +19,13 @@ from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_pattern
from synapse.types import UserID from synapse.types import UserID
import logging
import simplejson as json import simplejson as json
logger = logging.getLogger(__name__)
class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname") PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname")
@ -47,7 +51,8 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_displayname( yield self.handlers.profile_handler.set_displayname(
user, auth_user, new_name) user, auth_user, new_name
)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -410,10 +410,14 @@ def executescript(txn, schema_path):
def _get_or_create_schema_state(txn): def _get_or_create_schema_state(txn):
schema_path = os.path.join( try:
dir_path, "schema", "schema_version.sql", # Bluntly try creating the schema_version tables.
) schema_path = os.path.join(
executescript(txn, schema_path) dir_path, "schema", "schema_version.sql",
)
executescript(txn, schema_path)
except:
pass
txn.execute("SELECT version, upgraded FROM schema_version") txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone() row = txn.fetchone()

View File

@ -755,6 +755,8 @@ class SQLBaseStore(object):
return None return None
internal_metadata, js, redacted, rejected_reason = res internal_metadata, js, redacted, rejected_reason = res
js = js.decode("utf8")
internal_metadata = internal_metadata.decode("utf8")
start_time = update_counter("select_event", start_time) start_time = update_counter("select_event", start_time)
@ -779,9 +781,11 @@ class SQLBaseStore(object):
sql_getevents_timer.inc_by(curr_time - last_time, desc) sql_getevents_timer.inc_by(curr_time - last_time, desc)
return curr_time return curr_time
logger.debug("Got js: %r", js)
d = json.loads(js) d = json.loads(js)
start_time = update_counter("decode_json", start_time) start_time = update_counter("decode_json", start_time)
logger.debug("Got internal_metadata: %r", internal_metadata)
internal_metadata = json.loads(internal_metadata) internal_metadata = json.loads(internal_metadata)
start_time = update_counter("decode_internal", start_time) start_time = update_counter("decode_internal", start_time)

View File

@ -294,15 +294,17 @@ class EventsStore(SQLBaseStore):
) )
if is_new_state and not context.rejected: if is_new_state and not context.rejected:
self._simple_insert_txn( self._simple_upsert_txn(
txn, txn,
"current_state_events", "current_state_events",
{ keyvalues={
"event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
"type": event.type, "type": event.type,
"state_key": event.state_key, "state_key": event.state_key,
}, },
values={
"event_id": event.event_id,
}
) )
for e_id, h in event.prev_state: for e_id, h in event.prev_state:

View File

@ -64,7 +64,7 @@ class KeyStore(SQLBaseStore):
"fingerprint": fingerprint, "fingerprint": fingerprint,
"from_server": from_server, "from_server": from_server,
"ts_added_ms": time_now_ms, "ts_added_ms": time_now_ms,
"tls_certificate": buffer(tls_certificate_bytes), "tls_certificate": tls_certificate_bytes,
}, },
) )
@ -113,6 +113,6 @@ class KeyStore(SQLBaseStore):
"key_id": "%s:%s" % (verify_key.alg, verify_key.version), "key_id": "%s:%s" % (verify_key.alg, verify_key.version),
"from_server": from_server, "from_server": from_server,
"ts_added_ms": time_now_ms, "ts_added_ms": time_now_ms,
"verify_key": buffer(verify_key.encode()), "verify_key": verify_key.encode(),
}, },
) )

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
@ -24,19 +26,25 @@ class ProfileStore(SQLBaseStore):
desc="create_profile", desc="create_profile",
) )
@defer.inlineCallbacks
def get_profile_displayname(self, user_localpart): def get_profile_displayname(self, user_localpart):
return self._simple_select_one_onecol( name = yield self._simple_select_one_onecol(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
retcol="displayname", retcol="displayname",
desc="get_profile_displayname", desc="get_profile_displayname",
) )
if name:
name = name.decode("utf8")
defer.returnValue(name)
def set_profile_displayname(self, user_localpart, new_displayname): def set_profile_displayname(self, user_localpart, new_displayname):
return self._simple_update_one( return self._simple_update_one(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname}, updatevalues={"displayname": new_displayname.encode("utf8")},
desc="set_profile_displayname", desc="set_profile_displayname",
) )

View File

@ -81,13 +81,23 @@ class RegistrationStore(SQLBaseStore):
txn.execute("INSERT INTO access_tokens(user_id, token) " + txn.execute("INSERT INTO access_tokens(user_id, token) " +
"VALUES (?,?)", [user_id, token]) "VALUES (?,?)", [user_id, token])
@defer.inlineCallbacks
def get_user_by_id(self, user_id): def get_user_by_id(self, user_id):
query = ("SELECT users.name, users.password_hash FROM users" user_info = yield self._simple_select_one(
" WHERE users.name = ?") table="users",
return self._execute( keyvalues={
"get_user_by_id", self.cursor_to_dict, query, user_id "name": user_id,
},
retcols=["name", "password_hash"],
allow_none=True,
) )
if user_info:
user_info["password_hash"] = user_info["password_hash"].decode("utf8")
defer.returnValue(user_info)
@cached() @cached()
# TODO(paul): Currently there's no code to invalidate this cache. That # TODO(paul): Currently there's no code to invalidate this cache. That
# means if/when we ever add internal ways to invalidate access tokens or # means if/when we ever add internal ways to invalidate access tokens or

View File

@ -72,6 +72,7 @@ class RoomStore(SQLBaseStore):
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcols=RoomsTable.fields, retcols=RoomsTable.fields,
desc="get_room", desc="get_room",
allow_none=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -65,4 +65,4 @@ CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
) ENGINE = INNODB; ) ENGINE = INNODB;
CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id
ON local_media_repository_thumbnails (media_id); ON remote_media_cache_thumbnails (media_id);

View File

@ -14,7 +14,7 @@
*/ */
CREATE TABLE IF NOT EXISTS profiles( CREATE TABLE IF NOT EXISTS profiles(
user_id VARCHAR(255) NOT NULL, user_id VARCHAR(255) NOT NULL,
displayname VARCHAR(255), displayname VARBINARY(255),
avatar_url VARCHAR(255), avatar_url VARCHAR(255),
UNIQUE(user_id) UNIQUE(user_id)
) ENGINE = INNODB; ) ENGINE = INNODB;

View File

@ -38,7 +38,6 @@ CREATE TABLE IF NOT EXISTS sent_transactions(
) ENGINE = INNODB; ) ENGINE = INNODB;
CREATE INDEX IF NOT EXISTS sent_transaction_dest ON sent_transactions(destination); CREATE INDEX IF NOT EXISTS sent_transaction_dest ON sent_transactions(destination);
CREATE INDEX IF NOT EXISTS sent_transaction_dest_referenced ON sent_transactions(destination);
CREATE INDEX IF NOT EXISTS sent_transaction_txn_id ON sent_transactions(transaction_id); CREATE INDEX IF NOT EXISTS sent_transaction_txn_id ON sent_transactions(transaction_id);
-- So that we can do an efficient look up of all transactions that have yet to be successfully -- So that we can do an efficient look up of all transactions that have yet to be successfully
-- sent. -- sent.

View File

@ -54,7 +54,7 @@ class SignatureStore(SQLBaseStore):
{ {
"event_id": event_id, "event_id": event_id,
"algorithm": algorithm, "algorithm": algorithm,
"hash": buffer(hash_bytes), "hash": hash_bytes,
}, },
) )
@ -99,7 +99,7 @@ class SignatureStore(SQLBaseStore):
" WHERE event_id = ?" " WHERE event_id = ?"
) )
txn.execute(query, (event_id, )) txn.execute(query, (event_id, ))
return dict(txn.fetchall()) return {k: v for k, v in txn.fetchall()}
def _store_event_reference_hash_txn(self, txn, event_id, algorithm, def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
hash_bytes): hash_bytes):
@ -116,7 +116,7 @@ class SignatureStore(SQLBaseStore):
{ {
"event_id": event_id, "event_id": event_id,
"algorithm": algorithm, "algorithm": algorithm,
"hash": buffer(hash_bytes), "hash": hash_bytes,
}, },
) )
@ -160,7 +160,7 @@ class SignatureStore(SQLBaseStore):
"event_id": event_id, "event_id": event_id,
"signature_name": signature_name, "signature_name": signature_name,
"key_id": key_id, "key_id": key_id,
"signature": buffer(signature_bytes), "signature": signature_bytes,
}, },
) )
@ -193,6 +193,6 @@ class SignatureStore(SQLBaseStore):
"event_id": event_id, "event_id": event_id,
"prev_event_id": prev_event_id, "prev_event_id": prev_event_id,
"algorithm": algorithm, "algorithm": algorithm,
"hash": buffer(hash_bytes), "hash": hash_bytes,
}, },
) )

View File

@ -282,7 +282,7 @@ class TransactionStore(SQLBaseStore):
query = ( query = (
"UPDATE destinations" "UPDATE destinations"
" SET retry_last_ts = ?, retry_interval = ?" " SET retry_last_ts = ?, retry_interval = ?"
" WHERE destinations = ?" " WHERE destination = ?"
) )
txn.execute( txn.execute(