Reference Matrix Home Server

This commit is contained in:
matrix.org 2014-08-12 15:10:52 +01:00
commit 4f475c7697
217 changed files with 48447 additions and 0 deletions

117
synapse/storage/__init__.py Normal file
View file

@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.events.room import (
RoomMemberEvent, MessageEvent, RoomTopicEvent, FeedbackEvent,
RoomConfigEvent
)
from .directory import DirectoryStore
from .feedback import FeedbackStore
from .message import MessageStore
from .presence import PresenceStore
from .profile import ProfileStore
from .registration import RegistrationStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .roomdata import RoomDataStore
from .stream import StreamStore
from .pdu import StatePduStore, PduStore
from .transactions import TransactionStore
import json
import os
class DataStore(RoomDataStore, RoomMemberStore, MessageStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
PresenceStore, PduStore, StatePduStore, TransactionStore,
DirectoryStore):
def __init__(self, hs):
super(DataStore, self).__init__(hs)
self.event_factory = hs.get_event_factory()
self.hs = hs
def persist_event(self, event):
if event.type == MessageEvent.TYPE:
return self.store_message(
user_id=event.user_id,
room_id=event.room_id,
msg_id=event.msg_id,
content=json.dumps(event.content)
)
elif event.type == RoomMemberEvent.TYPE:
return self.store_room_member(
user_id=event.target_user_id,
sender=event.user_id,
room_id=event.room_id,
content=event.content,
membership=event.content["membership"]
)
elif event.type == FeedbackEvent.TYPE:
return self.store_feedback(
room_id=event.room_id,
msg_id=event.msg_id,
msg_sender_id=event.msg_sender_id,
fb_sender_id=event.user_id,
fb_type=event.feedback_type,
content=json.dumps(event.content)
)
elif event.type == RoomTopicEvent.TYPE:
return self.store_room_data(
room_id=event.room_id,
etype=event.type,
state_key=event.state_key,
content=json.dumps(event.content)
)
elif event.type == RoomConfigEvent.TYPE:
if "visibility" in event.content:
visibility = event.content["visibility"]
return self.store_room_config(
room_id=event.room_id,
visibility=visibility
)
else:
raise NotImplementedError(
"Don't know how to persist type=%s" % event.type
)
def schema_path(schema):
""" Get a filesystem path for the named database schema
Args:
schema: Name of the database schema.
Returns:
A filesystem path pointing at a ".sql" file.
"""
dir_path = os.path.dirname(__file__)
schemaPath = os.path.join(dir_path, "schema", schema + ".sql")
return schemaPath
def read_schema(schema):
""" Read the named database schema.
Args:
schema: Name of the datbase schema.
Returns:
A string containing the database schema.
"""
with open(schema_path(schema)) as schema_file:
return schema_file.read()

405
synapse/storage/_base.py Normal file
View file

@ -0,0 +1,405 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import StoreError
import collections
logger = logging.getLogger(__name__)
class SQLBaseStore(object):
def __init__(self, hs):
self._db_pool = hs.get_db_pool()
def cursor_to_dict(self, cursor):
"""Converts a SQL cursor into an list of dicts.
Args:
cursor : The DBAPI cursor which has executed a query.
Returns:
A list of dicts where the key is the column header.
"""
col_headers = list(column[0] for column in cursor.description)
results = list(
dict(zip(col_headers, row)) for row in cursor.fetchall()
)
return results
def _execute(self, decoder, query, *args):
"""Runs a single query for a result set.
Args:
decoder - The function which can resolve the cursor results to
something meaningful.
query - The query string to execute
*args - Query args.
Returns:
The result of decoder(results)
"""
logger.debug(
"[SQL] %s Args=%s Func=%s", query, args, decoder.__name__
)
def interaction(txn):
cursor = txn.execute(query, args)
return decoder(cursor)
return self._db_pool.runInteraction(interaction)
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
def _simple_insert(self, table, values):
"""Executes an INSERT query on the named table.
Args:
table : string giving the table name
values : dict of new column names and values for them
"""
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
", ".join(k for k in values),
", ".join("?" for k in values)
)
def func(txn):
txn.execute(sql, values.values())
return txn.lastrowid
return self._db_pool.runInteraction(func)
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
retcols : list of strings giving the names of the columns to return
allow_none : If true, return None instead of failing if the SELECT
statement returns no rows
"""
return self._simple_selectupdate_one(
table, keyvalues, retcols=retcols, allow_none=allow_none
)
@defer.inlineCallbacks
def _simple_select_one_onecol(self, table, keyvalues, retcol,
allow_none=False):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it."
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
retcol : string giving the name of the column to return
"""
ret = yield self._simple_select_one(
table=table,
keyvalues=keyvalues,
retcols=[retcol],
allow_none=allow_none
)
if ret:
defer.returnValue(ret[retcol])
else:
defer.returnValue(None)
@defer.inlineCallbacks
def _simple_select_onecol(self, table, keyvalues, retcol):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
Args:
table (str): table name
keyvalues (dict): column names and values to select the rows with
retcol (str): column whos value we wish to retrieve.
Returns:
Deferred: Results in a list
"""
sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
"retcol": retcol,
"table": table,
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
}
def func(txn):
txn.execute(sql, keyvalues.values())
return txn.fetchall()
res = yield self._db_pool.runInteraction(func)
defer.returnValue([r[0] for r in res])
def _simple_select_list(self, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k) for k in keyvalues)
)
def func(txn):
txn.execute(sql, keyvalues.values())
return self.cursor_to_dict(txn)
return self._db_pool.runInteraction(func)
def _simple_update_one(self, table, keyvalues, updatevalues,
retcols=None):
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
updatevalues : dict giving column names and values to update
retcols : optional list of column names to return
If present, retcols gives a list of column names on which to perform
a SELECT statement *before* performing the UPDATE statement. The values
of these will be returned in a dict.
These are performed within the same transaction, allowing an atomic
get-and-set. This can be used to implement compare-and-set by putting
the update column in the 'keyvalues' dict as well.
"""
return self._simple_selectupdate_one(table, keyvalues, updatevalues,
retcols=retcols)
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
retcols=None, allow_none=False):
""" Combined SELECT then UPDATE."""
if retcols:
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k) for k in keyvalues)
)
if updatevalues:
update_sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k) for k in updatevalues),
" AND ".join("%s = ?" % (k) for k in keyvalues)
)
def func(txn):
ret = None
if retcols:
txn.execute(select_sql, keyvalues.values())
row = txn.fetchone()
if not row:
if allow_none:
return None
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
ret = dict(zip(retcols, row))
if updatevalues:
txn.execute(
update_sql,
updatevalues.values() + keyvalues.values()
)
if txn.rowcount == 0:
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
return ret
return self._db_pool.runInteraction(func)
def _simple_delete_one(self, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k) for k in keyvalues)
)
def func(txn):
txn.execute(sql, keyvalues.values())
if txn.rowcount == 0:
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
return self._db_pool.runInteraction(func)
def _simple_max_id(self, table):
"""Executes a SELECT query on the named table, expecting to return the
max value for the column "id".
Args:
table : string giving the table name
"""
sql = "SELECT MAX(id) AS id FROM %s" % table
def func(txn):
txn.execute(sql)
max_id = self.cursor_to_dict(txn)[0]["id"]
if max_id is None:
return 0
return max_id
return self._db_pool.runInteraction(func)
class Table(object):
""" A base class used to store information about a particular table.
"""
table_name = None
""" str: The name of the table """
fields = None
""" list: The field names """
EntryType = None
""" Type: A tuple type used to decode the results """
_select_where_clause = "SELECT %s FROM %s WHERE %s"
_select_clause = "SELECT %s FROM %s"
_insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)"
@classmethod
def select_statement(cls, where_clause=None):
"""
Args:
where_clause (str): The WHERE clause to use.
Returns:
str: An SQL statement to select rows from the table with the given
WHERE clause.
"""
if where_clause:
return cls._select_where_clause % (
", ".join(cls.fields),
cls.table_name,
where_clause
)
else:
return cls._select_clause % (
", ".join(cls.fields),
cls.table_name,
)
@classmethod
def insert_statement(cls):
return cls._insert_clause % (
cls.table_name,
", ".join(cls.fields),
", ".join(["?"] * len(cls.fields)),
)
@classmethod
def decode_single_result(cls, results):
""" Given an iterable of tuples, return a single instance of
`EntryType` or None if the iterable is empty
Args:
results (list): The results list to convert to `EntryType`
Returns:
EntryType: An instance of `EntryType`
"""
results = list(results)
if results:
return cls.EntryType(*results[0])
else:
return None
@classmethod
def decode_results(cls, results):
""" Given an iterable of tuples, return a list of `EntryType`
Args:
results (list): The results list to convert to `EntryType`
Returns:
list: A list of `EntryType`
"""
return [cls.EntryType(*row) for row in results]
@classmethod
def get_fields_string(cls, prefix=None):
if prefix:
to_join = ("%s.%s" % (prefix, f) for f in cls.fields)
else:
to_join = cls.fields
return ", ".join(to_join)
class JoinHelper(object):
""" Used to help do joins on tables by looking at the tables' fields and
creating a list of unique fields to use with SELECTs and a namedtuple
to dump the results into.
Attributes:
taples (list): List of `Table` classes
EntryType (type)
"""
def __init__(self, *tables):
self.tables = tables
res = []
for table in self.tables:
res += [f for f in table.fields if f not in res]
self.EntryType = collections.namedtuple("JoinHelperEntry", res)
def get_fields(self, **prefixes):
"""Get a string representing a list of fields for use in SELECT
statements with the given prefixes applied to each.
For example::
JoinHelper(PdusTable, StateTable).get_fields(
PdusTable="pdus",
StateTable="state"
)
"""
res = []
for field in self.EntryType._fields:
for table in self.tables:
if field in table.fields:
res.append("%s.%s" % (prefixes[table.__name__], field))
break
return ", ".join(res)
def decode_results(self, rows):
return [self.EntryType(*row) for row in rows]

