From 53d1a2945df1912abc5d4a143424f68f49d48790 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Wed, 3 Jul 2019 17:39:57 +0200 Subject: [PATCH] index: Make the index optional. This patch disables the message indexer an message store if tantivy is not installed. The search API endpoint forwards all search requests to the homeserver. --- pantalaimon/client.py | 50 ++- pantalaimon/daemon.py | 5 +- pantalaimon/index.py | 866 +++++++++++++++++++++--------------------- 3 files changed, 478 insertions(+), 443 deletions(-) diff --git a/pantalaimon/client.py b/pantalaimon/client.py index c9994ab..0a4d242 100644 --- a/pantalaimon/client.py +++ b/pantalaimon/client.py @@ -43,7 +43,7 @@ from nio import ( from nio.crypto import Sas from nio.store import SqliteStore -from pantalaimon.index import IndexStore +from pantalaimon.index import INDEXING_ENABLED from pantalaimon.log import logger from pantalaimon.store import FetchTask from pantalaimon.thread_messages import ( @@ -151,7 +151,16 @@ class PanClient(AsyncClient): self.server_name = server_name self.pan_store = pan_store self.pan_conf = pan_conf - self.index = IndexStore(self.user_id, index_dir) + + if INDEXING_ENABLED: + logger.info("Indexing enabled.") + from pantalaimon.index import IndexStore + + self.index = IndexStore(self.user_id, index_dir) + else: + logger.info("Indexing disabled.") + self.index = None + self.task = None self.queue = queue @@ -159,26 +168,32 @@ class PanClient(AsyncClient): self.send_semaphores = defaultdict(asyncio.Semaphore) self.send_decision_queues = dict() # type: asyncio.Queue + self.last_sync_token = None self.history_fetcher_task = None self.history_fetch_queue = asyncio.Queue() self.add_to_device_callback(self.key_verification_cb, KeyVerificationEvent) self.add_event_callback(self.undecrypted_event_cb, MegolmEvent) - self.add_event_callback( - self.store_message_cb, - ( - RoomMessageText, - RoomMessageMedia, - RoomEncryptedMedia, - RoomTopicEvent, - RoomNameEvent, - ), - ) + + if INDEXING_ENABLED: + self.add_event_callback( + self.store_message_cb, + ( + RoomMessageText, + RoomMessageMedia, + RoomEncryptedMedia, + RoomTopicEvent, + RoomNameEvent, + ), + ) + self.add_response_callback(self.keys_query_cb, KeysQueryResponse) self.add_response_callback(self.sync_tasks, SyncResponse) def store_message_cb(self, room, event): + assert INDEXING_ENABLED + display_name = room.user_name(event.sender) avatar_url = room.avatar_url(event.sender) @@ -233,6 +248,8 @@ class PanClient(AsyncClient): self.pan_store.delete_fetcher_task(self.server_name, self.user_id, task) async def fetcher_loop(self): + assert INDEXING_ENABLED + for t in self.pan_store.load_fetcher_tasks(self.server_name, self.user_id): await self.history_fetch_queue.put(t) @@ -300,7 +317,9 @@ class PanClient(AsyncClient): return async def sync_tasks(self, response): - await self.index.commit_events() + if self.index: + await self.index.commit_events() + self.pan_store.save_token(self.server_name, self.user_id, self.next_batch) for room_id, room_info in response.rooms.join.items(): @@ -392,7 +411,8 @@ class PanClient(AsyncClient): loop = asyncio.get_event_loop() - self.history_fetcher_task = loop.create_task(self.fetcher_loop()) + if INDEXING_ENABLED: + self.history_fetcher_task = loop.create_task(self.fetcher_loop()) timeout = 30000 sync_filter = {"room": {"state": {"lazy_load_members": True}}} @@ -671,6 +691,8 @@ class PanClient(AsyncClient): async def search(self, search_terms): # type: (Dict[Any, Any]) -> Dict[Any, Any] + assert INDEXING_ENABLED + state_cache = dict() async def add_context(event_dict, room_id, event_id, include_state): diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index 4c0f656..4069c35 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -35,7 +35,7 @@ from pantalaimon.client import ( UnknownRoomError, validate_json, ) -from pantalaimon.index import InvalidQueryError +from pantalaimon.index import INDEXING_ENABLED, InvalidQueryError from pantalaimon.log import logger from pantalaimon.store import ClientInfo, PanStore from pantalaimon.thread_messages import ( @@ -905,6 +905,9 @@ class ProxyDaemon: if not access_token: return self._missing_token + if not INDEXING_ENABLED: + return await self.forward_to_web(request) + client = await self._find_client(access_token) if not client: diff --git a/pantalaimon/index.py b/pantalaimon/index.py index 87cc189..8192d00 100644 --- a/pantalaimon/index.py +++ b/pantalaimon/index.py @@ -12,487 +12,497 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import datetime -import json -import os -from functools import partial -from typing import Any, Dict, List, Optional, Tuple - -import attr -import tantivy -from nio import ( - RoomEncryptedMedia, - RoomMessageMedia, - RoomMessageText, - RoomNameEvent, - RoomTopicEvent, -) -from peewee import SQL, DateTimeField, ForeignKeyField, Model, SqliteDatabase, TextField - -from pantalaimon.store import use_database +from importlib import util class InvalidQueryError(Exception): pass -class DictField(TextField): - def python_value(self, value): # pragma: no cover - return json.loads(value) +if util.find_spec("tantivy"): + import asyncio + import datetime + import json + import os + from functools import partial + from typing import Any, Dict, List, Optional, Tuple - def db_value(self, value): # pragma: no cover - return json.dumps(value) + import attr + import tantivy + from nio import ( + RoomEncryptedMedia, + RoomMessageMedia, + RoomMessageText, + RoomNameEvent, + RoomTopicEvent, + ) + from peewee import ( + SQL, + DateTimeField, + ForeignKeyField, + Model, + SqliteDatabase, + TextField, + ) + from pantalaimon.store import use_database -class StoreUser(Model): - user_id = TextField() + INDEXING_ENABLED = True - class Meta: - constraints = [SQL("UNIQUE(user_id)")] + class DictField(TextField): + def python_value(self, value): # pragma: no cover + return json.loads(value) + def db_value(self, value): # pragma: no cover + return json.dumps(value) -class Profile(Model): - user_id = TextField() - avatar_url = TextField(null=True) - display_name = TextField(null=True) + class StoreUser(Model): + user_id = TextField() - class Meta: - constraints = [SQL("UNIQUE(user_id,avatar_url,display_name)")] + class Meta: + constraints = [SQL("UNIQUE(user_id)")] + class Profile(Model): + user_id = TextField() + avatar_url = TextField(null=True) + display_name = TextField(null=True) -class Event(Model): - event_id = TextField() - sender = TextField() - date = DateTimeField() - room_id = TextField() + class Meta: + constraints = [SQL("UNIQUE(user_id,avatar_url,display_name)")] - source = DictField() + class Event(Model): + event_id = TextField() + sender = TextField() + date = DateTimeField() + room_id = TextField() - profile = ForeignKeyField(model=Profile, column_name="profile_id") + source = DictField() - class Meta: - constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")] + profile = ForeignKeyField(model=Profile, column_name="profile_id") + class Meta: + constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")] -class UserMessages(Model): - user = ForeignKeyField(model=StoreUser, column_name="user_id") - event = ForeignKeyField(model=Event, column_name="event_id") + class UserMessages(Model): + user = ForeignKeyField(model=StoreUser, column_name="user_id") + event = ForeignKeyField(model=Event, column_name="event_id") + @attr.s + class MessageStore: + user = attr.ib(type=str) + store_path = attr.ib(type=str) + database_name = attr.ib(type=str) + database = attr.ib(type=SqliteDatabase, init=False) + database_path = attr.ib(type=str, init=False) -@attr.s -class MessageStore: - user = attr.ib(type=str) - store_path = attr.ib(type=str) - database_name = attr.ib(type=str) - database = attr.ib(type=SqliteDatabase, init=False) - database_path = attr.ib(type=str, init=False) + models = [StoreUser, Event, Profile, UserMessages] - models = [StoreUser, Event, Profile, UserMessages] - - def __attrs_post_init__(self): - self.database_path = os.path.join( - os.path.abspath(self.store_path), self.database_name - ) - - self.database = self._create_database() - self.database.connect() - - with self.database.bind_ctx(self.models): - self.database.create_tables(self.models) - - def _create_database(self): - return SqliteDatabase( - self.database_path, pragmas={"foreign_keys": 1, "secure_delete": 1} - ) - - @use_database - def event_in_store(self, event_id, room_id): - user, _ = StoreUser.get_or_create(user_id=self.user) - query = ( - Event.select() - .join(UserMessages) - .where( - (Event.room_id == room_id) - & (Event.event_id == event_id) - & (UserMessages.user == user) + def __attrs_post_init__(self): + self.database_path = os.path.join( + os.path.abspath(self.store_path), self.database_name ) - .execute() - ) - for _ in query: - return True + self.database = self._create_database() + self.database.connect() - return False + with self.database.bind_ctx(self.models): + self.database.create_tables(self.models) - def save_event(self, event, room_id, display_name=None, avatar_url=None): - user, _ = StoreUser.get_or_create(user_id=self.user) - - profile_id, _ = Profile.get_or_create( - user_id=event.sender, display_name=display_name, avatar_url=avatar_url - ) - - event_source = event.source - event_source["room_id"] = room_id - - event_id = ( - Event.insert( - event_id=event.event_id, - sender=event.sender, - date=datetime.datetime.fromtimestamp(event.server_timestamp / 1000), - room_id=room_id, - source=event_source, - profile=profile_id, + def _create_database(self): + return SqliteDatabase( + self.database_path, pragmas={"foreign_keys": 1, "secure_delete": 1} ) - .on_conflict_ignore() - .execute() - ) - if event_id <= 0: + @use_database + def event_in_store(self, event_id, room_id): + user, _ = StoreUser.get_or_create(user_id=self.user) + query = ( + Event.select() + .join(UserMessages) + .where( + (Event.room_id == room_id) + & (Event.event_id == event_id) + & (UserMessages.user == user) + ) + .execute() + ) + + for _ in query: + return True + + return False + + def save_event(self, event, room_id, display_name=None, avatar_url=None): + user, _ = StoreUser.get_or_create(user_id=self.user) + + profile_id, _ = Profile.get_or_create( + user_id=event.sender, display_name=display_name, avatar_url=avatar_url + ) + + event_source = event.source + event_source["room_id"] = room_id + + event_id = ( + Event.insert( + event_id=event.event_id, + sender=event.sender, + date=datetime.datetime.fromtimestamp(event.server_timestamp / 1000), + room_id=room_id, + source=event_source, + profile=profile_id, + ) + .on_conflict_ignore() + .execute() + ) + + if event_id <= 0: + return None + + _, created = UserMessages.get_or_create(user=user, event=event_id) + + if created: + return event_id + return None - _, created = UserMessages.get_or_create(user=user, event=event_id) + def _load_context(self, user, event, before, after): + context = {} - if created: - return event_id - - return None - - def _load_context(self, user, event, before, after): - context = {} - - if before > 0: - query = ( - Event.select() - .join(UserMessages) - .where( - (Event.date <= event.date) - & (Event.room_id == event.room_id) - & (Event.id != event.id) - & (UserMessages.user == user) + if before > 0: + query = ( + Event.select() + .join(UserMessages) + .where( + (Event.date <= event.date) + & (Event.room_id == event.room_id) + & (Event.id != event.id) + & (UserMessages.user == user) + ) + .order_by(Event.date.desc()) + .limit(before) ) - .order_by(Event.date.desc()) - .limit(before) + + context["events_before"] = [e.source for e in query] + else: + context["events_before"] = [] + + if after > 0: + query = ( + Event.select() + .join(UserMessages) + .where( + (Event.date >= event.date) + & (Event.room_id == event.room_id) + & (Event.id != event.id) + & (UserMessages.user == user) + ) + .order_by(Event.date) + .limit(after) + ) + + context["events_after"] = [e.source for e in query] + else: + context["events_after"] = [] + + return context + + @use_database + def load_events( + self, + search_result, # type: List[Tuple[int, int]] + include_profile=False, # type: bool + order_by_recent=False, # type: bool + before=0, # type: int + after=0, # type: int + ): + # type: (...) -> Dict[Any, Any] + user, _ = StoreUser.get_or_create(user_id=self.user) + + search_dict = {r[1]: r[0] for r in search_result} + columns = list(search_dict.keys()) + + result_dict = {"results": []} + + query = ( + UserMessages.select() + .where( + (UserMessages.user_id == user) & (UserMessages.event.in_(columns)) + ) + .execute() ) - context["events_before"] = [e.source for e in query] - else: - context["events_before"] = [] + for message in query: - if after > 0: - query = ( - Event.select() - .join(UserMessages) - .where( - (Event.date >= event.date) - & (Event.room_id == event.room_id) - & (Event.id != event.id) - & (UserMessages.user == user) - ) - .order_by(Event.date) - .limit(after) - ) + event = message.event - context["events_after"] = [e.source for e in query] - else: - context["events_after"] = [] - - return context - - @use_database - def load_events( - self, - search_result, # type: List[Tuple[int, int]] - include_profile=False, # type: bool - order_by_recent=False, # type: bool - before=0, # type: int - after=0, # type: int - ): - # type: (...) -> Dict[Any, Any] - user, _ = StoreUser.get_or_create(user_id=self.user) - - search_dict = {r[1]: r[0] for r in search_result} - columns = list(search_dict.keys()) - - result_dict = {"results": []} - - query = ( - UserMessages.select() - .where((UserMessages.user_id == user) & (UserMessages.event.in_(columns))) - .execute() - ) - - for message in query: - - event = message.event - - event_dict = { - "rank": 1 if order_by_recent else search_dict[event.id], - "result": event.source, - "context": {}, - } - - if include_profile: - event_profile = event.profile - - event_dict["context"]["profile_info"] = { - event_profile.user_id: { - "display_name": event_profile.display_name, - "avatar_url": event_profile.avatar_url, - } + event_dict = { + "rank": 1 if order_by_recent else search_dict[event.id], + "result": event.source, + "context": {}, } - context = self._load_context(user, event, before, after) + if include_profile: + event_profile = event.profile - event_dict["context"]["events_before"] = context["events_before"] - event_dict["context"]["events_after"] = context["events_after"] + event_dict["context"]["profile_info"] = { + event_profile.user_id: { + "display_name": event_profile.display_name, + "avatar_url": event_profile.avatar_url, + } + } - result_dict["results"].append(event_dict) + context = self._load_context(user, event, before, after) - return result_dict + event_dict["context"]["events_before"] = context["events_before"] + event_dict["context"]["events_after"] = context["events_after"] + result_dict["results"].append(event_dict) -def sanitize_room_id(room_id): - return room_id.replace(":", "/").replace("!", "") + return result_dict + def sanitize_room_id(room_id): + return room_id.replace(":", "/").replace("!", "") -class Searcher: - def __init__( - self, - index, - body_field, - name_field, - topic_field, - column_field, - room_field, - timestamp_field, - searcher, - ): - self._index = index - self._searcher = searcher + class Searcher: + def __init__( + self, + index, + body_field, + name_field, + topic_field, + column_field, + room_field, + timestamp_field, + searcher, + ): + self._index = index + self._searcher = searcher - self.body_field = body_field - self.name_field = topic_field - self.topic_field = name_field - self.column_field = column_field - self.room_field = room_field - self.timestamp_field = timestamp_field + self.body_field = body_field + self.name_field = topic_field + self.topic_field = name_field + self.column_field = column_field + self.room_field = room_field + self.timestamp_field = timestamp_field - def search(self, search_term, room=None, max_results=10, order_by_recent=False): - # type (str, str, int, bool) -> List[int, int] - """Search for events in the index. + def search(self, search_term, room=None, max_results=10, order_by_recent=False): + # type (str, str, int, bool) -> List[int, int] + """Search for events in the index. - Returns the score and the column id for the event. - """ - queryparser = tantivy.QueryParser.for_index( - self._index, - [self.body_field, self.name_field, self.topic_field, self.room_field], - ) - - # This currently supports only a single room since the query parser - # doesn't seem to work with multiple room fields here. - if room: - query_string = "{} AND room:{}".format(search_term, sanitize_room_id(room)) - else: - query_string = search_term - - try: - query = queryparser.parse_query(query_string) - except ValueError: - raise InvalidQueryError(f"Invalid search term: {search_term}") - - if order_by_recent: - collector = tantivy.TopDocs( - max_results, order_by_field=self.timestamp_field - ) - else: - collector = tantivy.TopDocs(max_results) - - result = self._searcher.search(query, collector) - - retrieved_result = [] - - for score, doc_address in result: - doc = self._searcher.doc(doc_address) - column = doc.get_first(self.column_field) - retrieved_result.append((score, column)) - - return retrieved_result - - -class Index: - def __init__(self, path=None, num_searchers=None): - schema_builder = tantivy.SchemaBuilder() - - self.body_field = schema_builder.add_text_field("body") - self.name_field = schema_builder.add_text_field("name") - self.topic_field = schema_builder.add_text_field("topic") - - self.timestamp_field = schema_builder.add_unsigned_field( - "server_timestamp", fast="single" - ) - self.date_field = schema_builder.add_date_field("message_date") - self.room_field = schema_builder.add_facet_field("room") - - self.column_field = schema_builder.add_unsigned_field( - "database_column", indexed=True, stored=True, fast="single" - ) - - schema = schema_builder.build() - - self.index = tantivy.Index(schema, path) - - self.reader = self.index.reader(num_searchers=num_searchers) - self.writer = self.index.writer() - - def add_event(self, column_id, event, room_id): - doc = tantivy.Document() - - room_path = "/{}".format(sanitize_room_id(room_id)) - - room_facet = tantivy.Facet.from_string(room_path) - - doc.add_unsigned(self.column_field, column_id) - doc.add_facet(self.room_field, room_facet) - doc.add_date( - self.date_field, - datetime.datetime.fromtimestamp(event.server_timestamp / 1000), - ) - doc.add_unsigned(self.timestamp_field, event.server_timestamp) - - if isinstance(event, RoomMessageText): - doc.add_text(self.body_field, event.body) - elif isinstance(event, (RoomMessageMedia, RoomEncryptedMedia)): - doc.add_text(self.body_field, event.body) - elif isinstance(event, RoomNameEvent): - doc.add_text(self.name_field, event.name) - elif isinstance(event, RoomTopicEvent): - doc.add_text(self.topic_field, event.topic) - else: - raise ValueError("Invalid event passed.") - - self.writer.add_document(doc) - - def commit(self): - self.writer.commit() - - def searcher(self): - self.reader.reload() - return Searcher( - self.index, - self.body_field, - self.name_field, - self.topic_field, - self.column_field, - self.room_field, - self.timestamp_field, - self.reader.searcher(), - ) - - -@attr.s -class StoreItem: - event = attr.ib() - room_id = attr.ib() - display_name = attr.ib(default=None) - avatar_url = attr.ib(default=None) - - -@attr.s -class IndexStore: - user = attr.ib(type=str) - index_path = attr.ib(type=str) - store_path = attr.ib(type=str, default=None) - store_name = attr.ib(default="events.db") - - index = attr.ib(type=Index, init=False) - store = attr.ib(type=MessageStore, init=False) - event_queue = attr.ib(factory=list) - write_lock = attr.ib(factory=asyncio.Lock) - read_semaphore = attr.ib(type=asyncio.Semaphore, init=False) - - def __attrs_post_init__(self): - self.store_path = self.store_path or self.index_path - num_searchers = os.cpu_count() - self.index = Index(self.index_path, num_searchers) - self.read_semaphore = asyncio.Semaphore(num_searchers or 1) - self.store = MessageStore(self.user, self.store_path, self.store_name) - - def add_event(self, event, room_id, display_name, avatar_url): - item = StoreItem(event, room_id, display_name, avatar_url) - self.event_queue.append(item) - - @staticmethod - def write_events(store, index, event_queue): - with store.database.bind_ctx(store.models): - with store.database.atomic(): - for item in event_queue: - column_id = store.save_event(item.event, item.room_id) - - if column_id: - index.add_event(column_id, item.event, item.room_id) - index.commit() - - async def commit_events(self): - loop = asyncio.get_event_loop() - - event_queue = self.event_queue - - if not event_queue: - return - - self.event_queue = [] - - async with self.write_lock: - write_func = partial( - IndexStore.write_events, self.store, self.index, event_queue - ) - await loop.run_in_executor(None, write_func) - - def event_in_store(self, event_id, room_id): - return self.store.event_in_store(event_id, room_id) - - async def search( - self, - search_term, # type: str - room=None, # type: Optional[str] - max_results=10, # type: int - order_by_recent=False, # type: bool - include_profile=False, # type: bool - before_limit=0, # type: int - after_limit=0, # type: int - ): - # type: (...) -> Dict[Any, Any] - """Search the indexstore for an event.""" - loop = asyncio.get_event_loop() - - # Getting a searcher from tantivy may block if there is no searcher - # available. To avoid blocking we set up the number of searchers to be - # the number of CPUs and the semaphore has the same counter value. - async with self.read_semaphore: - searcher = self.index.searcher() - search_func = partial( - searcher.search, - search_term, - room=room, - max_results=max_results, - order_by_recent=order_by_recent, + Returns the score and the column id for the event. + """ + queryparser = tantivy.QueryParser.for_index( + self._index, + [self.body_field, self.name_field, self.topic_field, self.room_field], ) - result = await loop.run_in_executor(None, search_func) + # This currently supports only a single room since the query parser + # doesn't seem to work with multiple room fields here. + if room: + query_string = "{} AND room:{}".format( + search_term, sanitize_room_id(room) + ) + else: + query_string = search_term - load_event_func = partial( - self.store.load_events, - result, - include_profile, - order_by_recent, - before_limit, - after_limit, + try: + query = queryparser.parse_query(query_string) + except ValueError: + raise InvalidQueryError(f"Invalid search term: {search_term}") + + if order_by_recent: + collector = tantivy.TopDocs( + max_results, order_by_field=self.timestamp_field + ) + else: + collector = tantivy.TopDocs(max_results) + + result = self._searcher.search(query, collector) + + retrieved_result = [] + + for score, doc_address in result: + doc = self._searcher.doc(doc_address) + column = doc.get_first(self.column_field) + retrieved_result.append((score, column)) + + return retrieved_result + + class Index: + def __init__(self, path=None, num_searchers=None): + schema_builder = tantivy.SchemaBuilder() + + self.body_field = schema_builder.add_text_field("body") + self.name_field = schema_builder.add_text_field("name") + self.topic_field = schema_builder.add_text_field("topic") + + self.timestamp_field = schema_builder.add_unsigned_field( + "server_timestamp", fast="single" + ) + self.date_field = schema_builder.add_date_field("message_date") + self.room_field = schema_builder.add_facet_field("room") + + self.column_field = schema_builder.add_unsigned_field( + "database_column", indexed=True, stored=True, fast="single" ) - search_result = await loop.run_in_executor(None, load_event_func) + schema = schema_builder.build() - search_result["count"] = len(search_result["results"]) - search_result["highlights"] = [] + self.index = tantivy.Index(schema, path) - return search_result + self.reader = self.index.reader(num_searchers=num_searchers) + self.writer = self.index.writer() + + def add_event(self, column_id, event, room_id): + doc = tantivy.Document() + + room_path = "/{}".format(sanitize_room_id(room_id)) + + room_facet = tantivy.Facet.from_string(room_path) + + doc.add_unsigned(self.column_field, column_id) + doc.add_facet(self.room_field, room_facet) + doc.add_date( + self.date_field, + datetime.datetime.fromtimestamp(event.server_timestamp / 1000), + ) + doc.add_unsigned(self.timestamp_field, event.server_timestamp) + + if isinstance(event, RoomMessageText): + doc.add_text(self.body_field, event.body) + elif isinstance(event, (RoomMessageMedia, RoomEncryptedMedia)): + doc.add_text(self.body_field, event.body) + elif isinstance(event, RoomNameEvent): + doc.add_text(self.name_field, event.name) + elif isinstance(event, RoomTopicEvent): + doc.add_text(self.topic_field, event.topic) + else: + raise ValueError("Invalid event passed.") + + self.writer.add_document(doc) + + def commit(self): + self.writer.commit() + + def searcher(self): + self.reader.reload() + return Searcher( + self.index, + self.body_field, + self.name_field, + self.topic_field, + self.column_field, + self.room_field, + self.timestamp_field, + self.reader.searcher(), + ) + + @attr.s + class StoreItem: + event = attr.ib() + room_id = attr.ib() + display_name = attr.ib(default=None) + avatar_url = attr.ib(default=None) + + @attr.s + class IndexStore: + user = attr.ib(type=str) + index_path = attr.ib(type=str) + store_path = attr.ib(type=str, default=None) + store_name = attr.ib(default="events.db") + + index = attr.ib(type=Index, init=False) + store = attr.ib(type=MessageStore, init=False) + event_queue = attr.ib(factory=list) + write_lock = attr.ib(factory=asyncio.Lock) + read_semaphore = attr.ib(type=asyncio.Semaphore, init=False) + + def __attrs_post_init__(self): + self.store_path = self.store_path or self.index_path + num_searchers = os.cpu_count() + self.index = Index(self.index_path, num_searchers) + self.read_semaphore = asyncio.Semaphore(num_searchers or 1) + self.store = MessageStore(self.user, self.store_path, self.store_name) + + def add_event(self, event, room_id, display_name, avatar_url): + item = StoreItem(event, room_id, display_name, avatar_url) + self.event_queue.append(item) + + @staticmethod + def write_events(store, index, event_queue): + with store.database.bind_ctx(store.models): + with store.database.atomic(): + for item in event_queue: + column_id = store.save_event(item.event, item.room_id) + + if column_id: + index.add_event(column_id, item.event, item.room_id) + index.commit() + + async def commit_events(self): + loop = asyncio.get_event_loop() + + event_queue = self.event_queue + + if not event_queue: + return + + self.event_queue = [] + + async with self.write_lock: + write_func = partial( + IndexStore.write_events, self.store, self.index, event_queue + ) + await loop.run_in_executor(None, write_func) + + def event_in_store(self, event_id, room_id): + return self.store.event_in_store(event_id, room_id) + + async def search( + self, + search_term, # type: str + room=None, # type: Optional[str] + max_results=10, # type: int + order_by_recent=False, # type: bool + include_profile=False, # type: bool + before_limit=0, # type: int + after_limit=0, # type: int + ): + # type: (...) -> Dict[Any, Any] + """Search the indexstore for an event.""" + loop = asyncio.get_event_loop() + + # Getting a searcher from tantivy may block if there is no searcher + # available. To avoid blocking we set up the number of searchers to be + # the number of CPUs and the semaphore has the same counter value. + async with self.read_semaphore: + searcher = self.index.searcher() + search_func = partial( + searcher.search, + search_term, + room=room, + max_results=max_results, + order_by_recent=order_by_recent, + ) + + result = await loop.run_in_executor(None, search_func) + + load_event_func = partial( + self.store.load_events, + result, + include_profile, + order_by_recent, + before_limit, + after_limit, + ) + + search_result = await loop.run_in_executor(None, load_event_func) + + search_result["count"] = len(search_result["results"]) + search_result["highlights"] = [] + + return search_result + + +else: + INDEXING_ENABLED = False