diff --git a/pantalaimon/client.py b/pantalaimon/client.py index e8f73c5..ae670d1 100644 --- a/pantalaimon/client.py +++ b/pantalaimon/client.py @@ -24,7 +24,8 @@ from jsonschema import Draft4Validator, FormatChecker, validators from nio import (AsyncClient, ClientConfig, EncryptionError, KeysQueryResponse, KeyVerificationEvent, KeyVerificationKey, KeyVerificationMac, KeyVerificationStart, LocalProtocolError, MegolmEvent, - RoomEncryptedEvent, RoomMessage, SyncResponse) + RoomEncryptedEvent, RoomMessage, SyncResponse, + RoomContextError) from nio.crypto import Sas from nio.store import SqliteStore @@ -54,7 +55,7 @@ SEARCH_TERMS_SCHEMA = { "order_by": {"type": "string", "default": "rank"}, "include_state": {"type": "boolean", "default": False}, "filter": {"type": "object", "default": {}}, - "event_context": {"type": "object", "default": {}}, + "event_context": {"type": "object"}, "groupings": {"type": "object", "default": {}}, }, "required": ["search_term"] @@ -564,14 +565,53 @@ class PanClient(AsyncClient): async def search(self, search_terms): # type: (Dict[Any, Any]) -> Dict[Any, Any] loop = asyncio.get_event_loop() + state_cache = dict() + + async def add_context(room_id, event_id, before, after, include_state): + try: + context = await self.room_context(room_id, event_id, + limit=before+after) + except ClientConnectionError: + return + + if isinstance(context, RoomContextError): + return + + if include_state: + state_cache[room_id] = [e.source for e in context.state] + + event_context = event_dict["context"] + + event_context["events_before"] = [ + e.source for e in context.events_before[:before] + ] + event_context["events_after"] = [ + e.source for e in context.events_after[:after] + ] + event_context["start"] = context.start + event_context["end"] = context.end validate_json(search_terms, SEARCH_TERMS_SCHEMA) search_terms = search_terms["search_categories"]["room_events"] term = search_terms["search_term"] + search_filter = search_terms["filter"] + limit = search_filter.get("limit", 10) + + before_limit = 0 + after_limit = 0 + include_profile = False + + event_context = search_terms.get("event_context") + include_state = search_terms.get("include_state") + + if event_context: + before_limit = event_context.get("before_limit", 5) + after_limit = event_context.get("before_limit", 5) + include_profile = event_context.get("include_profile", False) searcher = self.index.searcher() - search_func = partial(searcher.search, term) + search_func = partial(searcher.search, term, max_results=limit) result = await loop.run_in_executor(None, search_func) @@ -580,30 +620,29 @@ class PanClient(AsyncClient): } for score, column_id in result: - event = self.pan_store.load_event_by_columns( + event_dict = self.pan_store.load_event_by_columns( self.server_name, self.user_id, - column_id) + column_id, include_profile) - if not event: + if not event_dict: continue - event_dict = { - "rank": score, - "result": event, - } + if include_state or before_limit or after_limit: + await add_context( + event_dict["result"]["room_id"], + event_dict["result"]["event_id"], + before_limit, + after_limit, + include_state + ) - if False: - # TODO load the context from the server - event_dict["context"] = {} - - if False: - # TODO add profile info - pass + event_dict["rank"] = score result_dict["results"].append(event_dict) result_dict["count"] = len(result_dict["results"]) result_dict["highlight"] = [] + result_dict["state"] = state_cache return result_dict diff --git a/pantalaimon/store.py b/pantalaimon/store.py index 4878458..1d725ea 100644 --- a/pantalaimon/store.py +++ b/pantalaimon/store.py @@ -217,19 +217,23 @@ class PanStore: return None event = message.event - event_profile = event.profile - profile = { - event_profile.user_id: { - "display_name": event_profile.display_name, - "avatar_url": event_profile.avatar_url, - } + event_dict = { + "result": event.source, + "context": {} } if include_profile: - return event.source, profile + event_profile = event.profile - return event.source + event_dict["context"]["profile_info"] = { + event_profile.user_id: { + "display_name": event_profile.display_name, + "avatar_url": event_profile.avatar_url, + } + } + + return event_dict @use_database def save_server_user(self, server_name, user_id):