View file

@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from twisted.internet import defer
from collections import namedtuple
RoomAliasMapping = namedtuple(
"RoomAliasMapping",
("room_id", "room_alias", "servers",)
)
class DirectoryStore(SQLBaseStore):
@defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias):
""" Get's the room_id and server list for a given room_alias
Args:
room_alias (RoomAlias)
Returns:
Deferred: results in namedtuple with keys "room_id" and
"servers" or None if no association can be found
"""
room_id = yield self._simple_select_one_onecol(
"room_aliases",
{"room_alias": room_alias.to_string()},
"room_id",
allow_none=True,
)
if not room_id:
defer.returnValue(None)
return
servers = yield self._simple_select_onecol(
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
)
if not servers:
defer.returnValue(None)
return
defer.returnValue(
RoomAliasMapping(room_id, room_alias.to_string(), servers)
)
@defer.inlineCallbacks
def create_room_alias_association(self, room_alias, room_id, servers):
""" Creates an associatin between a room alias and room_id/servers
Args:
room_alias (RoomAlias)
room_id (str)
servers (list)
Returns:
Deferred
"""
yield self._simple_insert(
"room_aliases",
{
"room_alias": room_alias.to_string(),
"room_id": room_id,
},
)
for server in servers:
# TODO(erikj): Fix this to bulk insert
yield self._simple_insert(
"room_alias_servers",
{
"room_alias": room_alias.to_string(),
"server": server,
}
)

View file

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, Table
from synapse.api.events.room import FeedbackEvent
import collections
import json
class FeedbackStore(SQLBaseStore):
def store_feedback(self, room_id, msg_id, msg_sender_id,
fb_sender_id, fb_type, content):
return self._simple_insert(FeedbackTable.table_name, dict(
room_id=room_id,
msg_id=msg_id,
msg_sender_id=msg_sender_id,
fb_sender_id=fb_sender_id,
fb_type=fb_type,
content=content,
))
def get_feedback(self, room_id=None, msg_id=None, msg_sender_id=None,
fb_sender_id=None, fb_type=None):
query = FeedbackTable.select_statement(
"msg_sender_id = ? AND room_id = ? AND msg_id = ? " +
"AND fb_sender_id = ? AND feedback_type = ? " +
"ORDER BY id DESC LIMIT 1")
return self._execute(
FeedbackTable.decode_single_result,
query, msg_sender_id, room_id, msg_id, fb_sender_id, fb_type,
)
def get_max_feedback_id(self):
return self._simple_max_id(FeedbackTable.table_name)
class FeedbackTable(Table):
table_name = "feedback"
fields = [
"id",
"content",
"feedback_type",
"fb_sender_id",
"msg_id",
"room_id",
"msg_sender_id"
]
class EntryType(collections.namedtuple("FeedbackEntry", fields)):
def as_event(self, event_factory):
return event_factory.create_event(
etype=FeedbackEvent.TYPE,
room_id=self.room_id,
msg_id=self.msg_id,
msg_sender_id=self.msg_sender_id,
user_id=self.fb_sender_id,
feedback_type=self.feedback_type,
content=json.loads(self.content),
)

View file

@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, Table
from synapse.api.events.room import MessageEvent
import collections
import json
class MessageStore(SQLBaseStore):
def get_message(self, user_id, room_id, msg_id):
"""Get a message from the store.
Args:
user_id (str): The ID of the user who sent the message.
room_id (str): The room the message was sent in.
msg_id (str): The unique ID for this user/room combo.
"""
query = MessagesTable.select_statement(
"user_id = ? AND room_id = ? AND msg_id = ? " +
"ORDER BY id DESC LIMIT 1")
return self._execute(
MessagesTable.decode_single_result,
query, user_id, room_id, msg_id,
)
def store_message(self, user_id, room_id, msg_id, content):
"""Store a message in the store.
Args:
user_id (str): The ID of the user who sent the message.
room_id (str): The room the message was sent in.
msg_id (str): The unique ID for this user/room combo.
content (str): The content of the message (JSON)
"""
return self._simple_insert(MessagesTable.table_name, dict(
user_id=user_id,
room_id=room_id,
msg_id=msg_id,
content=content,
))
def get_max_message_id(self):
return self._simple_max_id(MessagesTable.table_name)
class MessagesTable(Table):
table_name = "messages"
fields = [
"id",
"user_id",
"room_id",
"msg_id",
"content"
]
class EntryType(collections.namedtuple("MessageEntry", fields)):
def as_event(self, event_factory):
return event_factory.create_event(
etype=MessageEvent.TYPE,
room_id=self.room_id,
user_id=self.user_id,
msg_id=self.msg_id,
content=json.loads(self.content),
)

993
synapse/storage/pdu.py Normal file
View file

