pantalaimon: Initial search support.

This patch adds support for the Matrix search API endpoint.

Events are stored in the pan sqlite database while the indexing is
handled by tanvity.

An tantivy index is created for each pan user. Currently only ordering
by ranking is supported, and all the search options are ignored for now.
This commit is contained in:
Damir Jelić 2019-06-06 11:14:40 +02:00
parent 9444e540df
commit 725c043e87
7 changed files with 590 additions and 48 deletions

View File

@ -13,32 +13,100 @@
# limitations under the License.
import asyncio
import os
from collections import defaultdict
from functools import partial
from pprint import pformat
from typing import Any, Dict, Optional
from aiohttp.client_exceptions import ClientConnectionError
from jsonschema import Draft4Validator, FormatChecker, validators
from nio import (AsyncClient, ClientConfig, EncryptionError, KeysQueryResponse,
KeyVerificationEvent, KeyVerificationKey, KeyVerificationMac,
KeyVerificationStart, LocalProtocolError, MegolmEvent,
RoomEncryptedEvent, SyncResponse)
RoomEncryptedEvent, RoomMessage, SyncResponse)
from nio.crypto import Sas
from nio.store import SqliteStore
from pantalaimon.index import Index
from pantalaimon.log import logger
from pantalaimon.thread_messages import (DaemonResponse, InviteSasSignal,
SasDoneSignal, ShowSasSignal,
UpdateDevicesMessage)
SEARCH_KEYS = ["content.body", "content.name", "content.topic"]
SEARCH_TERMS_SCHEMA = {
"type": "object",
"properties": {
"search_categories": {
"type": "object",
"properties": {
"room_events": {
"type": "object",
"properties": {
"search_term": {"type": "string"},
"keys": {
"type": "array",
"items": {"type": "string", "enum": SEARCH_KEYS},
"default": SEARCH_KEYS
},
"order_by": {"type": "string", "default": "rank"},
"include_state": {"type": "boolean", "default": False},
"filter": {"type": "object", "default": {}},
"event_context": {"type": "object", "default": {}},
"groupings": {"type": "object", "default": {}},
},
"required": ["search_term"]
},
}
},
"required": ["room_events"]
},
"required": [
"search_categories",
],
}
def extend_with_default(validator_class):
validate_properties = validator_class.VALIDATORS["properties"]
def set_defaults(validator, properties, instance, schema):
for prop, subschema in properties.items():
if "default" in subschema:
instance.setdefault(prop, subschema["default"])
for error in validate_properties(
validator, properties, instance, schema
):
yield error
return validators.extend(validator_class, {"properties": set_defaults})
Validator = extend_with_default(Draft4Validator)
def validate_json(instance, schema):
"""Validate a dictionary using the provided json schema."""
Validator(schema, format_checker=FormatChecker()).validate(instance)
class UnknownRoomError(Exception):
pass
class PanClient(AsyncClient):
"""A wrapper class around a nio AsyncClient extending its functionality."""
def __init__(
self,
server_name,
pan_store,
homeserver,
queue=None,
user="",
user_id="",
device_id="",
store_path="",
config=None,
@ -46,9 +114,19 @@ class PanClient(AsyncClient):
proxy=None
):
config = config or ClientConfig(store=SqliteStore, store_name="pan.db")
super().__init__(homeserver, user, device_id, store_path, config,
super().__init__(homeserver, user_id, device_id, store_path, config,
ssl, proxy)
index_dir = os.path.join(store_path, server_name, user_id)
try:
os.makedirs(index_dir)
except OSError:
pass
self.server_name = server_name
self.pan_store = pan_store
self.index = Index(index_dir)
self.task = None
self.queue = queue
@ -63,6 +141,10 @@ class PanClient(AsyncClient):
self.undecrypted_event_cb,
MegolmEvent
)
self.add_event_callback(
self.store_message_cb,
RoomMessage
)
self.key_verificatins_tasks = []
self.key_request_tasks = []
@ -76,6 +158,21 @@ class PanClient(AsyncClient):
SyncResponse
)
def store_message_cb(self, room, event):
display_name = room.user_name(event.sender)
avatar_url = room.avatar_url(event.sender)
column_id = self.pan_store.save_event(
self.server_name,
self.user_id,
event,
room.room_id,
display_name,
avatar_url
)
if column_id:
self.index.add_event(column_id, event, room.room_id)
@property
def unable_to_decrypt(self):
"""Room event signaling that the message couldn't be decrypted."""
@ -107,6 +204,8 @@ class PanClient(AsyncClient):
self.key_verificatins_tasks = []
self.key_request_tasks = []
self.index.commit()
async def keys_query_cb(self, response):
await self.send_update_devcies()
@ -367,10 +466,10 @@ class PanClient(AsyncClient):
await self.task
def pan_decrypt_event(
self,
event_dict,
room_id=None,
ignore_failures=True
self,
event_dict,
room_id=None,
ignore_failures=True
):
# type: (Dict[Any, Any], Optional[str], bool) -> (bool)
event = RoomEncryptedEvent.parse_event(event_dict)
@ -461,3 +560,50 @@ class PanClient(AsyncClient):
self.pan_decrypt_event(event, room_id, ignore_failures)
return body
async def search(self, search_terms):
# type: (Dict[Any, Any]) -> Dict[Any, Any]
loop = asyncio.get_event_loop()
validate_json(search_terms, SEARCH_TERMS_SCHEMA)
search_terms = search_terms["search_categories"]["room_events"]
term = search_terms["search_term"]
searcher = self.index.searcher()
search_func = partial(searcher.search, term)
result = await loop.run_in_executor(None, search_func)
result_dict = {
"results": []
}
for score, column_id in result:
event = self.pan_store.load_event_by_columns(
self.server_name,
self.user_id,
column_id)
if not event:
continue
event_dict = {
"rank": score,
"result": event,
}
if False:
# TODO load the context from the server
event_dict["context"] = {}
if False:
# TODO add profile info
pass
result_dict["results"].append(event_dict)
result_dict["count"] = len(result_dict["results"])
result_dict["highlight"] = []
return result_dict

