Merge pull request #324 from matrix-org/erikj/search

Add filters to search.
This commit is contained in:
Erik Johnston 2015-10-22 17:14:12 +01:00
commit 53c679b59b
5 changed files with 63 additions and 18 deletions

View File

@ -202,6 +202,26 @@ class Filter(object):
return True return True
def filter_rooms(self, room_ids):
"""Apply the 'rooms' filter to a given list of rooms.
Args:
room_ids (list): A list of room_ids.
Returns:
list: A list of room_ids that match the filter
"""
room_ids = set(room_ids)
disallowed_rooms = set(self.filter_json.get("not_rooms", []))
room_ids -= disallowed_rooms
allowed_rooms = self.filter_json.get("rooms", None)
if allowed_rooms is not None:
room_ids &= set(allowed_rooms)
return room_ids
def filter(self, events): def filter(self, events):
return filter(self.check, events) return filter(self.check, events)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.filtering import Filter
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
@ -49,9 +50,12 @@ class SearchHandler(BaseHandler):
keys = content["search_categories"]["room_events"].get("keys", [ keys = content["search_categories"]["room_events"].get("keys", [
"content.body", "content.name", "content.topic", "content.body", "content.name", "content.topic",
]) ])
filter_dict = content["search_categories"]["room_events"].get("filter", {})
except KeyError: except KeyError:
raise SynapseError(400, "Invalid search query") raise SynapseError(400, "Invalid search query")
search_filter = Filter(filter_dict)
# TODO: Search through left rooms too # TODO: Search through left rooms too
rooms = yield self.store.get_rooms_for_user_where_membership_is( rooms = yield self.store.get_rooms_for_user_where_membership_is(
user.to_string(), user.to_string(),
@ -60,15 +64,18 @@ class SearchHandler(BaseHandler):
) )
room_ids = set(r.room_id for r in rooms) room_ids = set(r.room_id for r in rooms)
# TODO: Apply room filter to rooms list room_ids = search_filter.filter_rooms(room_ids)
rank_map, event_map = yield self.store.search_msgs(room_ids, search_term, keys) rank_map, event_map, _ = yield self.store.search_msgs(
room_ids, search_term, keys
allowed_events = yield self._filter_events_for_client( )
user.to_string(), event_map.values()
filtered_events = search_filter.filter(event_map.values())
allowed_events = yield self._filter_events_for_client(
user.to_string(), filtered_events
) )
# TODO: Filter allowed_events
# TODO: Add a limit # TODO: Add a limit
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 24 SCHEMA_VERSION = 25
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
POSTGRES_SQL = """ POSTGRES_SQL = """
CREATE TABLE event_search ( CREATE TABLE IF NOT EXISTS event_search (
event_id TEXT, event_id TEXT,
room_id TEXT, room_id TEXT,
key TEXT, key TEXT,
@ -53,7 +53,8 @@ CREATE INDEX event_search_ev_ridx ON event_search(room_id);
SQLITE_TABLE = ( SQLITE_TABLE = (
"CREATE VIRTUAL TABLE event_search USING fts3 ( event_id, room_id, key, value)" "CREATE VIRTUAL TABLE IF NOT EXISTS event_search"
" USING fts3 ( event_id, room_id, key, value)"
) )

View File

@ -18,6 +18,17 @@ from twisted.internet import defer
from _base import SQLBaseStore from _base import SQLBaseStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from collections import namedtuple
"""The result of a search.
Fields:
rank_map (dict): Mapping event_id -> rank
event_map (dict): Mapping event_id -> event
pagination_token (str): Pagination token
"""
SearchResult = namedtuple("SearchResult", ("rank_map", "event_map", "pagination_token"))
class SearchStore(SQLBaseStore): class SearchStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
@ -31,15 +42,18 @@ class SearchStore(SQLBaseStore):
"content.body", "content.name", "content.topic" "content.body", "content.name", "content.topic"
Returns: Returns:
2-tuple of (dict event_id -> rank, dict event_id -> event) SearchResult
""" """
clauses = [] clauses = []
args = [] args = []
clauses.append( # Make sure we don't explode because the person is in too many rooms.
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) # We filter the results below regardless.
) if len(room_ids) < 500:
args.extend(room_ids) clauses.append(
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
)
args.extend(room_ids)
local_clauses = [] local_clauses = []
for key in keys: for key in keys:
@ -52,13 +66,13 @@ class SearchStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
sql = ( sql = (
"SELECT ts_rank_cd(vector, query) AS rank, event_id" "SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search" " FROM plainto_tsquery('english', ?) as query, event_search"
" WHERE vector @@ query" " WHERE vector @@ query"
) )
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
sql = ( sql = (
"SELECT 0 as rank, event_id FROM event_search" "SELECT 0 as rank, room_id, event_id FROM event_search"
" WHERE value MATCH ?" " WHERE value MATCH ?"
) )
else: else:
@ -76,6 +90,8 @@ class SearchStore(SQLBaseStore):
"search_msgs", self.cursor_to_dict, sql, *([search_term] + args) "search_msgs", self.cursor_to_dict, sql, *([search_term] + args)
) )
results = filter(lambda row: row["room_id"] in room_ids, results)
events = yield self._get_events([r["event_id"] for r in results]) events = yield self._get_events([r["event_id"] for r in results])
event_map = { event_map = {
@ -83,11 +99,12 @@ class SearchStore(SQLBaseStore):
for ev in events for ev in events
} }
defer.returnValue(( defer.returnValue(SearchResult(
{ {
r["event_id"]: r["rank"] r["event_id"]: r["rank"]
for r in results for r in results
if r["event_id"] in event_map if r["event_id"] in event_map
}, },
event_map event_map,
None
)) ))