@ -0,0 +1,993 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, Table, JoinHelper
from synapse.util.logutils import log_function
from collections import namedtuple
import logging
logger = logging.getLogger(__name__)
class PduStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
def get_pdu(self, pdu_id, origin):
"""Given a pdu_id and origin, get a PDU.
Args:
txn
pdu_id (str)
origin (str)
Returns:
PduTuple: If the pdu does not exist in the database, returns None
"""
return self._db_pool.runInteraction(
self._get_pdu_tuple, pdu_id, origin
)
def _get_pdu_tuple(self, txn, pdu_id, origin):
res = self._get_pdu_tuples(txn, [(pdu_id, origin)])
return res[0] if res else None
def _get_pdu_tuples(self, txn, pdu_id_tuples):
results = []
for pdu_id, origin in pdu_id_tuples:
txn.execute(
PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"),
(pdu_id, origin)
)
edges = [
(r.prev_pdu_id, r.prev_origin)
for r in PduEdgesTable.decode_results(txn.fetchall())
]
query = (
"SELECT %(fields)s FROM %(pdus)s as p "
"LEFT JOIN %(state)s as s "
"ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
"WHERE p.pdu_id = ? AND p.origin = ? "
) % {
"fields": _pdu_state_joiner.get_fields(
PdusTable="p", StatePdusTable="s"),
"pdus": PdusTable.table_name,
"state": StatePdusTable.table_name,
}
txn.execute(query, (pdu_id, origin))
row = txn.fetchone()
if row:
results.append(PduTuple(PduEntry(*row), edges))
return results
def get_current_state_for_context(self, context):
"""Get a list of PDUs that represent the current state for a given
context
Args:
context (str)
Returns:
list: A list of PduTuples
"""
return self._db_pool.runInteraction(
self._get_current_state_for_context,
context
)
def _get_current_state_for_context(self, txn, context):
query = (
"SELECT pdu_id, origin FROM %s WHERE context = ?"
% CurrentStateTable.table_name
)
logger.debug("get_current_state %s, Args=%s", query, context)
txn.execute(query, (context,))
res = txn.fetchall()
logger.debug("get_current_state %d results", len(res))
return self._get_pdu_tuples(txn, res)
def persist_pdu(self, prev_pdus, **cols):
"""Inserts a (non-state) PDU into the database.
Args:
txn,
prev_pdus (list)
**cols: The columns to insert into the PdusTable.
"""
return self._db_pool.runInteraction(
self._persist_pdu, prev_pdus, cols
)
def _persist_pdu(self, txn, prev_pdus, cols):
entry = PdusTable.EntryType(
**{k: cols.get(k, None) for k in PdusTable.fields}
)
txn.execute(PdusTable.insert_statement(), entry)
self._handle_prev_pdus(
txn, entry.outlier, entry.pdu_id, entry.origin,
prev_pdus, entry.context
)
def mark_pdu_as_processed(self, pdu_id, pdu_origin):
"""Mark a received PDU as processed.
Args:
txn
pdu_id (str)
pdu_origin (str)
"""
return self._db_pool.runInteraction(
self._mark_as_processed, pdu_id, pdu_origin
)
def _mark_as_processed(self, txn, pdu_id, pdu_origin):
txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name)
def get_all_pdus_from_context(self, context):
"""Get a list of all PDUs for a given context."""
return self._db_pool.runInteraction(
self._get_all_pdus_from_context, context,
)
def _get_all_pdus_from_context(self, txn, context):
query = (
"SELECT pdu_id, origin FROM %s "
"WHERE context = ?"
) % PdusTable.table_name
txn.execute(query, (context,))
return self._get_pdu_tuples(txn, txn.fetchall())
def get_pagination(self, context, pdu_list, limit):
"""Get a list of Pdus for a given topic that occured before (and
including) the pdus in pdu_list. Return a list of max size `limit`.
Args:
txn
context (str)
pdu_list (list)
limit (int)
Return:
list: A list of PduTuples
"""
return self._db_pool.runInteraction(
self._get_paginate, context, pdu_list, limit
)
def _get_paginate(self, txn, context, pdu_list, limit):
logger.debug(
"paginate: %s, %s, %s",
context, repr(pdu_list), limit
)
# We seed the pdu_results with the things from the pdu_list.
pdu_results = pdu_list
front = pdu_list
query = (
"SELECT prev_pdu_id, prev_origin FROM %(edges_table)s "
"WHERE context = ? AND pdu_id = ? AND origin = ? "
"LIMIT ?"
) % {
"edges_table": PduEdgesTable.table_name,
}
# We iterate through all pdu_ids in `front` to select their previous
# pdus. These are dumped in `new_front`. We continue until we reach the
# limit *or* new_front is empty (i.e., we've run out of things to
# select
while front and len(pdu_results) < limit:
new_front = []
for pdu_id, origin in front:
logger.debug(
"_paginate_interaction: i=%s, o=%s",
pdu_id, origin
)
txn.execute(
query,
(context, pdu_id, origin, limit - len(pdu_results))
)
for row in txn.fetchall():
logger.debug(
"_paginate_interaction: got i=%s, o=%s",
*row
)
new_front.append(row)
front = new_front
pdu_results += new_front
# We also want to update the `prev_pdus` attributes before returning.
return self._get_pdu_tuples(txn, pdu_results)
def get_min_depth_for_context(self, context):
"""Get the current minimum depth for a context
Args:
txn
context (str)
"""
return self._db_pool.runInteraction(
self._get_min_depth_for_context, context
)
def _get_min_depth_for_context(self, txn, context):
return self._get_min_depth_interaction(txn, context)
def _get_min_depth_interaction(self, txn, context):
txn.execute(
"SELECT min_depth FROM %s WHERE context = ?"
% ContextDepthTable.table_name,
(context,)
)
row = txn.fetchone()
return row[0] if row else None
def update_min_depth_for_context(self, context, depth):
"""Update the minimum `depth` of the given context, which is the line
where we stop paginating backwards on.
Args:
context (str)
depth (int)
"""
return self._db_pool.runInteraction(
self._update_min_depth_for_context, context, depth
)
def _update_min_depth_for_context(self, txn, context, depth):
min_depth = self._get_min_depth_interaction(txn, context)
do_insert = depth < min_depth if min_depth else True
if do_insert:
txn.execute(
"INSERT OR REPLACE INTO %s (context, min_depth) "
"VALUES (?,?)" % ContextDepthTable.table_name,
(context, depth)
)
def get_latest_pdus_in_context(self, context):
"""Get's a list of the most current pdus for a given context. This is
used when we are sending a Pdu and need to fill out the `prev_pdus`
key
Args:
txn
context
"""
return self._db_pool.runInteraction(
self._get_latest_pdus_in_context, context
)
def _get_latest_pdus_in_context(self, txn, context):
query = (
"SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
"INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
"AND f.origin = p.origin "
"WHERE f.context = ?"
) % {
"pdus": PdusTable.table_name,
"forward": PduForwardExtremitiesTable.table_name,
}
logger.debug("get_prev query: %s", query)
txn.execute(
query,
(context, )
)
results = txn.fetchall()
return [(row[0], row[1], row[2]) for row in results]
def get_oldest_pdus_in_context(self, context):
"""Get a list of Pdus that we paginated beyond yet (and haven't seen).
This list is used when we want to paginate backwards and is the list we
send to the remote server.
Args:
txn
context (str)
Returns:
list: A list of PduIdTuple.
"""
return self._db_pool.runInteraction(
self._get_oldest_pdus_in_context, context
)
def _get_oldest_pdus_in_context(self, txn, context):
txn.execute(
"SELECT pdu_id, origin FROM %(back)s WHERE context = ?"
% {"back": PduBackwardExtremitiesTable.table_name, },
(context,)
)
return [PduIdTuple(i, o) for i, o in txn.fetchall()]
def is_pdu_new(self, pdu_id, origin, context, depth):
"""For a given Pdu, try and figure out if it's 'new', i.e., if it's
not something we got randomly from the past, for example when we
request the current state of the room that will probably return a bunch
of pdus from before we joined.
Args:
txn
pdu_id (str)
origin (str)
context (str)
depth (int)
Returns:
bool
"""
return self._db_pool.runInteraction(
self._is_pdu_new,
pdu_id=pdu_id,
origin=origin,
context=context,
depth=depth
)
def _is_pdu_new(self, txn, pdu_id, origin, context, depth):
# If depth > min depth in back table, then we classify it as new.
# OR if there is nothing in the back table, then it kinda needs to
# be a new thing.
query = (
"SELECT min(p.depth) FROM %(edges)s as e "
"INNER JOIN %(back)s as b "
"ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin "
"INNER JOIN %(pdus)s as p "
"ON e.pdu_id = p.pdu_id AND p.origin = e.origin "
"WHERE p.context = ?"
) % {
"pdus": PdusTable.table_name,
"edges": PduEdgesTable.table_name,
"back": PduBackwardExtremitiesTable.table_name,
}
txn.execute(query, (context,))
min_depth, = txn.fetchone()
if not min_depth or depth > int(min_depth):
logger.debug(
"is_new true: id=%s, o=%s, d=%s min_depth=%s",
pdu_id, origin, depth, min_depth
)
return True
# If this pdu is in the forwards table, then it also is a new one
query = (
"SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?"
) % {
"forward": PduForwardExtremitiesTable.table_name,
}
txn.execute(query, (pdu_id, origin))
# Did we get anything?
if txn.fetchall():
logger.debug(
"is_new true: id=%s, o=%s, d=%s was forward",
pdu_id, origin, depth
)
return True
logger.debug(
"is_new false: id=%s, o=%s, d=%s",
pdu_id, origin, depth
)
# FINE THEN. It's probably old.
return False
@staticmethod
@log_function
def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus,
context):
txn.executemany(
PduEdgesTable.insert_statement(),
[(pdu_id, origin, p[0], p[1], context) for p in prev_pdus]
)
# Update the extremities table if this is not an outlier.
if not outlier:
# First, we delete the new one from the forwards extremities table.
query = (
"DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
% PduForwardExtremitiesTable.table_name
)
txn.executemany(query, prev_pdus)
# We only insert as a forward extremety the new pdu if there are no
# other pdus that reference it as a prev pdu
query = (
"INSERT INTO %(table)s (pdu_id, origin, context) "
"SELECT ?, ?, ? WHERE NOT EXISTS ("
"SELECT 1 FROM %(pdu_edges)s WHERE "
"prev_pdu_id = ? AND prev_origin = ?"
")"
) % {
"table": PduForwardExtremitiesTable.table_name,
"pdu_edges": PduEdgesTable.table_name
}
logger.debug("query: %s", query)
txn.execute(query, (pdu_id, origin, context, pdu_id, origin))
# Insert all the prev_pdus as a backwards thing, they'll get
# deleted in a second if they're incorrect anyway.
txn.executemany(
PduBackwardExtremitiesTable.insert_statement(),
[(i, o, context) for i, o in prev_pdus]
)
# Also delete from the backwards extremities table all ones that
# reference pdus that we have already seen
query = (
"DELETE FROM %(pdu_back)s WHERE EXISTS ("
"SELECT 1 FROM %(pdus)s AS pdus "
"WHERE "
"%(pdu_back)s.pdu_id = pdus.pdu_id "
"AND %(pdu_back)s.origin = pdus.origin "
"AND not pdus.outlier "
")"
) % {
"pdu_back": PduBackwardExtremitiesTable.table_name,
"pdus": PdusTable.table_name,
}
txn.execute(query)
class StatePduStore(SQLBaseStore):
"""A collection of queries for handling state PDUs.
"""
def persist_state(self, prev_pdus, **cols):
"""Inserts a state PDU into the database
Args:
txn,
prev_pdus (list)
**cols: The columns to insert into the PdusTable and StatePdusTable
"""
return self._db_pool.runInteraction(
self._persist_state, prev_pdus, cols
)
def _persist_state(self, txn, prev_pdus, cols):
pdu_entry = PdusTable.EntryType(
**{k: cols.get(k, None) for k in PdusTable.fields}
)
state_entry = StatePdusTable.EntryType(
**{k: cols.get(k, None) for k in StatePdusTable.fields}
)
logger.debug("Inserting pdu: %s", repr(pdu_entry))
logger.debug("Inserting state: %s", repr(state_entry))
txn.execute(PdusTable.insert_statement(), pdu_entry)
txn.execute(StatePdusTable.insert_statement(), state_entry)
self._handle_prev_pdus(
txn,
pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus,
pdu_entry.context
)
def get_unresolved_state_tree(self, new_state_pdu):
return self._db_pool.runInteraction(
self._get_unresolved_state_tree, new_state_pdu
)
@log_function
def _get_unresolved_state_tree(self, txn, new_pdu):
current = self._get_current_interaction(
txn,
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
)
ReturnType = namedtuple(
"StateReturnType", ["new_branch", "current_branch"]
)
return_value = ReturnType([new_pdu], [])
if not current:
logger.debug("get_unresolved_state_tree No current state.")
return return_value
return_value.current_branch.append(current)
enum_branches = self._enumerate_state_branches(
txn, new_pdu, current
)
for branch, prev_state, state in enum_branches:
if state:
return_value[branch].append(state)
else:
break
return return_value
def update_current_state(self, pdu_id, origin, context, pdu_type,
state_key):
return self._db_pool.runInteraction(
self._update_current_state,
pdu_id, origin, context, pdu_type, state_key
)
def _update_current_state(self, txn, pdu_id, origin, context, pdu_type,
state_key):
query = (
"INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
) % {
"curr": CurrentStateTable.table_name,
"fields": CurrentStateTable.get_fields_string(),
"qs": ", ".join(["?"] * len(CurrentStateTable.fields))
}
query_args = CurrentStateTable.EntryType(
pdu_id=pdu_id,
origin=origin,
context=context,
pdu_type=pdu_type,
state_key=state_key
)
txn.execute(query, query_args)
def get_current_state(self, context, pdu_type, state_key):
"""For a given context, pdu_type, state_key 3-tuple, return what is
currently considered the current state.
Args:
txn
context (str)
pdu_type (str)
state_key (str)
Returns:
PduEntry
"""
return self._db_pool.runInteraction(
self._get_current_state, context, pdu_type, state_key
)
def _get_current_state(self, txn, context, pdu_type, state_key):
return self._get_current_interaction(txn, context, pdu_type, state_key)
def _get_current_interaction(self, txn, context, pdu_type, state_key):
logger.debug(
"_get_current_interaction %s %s %s",
context, pdu_type, state_key
)
fields = _pdu_state_joiner.get_fields(
PdusTable="p", StatePdusTable="s")
current_query = (
"SELECT %(fields)s FROM %(state)s as s "
"INNER JOIN %(pdus)s as p "
"ON s.pdu_id = p.pdu_id AND s.origin = p.origin "
"INNER JOIN %(curr)s as c "
"ON s.pdu_id = c.pdu_id AND s.origin = c.origin "
"WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? "
) % {
"fields": fields,
"curr": CurrentStateTable.table_name,
"state": StatePdusTable.table_name,
"pdus": PdusTable.table_name,
}
txn.execute(
current_query,
(context, pdu_type, state_key)
)
row = txn.fetchone()
result = PduEntry(*row) if row else None
if not result:
logger.debug("_get_current_interaction not found")
else:
logger.debug(
"_get_current_interaction found %s %s",
result.pdu_id, result.origin
)
return result
def get_next_missing_pdu(self, new_pdu):
"""When we get a new state pdu we need to check whether we need to do
any conflict resolution, if we do then we need to check if we need
to go back and request some more state pdus that we haven't seen yet.
Args:
txn
new_pdu
Returns:
PduIdTuple: A pdu that we are missing, or None if we have all the
pdus required to do the conflict resolution.
"""
return self._db_pool.runInteraction(
self._get_next_missing_pdu, new_pdu
)
def _get_next_missing_pdu(self, txn, new_pdu):
logger.debug(
"get_next_missing_pdu %s %s",
new_pdu.pdu_id, new_pdu.origin
)
current = self._get_current_interaction(
txn,
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
)
if (not current or not current.prev_state_id
or not current.prev_state_origin):
return None
# Oh look, it's a straight clobber, so wooooo almost no-op.
if (new_pdu.prev_state_id == current.pdu_id
and new_pdu.prev_state_origin == current.origin):
return None
enum_branches = self._enumerate_state_branches(txn, new_pdu, current)
for branch, prev_state, state in enum_branches:
if not state:
return PduIdTuple(
prev_state.prev_state_id,
prev_state.prev_state_origin
)
return None
def handle_new_state(self, new_pdu):
"""Actually perform conflict resolution on the new_pdu on the
assumption we have all the pdus required to perform it.
Args:
new_pdu
Returns:
bool: True if the new_pdu clobbered the current state, False if not
"""
return self._db_pool.runInteraction(
self._handle_new_state, new_pdu
)
def _handle_new_state(self, txn, new_pdu):
logger.debug(
"handle_new_state %s %s",
new_pdu.pdu_id, new_pdu.origin
)
current = self._get_current_interaction(
txn,
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
)
is_current = False
if (not current or not current.prev_state_id
or not current.prev_state_origin):
# Oh, we don't have any state for this yet.
is_current = True
elif (current.pdu_id == new_pdu.prev_state_id
and current.origin == new_pdu.prev_state_origin):
# Oh! A direct clobber. Just do it.
is_current = True
else:
##
# Ok, now loop through until we get to a common ancestor.
max_new = int(new_pdu.power_level)
max_current = int(current.power_level)
enum_branches = self._enumerate_state_branches(
txn, new_pdu, current
)
for branch, prev_state, state in enum_branches:
if not state:
raise RuntimeError(
"Could not find state_pdu %s %s" %
(
prev_state.prev_state_id,
prev_state.prev_state_origin
)
)
if branch == 0:
max_new = max(int(state.depth), max_new)
else:
max_current = max(int(state.depth), max_current)
is_current = max_new > max_current
if is_current:
logger.debug("handle_new_state make current")
# Right, this is a new thing, so woo, just insert it.
txn.execute(
"INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
% {
"curr": CurrentStateTable.table_name,
"fields": CurrentStateTable.get_fields_string(),
"qs": ", ".join(["?"] * len(CurrentStateTable.fields))
},
CurrentStateTable.EntryType(
*(new_pdu.__dict__[k] for k in CurrentStateTable.fields)
)
)
else:
logger.debug("handle_new_state not current")
logger.debug("handle_new_state done")
return is_current
@classmethod
@log_function
def _enumerate_state_branches(cls, txn, pdu_a, pdu_b):
branch_a = pdu_a
branch_b = pdu_b
get_query = (
"SELECT %(fields)s FROM %(pdus)s as p "
"LEFT JOIN %(state)s as s "
"ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
"WHERE p.pdu_id = ? AND p.origin = ? "
) % {
"fields": _pdu_state_joiner.get_fields(
PdusTable="p", StatePdusTable="s"),
"pdus": PdusTable.table_name,
"state": StatePdusTable.table_name,
}
while True:
if (branch_a.pdu_id == branch_b.pdu_id
and branch_a.origin == branch_b.origin):
# Woo! We found a common ancestor
logger.debug("_enumerate_state_branches Found common ancestor")
break
do_branch_a = (
hasattr(branch_a, "prev_state_id") and
branch_a.prev_state_id
)
do_branch_b = (
hasattr(branch_b, "prev_state_id") and
branch_b.prev_state_id
)
logger.debug(
"do_branch_a=%s, do_branch_b=%s",
do_branch_a, do_branch_b
)
if do_branch_a and do_branch_b:
do_branch_a = int(branch_a.depth) > int(branch_b.depth)
if do_branch_a:
pdu_tuple = PduIdTuple(
branch_a.prev_state_id,
branch_a.prev_state_origin
)
logger.debug("getting branch_a prev %s", pdu_tuple)
txn.execute(get_query, pdu_tuple)
prev_branch = branch_a
res = txn.fetchone()
branch_a = PduEntry(*res) if res else None
logger.debug("branch_a=%s", branch_a)
yield (0, prev_branch, branch_a)
if not branch_a:
break
elif do_branch_b:
pdu_tuple = PduIdTuple(
branch_b.prev_state_id,
branch_b.prev_state_origin
)
txn.execute(get_query, pdu_tuple)
logger.debug("getting branch_b prev %s", pdu_tuple)
prev_branch = branch_b
res = txn.fetchone()
branch_b = PduEntry(*res) if res else None
logger.debug("branch_b=%s", branch_b)
yield (1, prev_branch, branch_b)
if not branch_b:
break
else:
break
class PdusTable(Table):
table_name = "pdus"
fields = [
"pdu_id",
"origin",
"context",
"pdu_type",
"ts",
"depth",
"is_state",
"content_json",
"unrecognized_keys",
"outlier",
"have_processed",
]
EntryType = namedtuple("PdusEntry", fields)
class PduDestinationsTable(Table):
table_name = "pdu_destinations"
fields = [
"pdu_id",
"origin",
"destination",
"delivered_ts",
]
EntryType = namedtuple("PduDestinationsEntry", fields)
class PduEdgesTable(Table):
table_name = "pdu_edges"
fields = [
"pdu_id",
"origin",
"prev_pdu_id",
"prev_origin",
"context"
]
EntryType = namedtuple("PduEdgesEntry", fields)
class PduForwardExtremitiesTable(Table):
table_name = "pdu_forward_extremities"
fields = [
"pdu_id",
"origin",
"context",
]
EntryType = namedtuple("PduForwardExtremitiesEntry", fields)
class PduBackwardExtremitiesTable(Table):
table_name = "pdu_backward_extremities"
fields = [
"pdu_id",
"origin",
"context",
]
EntryType = namedtuple("PduBackwardExtremitiesEntry", fields)
class ContextDepthTable(Table):
table_name = "context_depth"
fields = [
"context",
"min_depth",
]
EntryType = namedtuple("ContextDepthEntry", fields)
class StatePdusTable(Table):
table_name = "state_pdus"
fields = [
"pdu_id",
"origin",
"context",
"pdu_type",
"state_key",
"power_level",
"prev_state_id",
"prev_state_origin",
]
EntryType = namedtuple("StatePdusEntry", fields)
class CurrentStateTable(Table):
table_name = "current_state"
fields = [
"pdu_id",
"origin",
"context",
"pdu_type",
"state_key",
]
EntryType = namedtuple("CurrentStateEntry", fields)
_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable)
# TODO: These should probably be put somewhere more sensible
PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin"))
PduEntry = _pdu_state_joiner.EntryType
""" We are always interested in the join of the PdusTable and StatePdusTable,
rather than just the PdusTable.
This does not include a prev_pdus key.
"""
PduTuple = namedtuple(
"PduTuple",
("pdu_entry", "prev_pdu_list")
)
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
the `prev_pdus` key of a PDU.
"""

103
synapse/storage/presence.py Normal file
View file

@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
class PresenceStore(SQLBaseStore):
def create_presence(self, user_localpart):
return self._simple_insert(
table="presence",
values={"user_id": user_localpart},
)
def has_presence_state(self, user_localpart):
return self._simple_select_one(
table="presence",
keyvalues={"user_id": user_localpart},
retcols=["user_id"],
allow_none=True,
)
def get_presence_state(self, user_localpart):
return self._simple_select_one(
table="presence",
keyvalues={"user_id": user_localpart},
retcols=["state", "status_msg"],
)
def set_presence_state(self, user_localpart, new_state):
return self._simple_update_one(
table="presence",
keyvalues={"user_id": user_localpart},
updatevalues={"state": new_state["state"],
"status_msg": new_state["status_msg"]},
retcols=["state"],
)
def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert(
table="presence_allow_inbound",
values={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid},
)
def disallow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_delete_one(
table="presence_allow_inbound",
keyvalues={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid},
)
def is_presence_visible(self, observed_localpart, observer_userid):
return self._simple_select_one(
table="presence_allow_inbound",
keyvalues={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid},
allow_none=True,
)
def add_presence_list_pending(self, observer_localpart, observed_userid):
return self._simple_insert(
table="presence_list",
values={"user_id": observer_localpart,
"observed_user_id": observed_userid,
"accepted": False},
)
def set_presence_list_accepted(self, observer_localpart, observed_userid):
return self._simple_update_one(
table="presence_list",
keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid},
updatevalues={"accepted": True},
)
def get_presence_list(self, observer_localpart, accepted=None):
keyvalues = {"user_id": observer_localpart}
if accepted is not None:
keyvalues["accepted"] = accepted
return self._simple_select_list(
table="presence_list",
keyvalues=keyvalues,
retcols=["observed_user_id", "accepted"],
)
def del_presence_list(self, observer_localpart, observed_userid):
return self._simple_delete_one(
table="presence_list",
keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid},
)

View file

@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
class ProfileStore(SQLBaseStore):
def create_profile(self, user_localpart):
return self._simple_insert(
table="profiles",
values={"user_id": user_localpart},
)
def get_profile_displayname(self, user_localpart):
return self._simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
)
def set_profile_displayname(self, user_localpart, new_displayname):
return self._simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname},
)
def get_profile_avatar_url(self, user_localpart):
return self._simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
)
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
return self._simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url},
)

View file

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from sqlite3 import IntegrityError
from synapse.api.errors import StoreError
from ._base import SQLBaseStore
class RegistrationStore(SQLBaseStore):
def __init__(self, hs):
super(RegistrationStore, self).__init__(hs)
self.clock = hs.get_clock()
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token):
"""Adds an access token for the given user.
Args:
user_id (str): The user ID.
token (str): The new access token to add.
Raises:
StoreError if there was a problem adding this.
"""
row = yield self._simple_select_one("users", {"name": user_id}, ["id"])
if not row:
raise StoreError(400, "Bad user ID supplied.")
row_id = row["id"]
yield self._simple_insert(
"access_tokens",
{
"user_id": row_id,
"token": token
}
)
@defer.inlineCallbacks
def register(self, user_id, token, password_hash):
"""Attempts to register an account.
Args:
user_id (str): The desired user ID to register.
token (str): The desired access token to use for this user.
password_hash (str): Optional. The password hash for this user.
Raises:
StoreError if the user_id could not be registered.
"""
yield self._db_pool.runInteraction(self._register, user_id, token,
password_hash)
def _register(self, txn, user_id, token, password_hash):
now = int(self.clock.time())
try:
txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
"VALUES (?,?,?)",
[user_id, password_hash, now])
except IntegrityError:
raise StoreError(400, "User ID already taken.")
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID
txn.execute("INSERT INTO access_tokens(user_id, token) " +
"VALUES (?,?)", [txn.lastrowid, token])
def get_user_by_id(self, user_id):
query = ("SELECT users.name, users.password_hash FROM users "
"WHERE users.name = ?")
return self._execute(
self.cursor_to_dict,
query, user_id
)
@defer.inlineCallbacks
def get_user_by_token(self, token):
"""Get a user from the given access token.
Args:
token (str): The access token of a user.
Returns:
str: The user ID of the user.
Raises:
StoreError if no user was found.
"""
user_id = yield self._db_pool.runInteraction(self._query_for_auth,
token)
defer.returnValue(user_id)
def _query_for_auth(self, txn, token):
txn.execute("SELECT users.name FROM access_tokens LEFT JOIN users" +
" ON users.id = access_tokens.user_id WHERE token = ?",
[token])
row = txn.fetchone()
if row:
return row[0]
raise StoreError(404, "Token not found.")

129
synapse/storage/room.py Normal file
View file

@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from sqlite3 import IntegrityError
from synapse.api.errors import StoreError
from synapse.api.events.room import RoomTopicEvent
from ._base import SQLBaseStore, Table
import collections
import json
import logging
logger = logging.getLogger(__name__)
class RoomStore(SQLBaseStore):
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
"""Stores a room.
Args:
room_id (str): The desired room ID, can be None.
room_creator_user_id (str): The user ID of the room creator.
is_public (bool): True to indicate that this room should appear in
public room lists.
Raises:
StoreError if the room could not be stored.
"""
try:
yield self._simple_insert(RoomsTable.table_name, dict(
room_id=room_id,
creator=room_creator_user_id,
is_public=is_public
))
except IntegrityError:
raise StoreError(409, "Room ID in use.")
except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
def store_room_config(self, room_id, visibility):
return self._simple_update_one(
table=RoomsTable.table_name,
keyvalues={"room_id": room_id},
updatevalues={"is_public": visibility}
)
def get_room(self, room_id):
"""Retrieve a room.
Args:
room_id (str): The ID of the room to retrieve.
Returns:
A namedtuple containing the room information, or an empty list.
"""
query = RoomsTable.select_statement("room_id=?")
return self._execute(
RoomsTable.decode_single_result, query, room_id,
)
@defer.inlineCallbacks
def get_rooms(self, is_public, with_topics):
"""Retrieve a list of all public rooms.
Args:
is_public (bool): True if the rooms returned should be public.
with_topics (bool): True to include the current topic for the room
in the response.
Returns:
A list of room dicts containing at least a "room_id" key, and a
"topic" key if one is set and with_topic=True.
"""
room_data_type = RoomTopicEvent.TYPE
public = 1 if is_public else 0
latest_topic = ("SELECT max(room_data.id) FROM room_data WHERE "
+ "room_data.type = ? GROUP BY room_id")
query = ("SELECT rooms.*, room_data.content FROM rooms LEFT JOIN "
+ "room_data ON rooms.room_id = room_data.room_id WHERE "
+ "(room_data.id IN (" + latest_topic + ") "
+ "OR room_data.id IS NULL) AND rooms.is_public = ?")
res = yield self._execute(
self.cursor_to_dict, query, room_data_type, public
)
# return only the keys the specification expects
ret_keys = ["room_id", "topic"]
# extract topic from the json (icky) FIXME
for i, room_row in enumerate(res):
try:
content_json = json.loads(room_row["content"])
room_row["topic"] = content_json["topic"]
except:
pass # no topic set
# filter the dict based on ret_keys
res[i] = {k: v for k, v in room_row.iteritems() if k in ret_keys}
defer.returnValue(res)
class RoomsTable(Table):
table_name = "rooms"
fields = [
"room_id",
"is_public",
"creator"
]
EntryType = collections.namedtuple("RoomEntry", fields)

View file

@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, Table
import collections
import json
class RoomDataStore(SQLBaseStore):
"""Provides various CRUD operations for Room Events. """
def get_room_data(self, room_id, etype, state_key=""):
"""Retrieve the data stored under this type and state_key.
Args:
room_id (str)
etype (str)
state_key (str)
Returns:
namedtuple: Or None if nothing exists at this path.
"""
query = RoomDataTable.select_statement(
"room_id = ? AND type = ? AND state_key = ? "
"ORDER BY id DESC LIMIT 1"
)
return self._execute(
RoomDataTable.decode_single_result,
query, room_id, etype, state_key,
)
def store_room_data(self, room_id, etype, state_key="", content=None):
"""Stores room specific data.
Args:
room_id (str)
etype (str)
state_key (str)
data (str)- The data to store for this path in JSON.
Returns:
The store ID for this data.
"""
return self._simple_insert(RoomDataTable.table_name, dict(
etype=etype,
state_key=state_key,
room_id=room_id,
content=content,
))
def get_max_room_data_id(self):
return self._simple_max_id(RoomDataTable.table_name)
class RoomDataTable(Table):
table_name = "room_data"
fields = [
"id",
"room_id",
"type",
"state_key",
"content"
]
class EntryType(collections.namedtuple("RoomDataEntry", fields)):
def as_event(self, event_factory):
return event_factory.create_event(
etype=self.type,
room_id=self.room_id,
content=json.loads(self.content),
)

View file

@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.types import UserID
from synapse.api.constants import Membership
from synapse.api.events.room import RoomMemberEvent
from ._base import SQLBaseStore, Table
import collections
import json
import logging
logger = logging.getLogger(__name__)
class RoomMemberStore(SQLBaseStore):
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
Args:
user_id (str): The member's user ID.
room_id (str): The room the member is in.
Returns:
namedtuple: The room member from the database, or None if this
member does not exist.
"""
query = RoomMemberTable.select_statement(
"room_id = ? AND user_id = ? ORDER BY id DESC LIMIT 1")
return self._execute(
RoomMemberTable.decode_single_result,
query, room_id, user_id,
)
def store_room_member(self, user_id, sender, room_id, membership, content):
"""Store a room member in the database.
Args:
user_id (str): The member's user ID.
room_id (str): The room in relation to the member.
membership (synapse.api.constants.Membership): The new membership
state.
content (dict): The content of the membership (JSON).
"""
content_json = json.dumps(content)
return self._simple_insert(RoomMemberTable.table_name, dict(
user_id=user_id,
sender=sender,
room_id=room_id,
membership=membership,
content=content_json,
))
@defer.inlineCallbacks
def get_room_members(self, room_id, membership=None):
"""Retrieve the current room member list for a room.
Args:
room_id (str): The room to get the list of members.
membership (synapse.api.constants.Membership): The filter to apply
to this list, or None to return all members with some state
associated with this room.
Returns:
list of namedtuples representing the members in this room.
"""
query = RoomMemberTable.select_statement(
"id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
+ " WHERE room_id = ? GROUP BY user_id)"
)
res = yield self._execute(
RoomMemberTable.decode_results, query, room_id,
)
# strip memberships which don't match
if membership:
res = [entry for entry in res if entry.membership == membership]
defer.returnValue(res)
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user
matches one in the membership list.
Args:
user_id (str): The user ID.
membership_list (list): A list of synapse.api.constants.Membership
values which the user must be in.
Returns:
A list of dicts with "room_id" and "membership" keys.
"""
if not membership_list:
return defer.succeed(None)
args = [user_id]
membership_placeholder = ["membership=?"] * len(membership_list)
where_membership = "(" + " OR ".join(membership_placeholder) + ")"
for membership in membership_list:
args.append(membership)
query = ("SELECT room_id, membership FROM room_memberships"
+ " WHERE user_id=? AND " + where_membership
+ " GROUP BY room_id ORDER BY id DESC")
return self._execute(
self.cursor_to_dict, query, *args
)
@defer.inlineCallbacks
def get_joined_hosts_for_room(self, room_id):
query = RoomMemberTable.select_statement(
"id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
+ " WHERE room_id = ? GROUP BY user_id)"
)
res = yield self._execute(
RoomMemberTable.decode_results, query, room_id,
)
def host_from_user_id_string(user_id):
domain = UserID.from_string(entry.user_id, self.hs).domain
return domain
# strip memberships which don't match
hosts = [
host_from_user_id_string(entry.user_id)
for entry in res
if entry.membership == Membership.JOIN
]
logger.debug("Returning hosts: %s from results: %s", hosts, res)
defer.returnValue(hosts)
def get_max_room_member_id(self):
return self._simple_max_id(RoomMemberTable.table_name)
class RoomMemberTable(Table):
table_name = "room_memberships"
fields = [
"id",
"user_id",
"sender",
"room_id",
"membership",
"content"
]
class EntryType(collections.namedtuple("RoomMemberEntry", fields)):
def as_event(self, event_factory):
return event_factory.create_event(
etype=RoomMemberEvent.TYPE,
room_id=self.room_id,
target_user_id=self.user_id,
user_id=self.sender,
content=json.loads(self.content),
)

View file

@ -0,0 +1,31 @@
/* Copyright 2014 matrix.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS context_edge_pdus(
id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
pdu_id TEXT,
origin TEXT,
context TEXT,
CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin)
);
CREATE TABLE IF NOT EXISTS origin_edge_pdus(
id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
pdu_id TEXT,
origin TEXT,
CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin)
);
CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin);
CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin);

View file

@ -0,0 +1,54 @@
/* Copyright 2014 matrix.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS rooms(
room_id TEXT PRIMARY KEY NOT NULL,
is_public INTEGER,
creator TEXT
);
CREATE TABLE IF NOT EXISTS room_memberships(
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL, -- no foreign key to users table, it could be an id belonging to another home server
sender TEXT NOT NULL,
room_id TEXT NOT NULL,
membership TEXT NOT NULL,
content TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS messages(
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT,
room_id TEXT,
msg_id TEXT,
content TEXT
);
CREATE TABLE IF NOT EXISTS feedback(
id INTEGER PRIMARY KEY AUTOINCREMENT,
content TEXT,
feedback_type TEXT,
fb_sender_id TEXT,
msg_id TEXT,
room_id TEXT,
msg_sender_id TEXT
);
CREATE TABLE IF NOT EXISTS room_data(
id INTEGER PRIMARY KEY AUTOINCREMENT,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
content TEXT
);

View file

@ -0,0 +1,106 @@
/* Copyright 2014 matrix.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Stores pdus and their content
CREATE TABLE IF NOT EXISTS pdus(
pdu_id TEXT,
origin TEXT,
context TEXT,
pdu_type TEXT,
ts INTEGER,
depth INTEGER DEFAULT 0 NOT NULL,
is_state BOOL,
content_json TEXT,
unrecognized_keys TEXT,
outlier BOOL NOT NULL,
have_processed BOOL,
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
);
-- Stores what the current state pdu is for a given (context, pdu_type, key) tuple
CREATE TABLE IF NOT EXISTS state_pdus(
pdu_id TEXT,
origin TEXT,
context TEXT,
pdu_type TEXT,
state_key TEXT,
power_level TEXT,
prev_state_id TEXT,
prev_state_origin TEXT,
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin)
);
CREATE TABLE IF NOT EXISTS current_state(
pdu_id TEXT,
origin TEXT,
context TEXT,
pdu_type TEXT,
state_key TEXT,
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE
);
-- Stores where each pdu we want to send should be sent and the delivery status.
create TABLE IF NOT EXISTS pdu_destinations(
pdu_id TEXT,
origin TEXT,
destination TEXT,
delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE
);
CREATE TABLE IF NOT EXISTS pdu_forward_extremities(
pdu_id TEXT,
origin TEXT,
context TEXT,
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
);
CREATE TABLE IF NOT EXISTS pdu_backward_extremities(
pdu_id TEXT,
origin TEXT,
context TEXT,
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
);
CREATE TABLE IF NOT EXISTS pdu_edges(
pdu_id TEXT,
origin TEXT,
prev_pdu_id TEXT,
prev_origin TEXT,
context TEXT,
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context)
);
CREATE TABLE IF NOT EXISTS context_depth(
context TEXT,
min_depth INTEGER,
CONSTRAINT uniqueness UNIQUE (context)
);
CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context);
CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin);
CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin);
-- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination);
CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context);
CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin);
CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin);
CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context);

View file

@ -0,0 +1,37 @@
/* Copyright 2014 matrix.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS presence(
user_id INTEGER NOT NULL,
state INTEGER,
status_msg TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);
-- For each of /my/ users which possibly-remote users are allowed to see their
-- presence state
CREATE TABLE IF NOT EXISTS presence_allow_inbound(
observed_user_id INTEGER NOT NULL,
observer_user_id TEXT, -- a UserID,
FOREIGN KEY(observed_user_id) REFERENCES users(id)
);
-- For each of /my/ users (watcher), which possibly-remote users are they
-- watching?
CREATE TABLE IF NOT EXISTS presence_list(
user_id INTEGER NOT NULL,
observed_user_id TEXT, -- a UserID,
accepted BOOLEAN,
FOREIGN KEY(user_id) REFERENCES users(id)
);

View file

@ -0,0 +1,20 @@
/* Copyright 2014 matrix.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS profiles(
user_id INTEGER NOT NULL,
displayname TEXT,
avatar_url TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);

View file

@ -0,0 +1,12 @@
CREATE TABLE IF NOT EXISTS room_aliases(
room_alias TEXT NOT NULL,
room_id TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS room_alias_servers(
room_alias TEXT NOT NULL,
server TEXT NOT NULL
);

View file

@ -0,0 +1,61 @@
/* Copyright 2014 matrix.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Stores what transaction ids we have received and what our response was
CREATE TABLE IF NOT EXISTS received_transactions(
transaction_id TEXT,
origin TEXT,
ts INTEGER,
response_code INTEGER,
response_json TEXT,
has_been_referenced BOOL default 0, -- Whether thishas been referenced by a prev_tx
CONSTRAINT uniquesss UNIQUE (transaction_id, origin) ON CONFLICT REPLACE
);
CREATE UNIQUE INDEX IF NOT EXISTS transactions_txid ON received_transactions(transaction_id, origin);
CREATE INDEX IF NOT EXISTS transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0;
-- Stores what transactions we've sent, what their response was (if we got one) and whether we have
-- since referenced the transaction in another outgoing transaction
CREATE TABLE IF NOT EXISTS sent_transactions(
id INTEGER PRIMARY KEY AUTOINCREMENT, -- This is used to apply insertion ordering
transaction_id TEXT,
destination TEXT,
response_code INTEGER DEFAULT 0,
response_json TEXT,
ts INTEGER
);
CREATE INDEX IF NOT EXISTS sent_transaction_dest ON sent_transactions(destination);
CREATE INDEX IF NOT EXISTS sent_transaction_dest_referenced ON sent_transactions(
destination
);
-- So that we can do an efficient look up of all transactions that have yet to be successfully
-- sent.
CREATE INDEX IF NOT EXISTS sent_transaction_sent ON sent_transactions(response_code);
-- For sent transactions only.
CREATE TABLE IF NOT EXISTS transaction_id_to_pdu(
transaction_id INTEGER,
destination TEXT,
pdu_id TEXT,
pdu_origin TEXT
);
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination);
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination);
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_index ON transaction_id_to_pdu(transaction_id, destination);

View file

@ -0,0 +1,31 @@
/* Copyright 2014 matrix.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS users(
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
password_hash TEXT,
creation_ts INTEGER,
UNIQUE(name) ON CONFLICT ROLLBACK
);
CREATE TABLE IF NOT EXISTS access_tokens(
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
device_id TEXT,
token TEXT NOT NULL,
last_used INTEGER,
FOREIGN KEY(user_id) REFERENCES users(id),
UNIQUE(token) ON CONFLICT ROLLBACK
);

282
synapse/storage/stream.py Normal file
View file

@ -0,0 +1,282 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from .message import MessagesTable
from .feedback import FeedbackTable
from .roomdata import RoomDataTable
from .roommember import RoomMemberTable
import json
import logging
logger = logging.getLogger(__name__)
class StreamStore(SQLBaseStore):
def get_message_stream(self, user_id, from_key, to_key, room_id, limit=0,
with_feedback=False):
"""Get all messages for this user between the given keys.
Args:
user_id (str): The user who is requesting messages.
from_key (int): The ID to start returning results from (exclusive).
to_key (int): The ID to stop returning results (exclusive).
room_id (str): Gets messages only for this room. Can be None, in
which case all room messages will be returned.
Returns:
A tuple of rows (list of namedtuples), new_id(int)
"""
if with_feedback and room_id: # with fb MUST specify a room ID
return self._db_pool.runInteraction(
self._get_message_rows_with_feedback,
user_id, from_key, to_key, room_id, limit
)
else:
return self._db_pool.runInteraction(
self._get_message_rows,
user_id, from_key, to_key, room_id, limit
)
def _get_message_rows(self, txn, user_id, from_pkey, to_pkey, room_id,
limit):
# work out which rooms this user is joined in on and join them with
# the room id on the messages table, bounded by the specified pkeys
# get all messages where the *current* membership state is 'join' for
# this user in that room.
query = ("SELECT messages.* FROM messages WHERE ? IN"
+ " (SELECT membership from room_memberships WHERE user_id=?"
+ " AND room_id = messages.room_id ORDER BY id DESC LIMIT 1)")
query_args = ["join", user_id]
if room_id:
query += " AND messages.room_id=?"
query_args.append(room_id)
(query, query_args) = self._append_stream_operations(
"messages", query, query_args, from_pkey, to_pkey, limit=limit
)
logger.debug("[SQL] %s : %s", query, query_args)
cursor = txn.execute(query, query_args)
return self._as_events(cursor, MessagesTable, from_pkey)
def _get_message_rows_with_feedback(self, txn, user_id, from_pkey, to_pkey,
room_id, limit):
# this col represents the compressed feedback JSON as per spec
compressed_feedback_col = (
"'[' || group_concat('{\"sender_id\":\"' || f.fb_sender_id"
+ " || '\",\"feedback_type\":\"' || f.feedback_type"
+ " || '\",\"content\":' || f.content || '}') || ']'"
)
global_msg_id_join = ("f.room_id = messages.room_id"
+ " and f.msg_id = messages.msg_id"
+ " and messages.user_id = f.msg_sender_id")
select_query = (
"SELECT messages.*, f.content AS fb_content, f.fb_sender_id"
+ ", " + compressed_feedback_col + " AS compressed_fb"
+ " FROM messages LEFT JOIN feedback f ON " + global_msg_id_join)
current_membership_sub_query = (
"(SELECT membership from room_memberships rm"
+ " WHERE user_id=? AND room_id = rm.room_id"
+ " ORDER BY id DESC LIMIT 1)")
where = (" WHERE ? IN " + current_membership_sub_query
+ " AND messages.room_id=?")
query = select_query + where
query_args = ["join", user_id, room_id]
(query, query_args) = self._append_stream_operations(
"messages", query, query_args, from_pkey, to_pkey,
limit=limit, group_by=" GROUP BY messages.id "
)
logger.debug("[SQL] %s : %s", query, query_args)
cursor = txn.execute(query, query_args)
# convert the result set into events
entries = self.cursor_to_dict(cursor)
events = []
for entry in entries:
# TODO we should spec the cursor > event mapping somewhere else.
event = {}
straight_mappings = ["msg_id", "user_id", "room_id"]
for key in straight_mappings:
event[key] = entry[key]
event["content"] = json.loads(entry["content"])
if entry["compressed_fb"]:
event["feedback"] = json.loads(entry["compressed_fb"])
events.append(event)
latest_pkey = from_pkey if len(entries) == 0 else entries[-1]["id"]
return (events, latest_pkey)
def get_room_member_stream(self, user_id, from_key, to_key):
"""Get all room membership events for this user between the given keys.
Args:
user_id (str): The user who is requesting membership events.
from_key (int): The ID to start returning results from (exclusive).
to_key (int): The ID to stop returning results (exclusive).
Returns:
A tuple of rows (list of namedtuples), new_id(int)
"""
return self._db_pool.runInteraction(
self._get_room_member_rows, user_id, from_key, to_key
)
def _get_room_member_rows(self, txn, user_id, from_pkey, to_pkey):
# get all room membership events for rooms which the user is
# *currently* joined in on, or all invite events for this user.
current_membership_sub_query = (
"(SELECT membership FROM room_memberships"
+ " WHERE user_id=? AND room_id = rm.room_id"
+ " ORDER BY id DESC LIMIT 1)")
query = ("SELECT rm.* FROM room_memberships rm "
# all membership events for rooms you've currently joined.
+ " WHERE (? IN " + current_membership_sub_query
# all invite membership events for this user
+ " OR rm.membership=? AND user_id=?)"
+ " AND rm.id > ?")
query_args = ["join", user_id, "invite", user_id, from_pkey]
if to_pkey != -1:
query += " AND rm.id < ?"
query_args.append(to_pkey)
cursor = txn.execute(query, query_args)
return self._as_events(cursor, RoomMemberTable, from_pkey)
def get_feedback_stream(self, user_id, from_key, to_key, room_id, limit=0):
return self._db_pool.runInteraction(
self._get_feedback_rows,
user_id, from_key, to_key, room_id, limit
)
def _get_feedback_rows(self, txn, user_id, from_pkey, to_pkey, room_id,
limit):
# work out which rooms this user is joined in on and join them with
# the room id on the feedback table, bounded by the specified pkeys
# get all messages where the *current* membership state is 'join' for
# this user in that room.
query = (
"SELECT feedback.* FROM feedback WHERE ? IN "
+ "(SELECT membership from room_memberships WHERE user_id=?"
+ " AND room_id = feedback.room_id ORDER BY id DESC LIMIT 1)")
query_args = ["join", user_id]
if room_id:
query += " AND feedback.room_id=?"
query_args.append(room_id)
(query, query_args) = self._append_stream_operations(
"feedback", query, query_args, from_pkey, to_pkey, limit=limit
)
logger.debug("[SQL] %s : %s", query, query_args)
cursor = txn.execute(query, query_args)
return self._as_events(cursor, FeedbackTable, from_pkey)
def get_room_data_stream(self, user_id, from_key, to_key, room_id,
limit=0):
return self._db_pool.runInteraction(
self._get_room_data_rows,
user_id, from_key, to_key, room_id, limit
)
def _get_room_data_rows(self, txn, user_id, from_pkey, to_pkey, room_id,
limit):
# work out which rooms this user is joined in on and join them with
# the room id on the feedback table, bounded by the specified pkeys
# get all messages where the *current* membership state is 'join' for
# this user in that room.
query = (
"SELECT room_data.* FROM room_data WHERE ? IN "
+ "(SELECT membership from room_memberships WHERE user_id=?"
+ " AND room_id = room_data.room_id ORDER BY id DESC LIMIT 1)")
query_args = ["join", user_id]
if room_id:
query += " AND room_data.room_id=?"
query_args.append(room_id)
(query, query_args) = self._append_stream_operations(
"room_data", query, query_args, from_pkey, to_pkey, limit=limit
)
logger.debug("[SQL] %s : %s", query, query_args)
cursor = txn.execute(query, query_args)
return self._as_events(cursor, RoomDataTable, from_pkey)
def _append_stream_operations(self, table_name, query, query_args,
from_pkey, to_pkey, limit=None,
group_by=""):
LATEST_ROW = -1
order_by = ""
if to_pkey > from_pkey:
if from_pkey != LATEST_ROW:
# e.g. from=5 to=9 >> from 5 to 9 >> id>5 AND id<9
query += (" AND %s.id > ? AND %s.id < ?" %
(table_name, table_name))
query_args.append(from_pkey)
query_args.append(to_pkey)
else:
# e.g. from=-1 to=5 >> from now to 5 >> id>5 ORDER BY id DESC
query += " AND %s.id > ? " % table_name
order_by = "ORDER BY id DESC"
query_args.append(to_pkey)
elif from_pkey > to_pkey:
if to_pkey != LATEST_ROW:
# from=9 to=5 >> from 9 to 5 >> id>5 AND id<9 ORDER BY id DESC
query += (" AND %s.id > ? AND %s.id < ? " %
(table_name, table_name))
order_by = "ORDER BY id DESC"
query_args.append(to_pkey)
query_args.append(from_pkey)
else:
# from=5 to=-1 >> from 5 to now >> id>5
query += " AND %s.id > ?" % table_name
query_args.append(from_pkey)
query += group_by + order_by
if limit and limit > 0:
query += " LIMIT ?"
query_args.append(str(limit))
return (query, query_args)
def _as_events(self, cursor, table, from_pkey):
data_entries = table.decode_results(cursor)
last_pkey = from_pkey
if data_entries:
last_pkey = data_entries[-1].id
events = [
entry.as_event(self.event_factory).get_dict()
for entry in data_entries
]
return (events, last_pkey)

View file

@ -0,0 +1,287 @@
# -*- coding: utf-8 -*-
# Copyright 2014 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, Table
from .pdu import PdusTable
from collections import namedtuple
import logging
logger = logging.getLogger(__name__)
class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
def get_received_txn_response(self, transaction_id, origin):
"""For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response
body (as a dict).
Args:
transaction_id (str)
origin(str)
Returns:
tuple: None if we have not previously responded to
this transaction or a 2-tuple of (int, dict)
"""
return self._db_pool.runInteraction(
self._get_received_txn_response, transaction_id, origin
)
def _get_received_txn_response(self, txn, transaction_id, origin):
where_clause = "transaction_id = ? AND origin = ?"
query = ReceivedTransactionsTable.select_statement(where_clause)
txn.execute(query, (transaction_id, origin))
results = ReceivedTransactionsTable.decode_results(txn.fetchall())
if results and results[0].response_code:
return (results[0].response_code, results[0].response_json)
else:
return None
def set_received_txn_response(self, transaction_id, origin, code,
response_dict):
"""Persist the response we returened for an incoming transaction, and
should return for subsequent transactions with the same transaction_id
and origin.
Args:
txn
transaction_id (str)
origin (str)
code (int)
response_json (str)
"""
return self._db_pool.runInteraction(
self._set_received_txn_response,
transaction_id, origin, code, response_dict
)
def _set_received_txn_response(self, txn, transaction_id, origin, code,
response_json):
query = (
"UPDATE %s "
"SET response_code = ?, response_json = ? "
"WHERE transaction_id = ? AND origin = ?"
) % ReceivedTransactionsTable.table_name
txn.execute(query, (code, response_json, transaction_id, origin))
def prep_send_transaction(self, transaction_id, destination, ts, pdu_list):
"""Persists an outgoing transaction and calculates the values for the
previous transaction id list.
This should be called before sending the transaction so that it has the
correct value for the `prev_ids` key.
Args:
transaction_id (str)
destination (str)
ts (int)
pdu_list (list)
Returns:
list: A list of previous transaction ids.
"""
return self._db_pool.runInteraction(
self._prep_send_transaction,
transaction_id, destination, ts, pdu_list
)
def _prep_send_transaction(self, txn, transaction_id, destination, ts,
pdu_list):
# First we find out what the prev_txs should be.
# Since we know that we are only sending one transaction at a time,
# we can simply take the last one.
query = "%s ORDER BY id DESC LIMIT 1" % (
SentTransactions.select_statement("destination = ?"),
)
results = txn.execute(query, (destination,))
results = SentTransactions.decode_results(results)
prev_txns = [r.transaction_id for r in results]
# Actually add the new transaction to the sent_transactions table.
query = SentTransactions.insert_statement()
txn.execute(query, SentTransactions.EntryType(
None,
transaction_id=transaction_id,
destination=destination,
ts=ts,
response_code=0,
response_json=None
))
# Update the tx id -> pdu id mapping
values = [
(transaction_id, destination, pdu[0], pdu[1])
for pdu in pdu_list
]
logger.debug("Inserting: %s", repr(values))
query = TransactionsToPduTable.insert_statement()
txn.executemany(query, values)
return prev_txns
def delivered_txn(self, transaction_id, destination, code, response_dict):
"""Persists the response for an outgoing transaction.
Args:
transaction_id (str)
destination (str)
code (int)
response_json (str)
"""
return self._db_pool.runInteraction(
self._delivered_txn,
transaction_id, destination, code, response_dict
)
def _delivered_txn(cls, txn, transaction_id, destination,
code, response_json):
query = (
"UPDATE %s "
"SET response_code = ?, response_json = ? "
"WHERE transaction_id = ? AND destination = ?"
) % SentTransactions.table_name
txn.execute(query, (code, response_json, transaction_id, destination))
def get_transactions_after(self, transaction_id, destination):
"""Get all transactions after a given local transaction_id.
Args:
transaction_id (str)
destination (str)
Returns:
list: A list of `ReceivedTransactionsTable.EntryType`
"""
return self._db_pool.runInteraction(
self._get_transactions_after, transaction_id, destination
)
def _get_transactions_after(cls, txn, transaction_id, destination):
where = (
"destination = ? AND id > (select id FROM %s WHERE "
"transaction_id = ? AND destination = ?)"
) % (
SentTransactions.table_name
)
query = SentTransactions.select_statement(where)
txn.execute(query, (destination, transaction_id, destination))
return ReceivedTransactionsTable.decode_results(txn.fetchall())
def get_pdus_after_transaction(self, transaction_id, destination):
"""For a given local transaction_id that we sent to a given destination
home server, return a list of PDUs that were sent to that destination
after it.
Args:
txn
transaction_id (str)
destination (str)
Returns
list: A list of PduTuple
"""
return self._db_pool.runInteraction(
self._get_pdus_after_transaction,
transaction_id, destination
)
def _get_pdus_after_transaction(self, txn, transaction_id, destination):
# Query that first get's all transaction_ids with an id greater than
# the one given from the `sent_transactions` table. Then JOIN on this
# from the `tx->pdu` table to get a list of (pdu_id, origin) that
# specify the pdus that were sent in those transactions.
query = (
"SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp "
"INNER JOIN %(sent_tx)s as st "
"ON tp.transaction_id = st.transaction_id "
"AND tp.destination = st.destination "
"WHERE st.id > ("
"SELECT id FROM %(sent_tx)s "
"WHERE transaction_id = ? AND destination = ?"
) % {
"tx_pdu": TransactionsToPduTable.table_name,
"sent_tx": SentTransactions.table_name,
}
txn.execute(query, (transaction_id, destination))
pdus = PdusTable.decode_results(txn.fetchall())
return self._get_pdu_tuples(txn, pdus)
class ReceivedTransactionsTable(Table):
table_name = "received_transactions"
fields = [
"transaction_id",
"origin",
"ts",
"response_code",
"response_json",
"has_been_referenced",
]
EntryType = namedtuple("ReceivedTransactionsEntry", fields)
class SentTransactions(Table):
table_name = "sent_transactions"
fields = [
"id",
"transaction_id",
"destination",
"ts",
"response_code",
"response_json",
]
EntryType = namedtuple("SentTransactionsEntry", fields)
class TransactionsToPduTable(Table):
table_name = "transaction_id_to_pdu"
fields = [
"transaction_id",
"destination",
"pdu_id",
"pdu_origin",
]
EntryType = namedtuple("TransactionsToPduEntry", fields)