diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index efa63031b..7c5631d01 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -61,6 +61,7 @@ SCHEMAS = [ "event_edges", "event_signatures", "media_repository", + "filtering", ] diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index 18e0e7c29..e98eaf803 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -17,6 +17,8 @@ from twisted.internet import defer from ._base import SQLBaseStore +import json + # TODO(paul) _filters_for_user = {} @@ -25,22 +27,41 @@ _filters_for_user = {} class FilteringStore(SQLBaseStore): @defer.inlineCallbacks def get_user_filter(self, user_localpart, filter_id): - filters = _filters_for_user.get(user_localpart, None) + def_json = yield self._simple_select_one_onecol( + table="user_filters", + keyvalues={ + "user_id": user_localpart, + "filter_id": filter_id, + }, + retcol="definition", + allow_none=False, + ) - if not filters or filter_id >= len(filters): - raise KeyError() + defer.returnValue(json.loads(def_json)) - # trivial yield to make it a generator so d.iC works - yield - defer.returnValue(filters[filter_id]) - - @defer.inlineCallbacks def add_user_filter(self, user_localpart, definition): - filters = _filters_for_user.setdefault(user_localpart, []) + def_json = json.dumps(definition) - filter_id = len(filters) - filters.append(definition) + # Need an atomic transaction to SELECT the maximal ID so far then + # INSERT a new one + def _do_txn(txn): + sql = ( + "SELECT MAX(filter_id) FROM user_filters " + "WHERE user_id = ?" + ) + txn.execute(sql, (user_localpart,)) + max_id = txn.fetchone()[0] + if max_id is None: + filter_id = 0 + else: + filter_id = max_id + 1 - # trivial yield, see above - yield - defer.returnValue(filter_id) + sql = ( + "INSERT INTO user_filters (user_id, filter_id, definition)" + "VALUES(?, ?, ?)" + ) + txn.execute(sql, (user_localpart, filter_id, def_json)) + + return filter_id + + return self.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/schema/filtering.sql b/synapse/storage/schema/filtering.sql new file mode 100644 index 000000000..795aca4af --- /dev/null +++ b/synapse/storage/schema/filtering.sql @@ -0,0 +1,24 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS user_filters( + user_id TEXT, + filter_id INTEGER, + definition TEXT, + FOREIGN KEY(user_id) REFERENCES users(id) +); + +CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters( + user_id, filter_id +); diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index fecadd105..149948374 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -53,16 +53,33 @@ class FilteringTestCase(unittest.TestCase): self.filtering = hs.get_filtering() + self.datastore = hs.get_datastore() + @defer.inlineCallbacks - def test_filter(self): + def test_add_filter(self): filter_id = yield self.filtering.add_user_filter( user_localpart=user_localpart, definition={"type": ["m.*"]}, ) + self.assertEquals(filter_id, 0) + self.assertEquals({"type": ["m.*"]}, + (yield self.datastore.get_user_filter( + user_localpart=user_localpart, + filter_id=0, + )) + ) + + @defer.inlineCallbacks + def test_get_filter(self): + filter_id = yield self.datastore.add_user_filter( + user_localpart=user_localpart, + definition={"type": ["m.*"]}, + ) filter = yield self.filtering.get_user_filter( user_localpart=user_localpart, filter_id=filter_id, ) + self.assertEquals(filter, {"type": ["m.*"]})