View File

@ -23,11 +23,12 @@ import attr
import keyring
from aiohttp import ClientSession, web
from aiohttp.client_exceptions import ClientConnectionError, ContentTypeError
from jsonschema import ValidationError
from multidict import CIMultiDict
from nio import (Api, EncryptionError, LoginResponse, OlmTrustError,
SendRetryError)
from pantalaimon.client import PanClient
from pantalaimon.client import PanClient, UnknownRoomError
from pantalaimon.log import logger
from pantalaimon.store import ClientInfo, PanStore
from pantalaimon.thread_messages import (AcceptSasMessage, CancelSasMessage,
@ -62,6 +63,7 @@ class ProxyDaemon:
store = attr.ib(type=PanStore, init=False)
homeserver_url = attr.ib(init=False, default=attr.Factory(dict))
hostname = attr.ib(init=False, default=attr.Factory(dict))
pan_clients = attr.ib(init=False, default=attr.Factory(dict))
client_info = attr.ib(
init=False,
@ -94,6 +96,8 @@ class ProxyDaemon:
logger.info(f"Restoring client for {user_id} {device_id}")
pan_client = PanClient(
self.name,
self.store,
self.homeserver_url,
self.send_queue,
user_id,
@ -442,12 +446,12 @@ class ProxyDaemon:
)
async def forward_to_web(
self,
request,
params=None,
data=None,
session=None,
token=None
self,
request,
params=None,
data=None,
session=None,
token=None
):
"""Forward the given request and convert the response to a Response.
@ -510,9 +514,11 @@ class ProxyDaemon:
return
pan_client = PanClient(
self.name,
self.store,
self.homeserver_url,
self.send_queue,
user,
user_id,
store_path=self.data_dir,
ssl=self.ssl,
proxy=self.proxy
@ -888,6 +894,37 @@ class ProxyDaemon:
data=json.dumps(sanitized_content)
)
async def search(self, request):
access_token = self.get_access_token(request)
if not access_token:
return self._missing_token
client = await self._find_client(access_token)
if not client:
return self._unknown_token
try:
search_categories = await request.json()
except (JSONDecodeError, ContentTypeError):
return self._not_json
try:
result = await client.search(search_categories)
except ValidationError:
return web.json_response(
{
"errcode": "M_BAD_JSON",
"error": "Invalid search query"
},
status=400,
)
except UnknownRoomError:
return await self.forward_to_web(request)
return web.json_response(result, status=200)
async def shutdown(self, app):
"""Shut the daemon down closing all the client sessions it has.

143
pantalaimon/index.py Normal file
View File

@ -0,0 +1,143 @@
# Copyright 2019 The Matrix.org Foundation CIC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import tantivy
from nio import RoomMessageText, RoomNameEvent, RoomTopicEvent
def sanitize_room_id(room_id):
return room_id.replace(":", "/").replace("!", "")
class Searcher:
def __init__(self, index, body_field, name_field, topic_field,
column_field, room_field, searcher):
self._index = index
self._searcher = searcher
self.body_field = body_field
self.name_field = topic_field
self.topic_field = name_field
self.column_field = column_field
self.room_field = room_field
def search(self, search_term, room=None, max_results=10):
# type (str, str) -> List[int, int]
"""Search for events in the index.
Returns the score and the column id for the event.
"""
queryparser = tantivy.QueryParser.for_index(
self._index,
[
self.body_field,
self.name_field,
self.topic_field,
self.room_field
]
)
# This currently supports only a single room since the query parser
# doesn't seem to work with multiple room fields here.
if room:
search_term = "{} AND room:{}".format(
search_term,
sanitize_room_id(room)
)
query = queryparser.parse_query(search_term)
collector = tantivy.TopDocs(max_results)
result = self._searcher.search(query, collector)
retrieved_result = []
for score, doc_address in result:
doc = self._searcher.doc(doc_address)
column = doc.get_first(self.column_field)
retrieved_result.append((score, column))
return retrieved_result
class Index:
def __init__(self, path=None):
schema_builder = tantivy.SchemaBuilder()
self.body_field = schema_builder.add_text_field("body")
self.name_field = schema_builder.add_text_field("name")
self.topic_field = schema_builder.add_text_field("topic")
self.timestamp_field = schema_builder.add_date_field(
"server_timestamp"
)
self.room_field = schema_builder.add_facet_field("room")
self.column_field = schema_builder.add_integer_field(
"database_column",
indexed=True,
stored=True,
fast="single"
)
schema = schema_builder.build()
self.index = tantivy.Index(schema, path)
self.reader = self.index.reader()
self.writer = self.index.writer()
def add_event(self, column_id, event, room_id):
doc = tantivy.Document()
room_path = "/{}".format(sanitize_room_id(room_id))
room_facet = tantivy.Facet.from_string(room_path)
doc.add_integer(self.column_field, column_id)
doc.add_facet(self.room_field, room_facet)
doc.add_date(
self.timestamp_field,
datetime.datetime.fromtimestamp(event.server_timestamp / 1000)
)
if isinstance(event, RoomMessageText):
doc.add_text(self.body_field, event.body)
elif isinstance(event, RoomNameEvent):
doc.add_text(self.name_field, event.name)
elif isinstance(event, RoomTopicEvent):
doc.add_text(self.topic_field, event.topic)
else:
raise ValueError("Invalid event passed.")
self.writer.add_document(doc)
def commit(self):
self.writer.commit()
def searcher(self):
self.reader.reload()
return Searcher(
self.index,
self.body_field,
self.name_field,
self.topic_field,
self.column_field,
self.room_field,
self.reader.searcher()
)

View File

@ -67,6 +67,7 @@ async def init(data_dir, server_conf, send_queue, recv_queue):
proxy.send_message
),
web.post("/_matrix/client/r0/user/{user_id}/filter", proxy.filter),
web.post("/_matrix/client/r0/search", proxy.search),
])
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)
app.on_shutdown.append(proxy.shutdown)
@ -174,7 +175,7 @@ def main(
data_dir,
server_conf,
pan_queue.async_q,
ui_queue.async_q
ui_queue.async_q,
)
)
servers.append((proxy, runner, site))

View File

@ -12,15 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import json
import os
from collections import defaultdict
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
import attr
from nio.store import (Accounts, DeviceKeys, DeviceTrustState, TrustState,
use_database)
from peewee import (SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
TextField)
from peewee import (SQL, DateTimeField, DoesNotExist, ForeignKeyField, Model,
SqliteDatabase, TextField)
class DictField(TextField):
def python_value(self, value): # pragma: no cover
return json.loads(value)
def db_value(self, value): # pragma: no cover
return json.dumps(value)
class AccessTokens(Model):
@ -53,21 +63,47 @@ class ServerUsers(Model):
constraints = [SQL("UNIQUE(user_id,server_id)")]
class Profile(Model):
user_id = TextField()
avatar_url = TextField(null=True)
display_name = TextField(null=True)
class Meta:
constraints = [SQL("UNIQUE(user_id,avatar_url,display_name)")]
class Event(Model):
event_id = TextField()
sender = TextField()
date = DateTimeField()
room_id = TextField()
source = DictField()
profile = ForeignKeyField(
model=Profile,
column_name="profile_id",
)
class Meta:
constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")]
class UserMessages(Model):
user = ForeignKeyField(
model=ServerUsers,
column_name="user_id")
event = ForeignKeyField(
model=Event,
column_name="event_id")
@attr.s
class ClientInfo:
user_id = attr.ib(type=str)
access_token = attr.ib(type=str)
@attr.s
class OlmDevice:
user_id = attr.ib()
id = attr.ib()
fp_key = attr.ib()
sender_key = attr.ib()
trust_state = attr.ib()
@attr.s
class PanStore:
store_path = attr.ib(type=str)
@ -81,6 +117,9 @@ class PanStore:
ServerUsers,
DeviceKeys,
DeviceTrustState,
Profile,
Event,
UserMessages,
]
def __attrs_post_init__(self):
@ -108,29 +147,107 @@ class PanStore:
def _get_account(self, user_id, device_id):
try:
return Accounts.get(
Accounts.user_id == user_id,
Accounts.device_id == device_id,
Accounts.user_id == user_id,
Accounts.device_id == device_id,
)
except DoesNotExist:
return None
@use_database
def save_event(self, server, pan_user, event, room_id, display_name,
avatar_url):
# type (str, str, str, RoomMessage, str, str, str) -> Optional[int]
"""Save an event to the store.
Returns the database id of the event.
"""
server = Servers.get(name=server)
user = ServerUsers.get(server=server, user_id=pan_user)
profile_id, _ = Profile.get_or_create(
user_id=event.sender,
display_name=display_name,
avatar_url=avatar_url
)
event_source = event.source
event_source["room_id"] = room_id
event_id = Event.insert(
event_id=event.event_id,
sender=event.sender,
date=datetime.datetime.fromtimestamp(
event.server_timestamp / 1000
),
room_id=room_id,
source=event_source,
profile=profile_id
).on_conflict_ignore().execute()
# TODO why do we get a 0 on conflict here, the test show that we get the
# existing event id.
if event_id <= 0:
return None
_, created = UserMessages.get_or_create(
user=user,
event=event_id,
)
if created:
return event_id
return None
@use_database
def load_event_by_columns(
self,
server, # type: str
pan_user, # type: str
column, # type: List[int]
include_profile=False # type: bool
):
# type: (...) -> Union[List[Dict[Any, Any]], List[Tuple[Dict, Dict]]]
server = Servers.get(name=server)
user = ServerUsers.get(server=server, user_id=pan_user)
message = UserMessages.get_or_none(user=user, event=column)
if not message:
return None
event = message.event
event_profile = event.profile
profile = {
event_profile.user_id: {
"display_name": event_profile.display_name,
"avatar_url": event_profile.avatar_url,
}
}
if include_profile:
return event.source, profile
return event.source
@use_database
def save_server_user(self, server_name, user_id):
# type: (ClientInfo) -> None
# type: (str, str) -> None
server, _ = Servers.get_or_create(name=server_name)
ServerUsers.replace(
ServerUsers.insert(
user_id=user_id,
server=server
).execute()
).on_conflict_ignore().execute()
@use_database
def load_all_users(self):
users = []
query = Accounts.select(
Accounts.user_id,
Accounts.device_id,
Accounts.user_id,
Accounts.device_id,
)
for account in query:
@ -140,7 +257,7 @@ class PanStore:
@use_database
def load_users(self, server_name):
# type: () -> List[Tuple[str, str]]
# type: (str) -> List[Tuple[str, str]]
users = []
server = Servers.get_or_none(Servers.name == server_name)
@ -154,8 +271,8 @@ class PanStore:
server_users.append(u.user_id)
query = Accounts.select(
Accounts.user_id,
Accounts.device_id,
Accounts.user_id,
Accounts.device_id,
).where(Accounts.user_id.in_(server_users))
for account in query:

View File

@ -63,3 +63,18 @@ def panstore(tempdir):
store = PanStore(tempdir, "pan.db")
return store
@pytest.fixture
def panstore_with_users(panstore):
accounts = panstore.load_all_users()
user_id, device_id = accounts[0]
server = "example"
panstore.save_server_user(server, user_id)
server2 = "localhost"
user_id2, device_id2 = accounts[1]
panstore.save_server_user(server2, user_id2)
return panstore

View File

@ -1,9 +1,47 @@
import pdb
from nio import RoomMessage
from conftest import faker
from pantalaimon.index import Index
TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost"
TEST_ROOM2 = "!testroom:localhost"
class TestClass(object):
@property
def test_event(self):
return RoomMessage.parse_event(
{
"content": {"body": "Test message", "msgtype": "m.text"},
"event_id": "$15163622445EBvZJ:localhost",
"origin_server_ts": 1516362244026,
"room_id": "!SVkFJHzfwvuaIEawgC:localhost",
"sender": "@example2:localhost",
"type": "m.room.message",
"unsigned": {"age": 43289803095},
"user_id": "@example2:localhost",
"age": 43289803095
}
)
@property
def another_event(self):
return RoomMessage.parse_event(
{
"content": {"body": "Another message", "msgtype": "m.text"},
"event_id": "$15163622445EBvZJ:localhost",
"origin_server_ts": 1516362244026,
"room_id": "!SVkFJHzfwvuaIEawgC:localhost",
"sender": "@example2:localhost",
"type": "m.room.message",
"unsigned": {"age": 43289803095},
"user_id": "@example2:localhost",
"age": 43289803095
}
)
def test_account_loading(self, panstore):
accounts = panstore.load_all_users()
# pdb.set_trace()
@ -19,18 +57,63 @@ class TestClass(object):
token = panstore.load_access_token(user_id, device_id)
access_token == token
def test_server_account_storing(self, panstore):
def test_event_storing(self, panstore_with_users):
panstore = panstore_with_users
accounts = panstore.load_all_users()
user, _ = accounts[0]
user_id, device_id = accounts[0]
server = faker.hostname()
event = self.test_event
panstore.save_server_user(server, user_id)
event_id = panstore.save_event("example", user, event, TEST_ROOM,
"Example2", None)
server2 = faker.hostname()
user_id2, device_id2 = accounts[1]
panstore.save_server_user(server2, user_id2)
assert event_id == 1
server_users = panstore.load_users(server)
assert (user_id, device_id) in server_users
assert (user_id2, device_id2) not in server_users
event_id = panstore.save_event("example", user, event, TEST_ROOM,
"Example2", None)
assert event_id is None
assert False
event_dict = panstore.load_event_by_columns("example", user, 1)
assert event.source == event_dict
_, profile = panstore.load_event_by_columns("example", user, 1, True)
assert profile == {
"@example2:localhost": {
"display_name": "Example2",
"avatar_url": None
}
}
def test_index(self, panstore_with_users):
panstore = panstore_with_users
accounts = panstore.load_all_users()
user, _ = accounts[0]
event = self.test_event
another_event = self.another_event
index = Index(panstore.store_path)
event_id = panstore.save_event("example", user, event, TEST_ROOM,
"Example2", None)
assert event_id == 1
index.add_event(event_id, event, TEST_ROOM)
event_id = panstore.save_event("example", user, another_event,
TEST_ROOM2, "Example2", None)
assert event_id == 2
index.add_event(event_id, another_event, TEST_ROOM2)
index.commit()
searcher = index.searcher()
searched_events = searcher.search("message", TEST_ROOM)
_, found_id = searched_events[0]
event_dict = panstore.load_event_by_columns("example", user, found_id)
assert event_dict == event.source