mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-01-20 20:31:28 -05:00
pantalaimon: Initial search support.
This patch adds support for the Matrix search API endpoint. Events are stored in the pan sqlite database while the indexing is handled by tanvity. An tantivy index is created for each pan user. Currently only ordering by ranking is supported, and all the search options are ignored for now.
This commit is contained in:
parent
9444e540df
commit
725c043e87
@ -13,32 +13,100 @@
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from pprint import pformat
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from aiohttp.client_exceptions import ClientConnectionError
|
||||
from jsonschema import Draft4Validator, FormatChecker, validators
|
||||
from nio import (AsyncClient, ClientConfig, EncryptionError, KeysQueryResponse,
|
||||
KeyVerificationEvent, KeyVerificationKey, KeyVerificationMac,
|
||||
KeyVerificationStart, LocalProtocolError, MegolmEvent,
|
||||
RoomEncryptedEvent, SyncResponse)
|
||||
RoomEncryptedEvent, RoomMessage, SyncResponse)
|
||||
from nio.crypto import Sas
|
||||
from nio.store import SqliteStore
|
||||
|
||||
from pantalaimon.index import Index
|
||||
from pantalaimon.log import logger
|
||||
from pantalaimon.thread_messages import (DaemonResponse, InviteSasSignal,
|
||||
SasDoneSignal, ShowSasSignal,
|
||||
UpdateDevicesMessage)
|
||||
|
||||
SEARCH_KEYS = ["content.body", "content.name", "content.topic"]
|
||||
|
||||
SEARCH_TERMS_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"search_categories": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"room_events": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"search_term": {"type": "string"},
|
||||
"keys": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "enum": SEARCH_KEYS},
|
||||
"default": SEARCH_KEYS
|
||||
},
|
||||
"order_by": {"type": "string", "default": "rank"},
|
||||
"include_state": {"type": "boolean", "default": False},
|
||||
"filter": {"type": "object", "default": {}},
|
||||
"event_context": {"type": "object", "default": {}},
|
||||
"groupings": {"type": "object", "default": {}},
|
||||
},
|
||||
"required": ["search_term"]
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["room_events"]
|
||||
},
|
||||
"required": [
|
||||
"search_categories",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def extend_with_default(validator_class):
|
||||
validate_properties = validator_class.VALIDATORS["properties"]
|
||||
|
||||
def set_defaults(validator, properties, instance, schema):
|
||||
for prop, subschema in properties.items():
|
||||
if "default" in subschema:
|
||||
instance.setdefault(prop, subschema["default"])
|
||||
|
||||
for error in validate_properties(
|
||||
validator, properties, instance, schema
|
||||
):
|
||||
yield error
|
||||
|
||||
return validators.extend(validator_class, {"properties": set_defaults})
|
||||
|
||||
|
||||
Validator = extend_with_default(Draft4Validator)
|
||||
|
||||
|
||||
def validate_json(instance, schema):
|
||||
"""Validate a dictionary using the provided json schema."""
|
||||
Validator(schema, format_checker=FormatChecker()).validate(instance)
|
||||
|
||||
|
||||
class UnknownRoomError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PanClient(AsyncClient):
|
||||
"""A wrapper class around a nio AsyncClient extending its functionality."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_name,
|
||||
pan_store,
|
||||
homeserver,
|
||||
queue=None,
|
||||
user="",
|
||||
user_id="",
|
||||
device_id="",
|
||||
store_path="",
|
||||
config=None,
|
||||
@ -46,9 +114,19 @@ class PanClient(AsyncClient):
|
||||
proxy=None
|
||||
):
|
||||
config = config or ClientConfig(store=SqliteStore, store_name="pan.db")
|
||||
super().__init__(homeserver, user, device_id, store_path, config,
|
||||
super().__init__(homeserver, user_id, device_id, store_path, config,
|
||||
ssl, proxy)
|
||||
|
||||
index_dir = os.path.join(store_path, server_name, user_id)
|
||||
|
||||
try:
|
||||
os.makedirs(index_dir)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
self.server_name = server_name
|
||||
self.pan_store = pan_store
|
||||
self.index = Index(index_dir)
|
||||
self.task = None
|
||||
self.queue = queue
|
||||
|
||||
@ -63,6 +141,10 @@ class PanClient(AsyncClient):
|
||||
self.undecrypted_event_cb,
|
||||
MegolmEvent
|
||||
)
|
||||
self.add_event_callback(
|
||||
self.store_message_cb,
|
||||
RoomMessage
|
||||
)
|
||||
self.key_verificatins_tasks = []
|
||||
self.key_request_tasks = []
|
||||
|
||||
@ -76,6 +158,21 @@ class PanClient(AsyncClient):
|
||||
SyncResponse
|
||||
)
|
||||
|
||||
def store_message_cb(self, room, event):
|
||||
display_name = room.user_name(event.sender)
|
||||
avatar_url = room.avatar_url(event.sender)
|
||||
|
||||
column_id = self.pan_store.save_event(
|
||||
self.server_name,
|
||||
self.user_id,
|
||||
event,
|
||||
room.room_id,
|
||||
display_name,
|
||||
avatar_url
|
||||
)
|
||||
if column_id:
|
||||
self.index.add_event(column_id, event, room.room_id)
|
||||
|
||||
@property
|
||||
def unable_to_decrypt(self):
|
||||
"""Room event signaling that the message couldn't be decrypted."""
|
||||
@ -107,6 +204,8 @@ class PanClient(AsyncClient):
|
||||
self.key_verificatins_tasks = []
|
||||
self.key_request_tasks = []
|
||||
|
||||
self.index.commit()
|
||||
|
||||
async def keys_query_cb(self, response):
|
||||
await self.send_update_devcies()
|
||||
|
||||
@ -367,10 +466,10 @@ class PanClient(AsyncClient):
|
||||
await self.task
|
||||
|
||||
def pan_decrypt_event(
|
||||
self,
|
||||
event_dict,
|
||||
room_id=None,
|
||||
ignore_failures=True
|
||||
self,
|
||||
event_dict,
|
||||
room_id=None,
|
||||
ignore_failures=True
|
||||
):
|
||||
# type: (Dict[Any, Any], Optional[str], bool) -> (bool)
|
||||
event = RoomEncryptedEvent.parse_event(event_dict)
|
||||
@ -461,3 +560,50 @@ class PanClient(AsyncClient):
|
||||
self.pan_decrypt_event(event, room_id, ignore_failures)
|
||||
|
||||
return body
|
||||
|
||||
async def search(self, search_terms):
|
||||
# type: (Dict[Any, Any]) -> Dict[Any, Any]
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
validate_json(search_terms, SEARCH_TERMS_SCHEMA)
|
||||
search_terms = search_terms["search_categories"]["room_events"]
|
||||
|
||||
term = search_terms["search_term"]
|
||||
|
||||
searcher = self.index.searcher()
|
||||
search_func = partial(searcher.search, term)
|
||||
|
||||
result = await loop.run_in_executor(None, search_func)
|
||||
|
||||
result_dict = {
|
||||
"results": []
|
||||
}
|
||||
|
||||
for score, column_id in result:
|
||||
event = self.pan_store.load_event_by_columns(
|
||||
self.server_name,
|
||||
self.user_id,
|
||||
column_id)
|
||||
|
||||
if not event:
|
||||
continue
|
||||
|
||||
event_dict = {
|
||||
"rank": score,
|
||||
"result": event,
|
||||
}
|
||||
|
||||
if False:
|
||||
# TODO load the context from the server
|
||||
event_dict["context"] = {}
|
||||
|
||||
if False:
|
||||
# TODO add profile info
|
||||
pass
|
||||
|
||||
result_dict["results"].append(event_dict)
|
||||
|
||||
result_dict["count"] = len(result_dict["results"])
|
||||
result_dict["highlight"] = []
|
||||
|
||||
return result_dict
|
||||
|
@ -23,11 +23,12 @@ import attr
|
||||
import keyring
|
||||
from aiohttp import ClientSession, web
|
||||
from aiohttp.client_exceptions import ClientConnectionError, ContentTypeError
|
||||
from jsonschema import ValidationError
|
||||
from multidict import CIMultiDict
|
||||
from nio import (Api, EncryptionError, LoginResponse, OlmTrustError,
|
||||
SendRetryError)
|
||||
|
||||
from pantalaimon.client import PanClient
|
||||
from pantalaimon.client import PanClient, UnknownRoomError
|
||||
from pantalaimon.log import logger
|
||||
from pantalaimon.store import ClientInfo, PanStore
|
||||
from pantalaimon.thread_messages import (AcceptSasMessage, CancelSasMessage,
|
||||
@ -62,6 +63,7 @@ class ProxyDaemon:
|
||||
|
||||
store = attr.ib(type=PanStore, init=False)
|
||||
homeserver_url = attr.ib(init=False, default=attr.Factory(dict))
|
||||
hostname = attr.ib(init=False, default=attr.Factory(dict))
|
||||
pan_clients = attr.ib(init=False, default=attr.Factory(dict))
|
||||
client_info = attr.ib(
|
||||
init=False,
|
||||
@ -94,6 +96,8 @@ class ProxyDaemon:
|
||||
logger.info(f"Restoring client for {user_id} {device_id}")
|
||||
|
||||
pan_client = PanClient(
|
||||
self.name,
|
||||
self.store,
|
||||
self.homeserver_url,
|
||||
self.send_queue,
|
||||
user_id,
|
||||
@ -442,12 +446,12 @@ class ProxyDaemon:
|
||||
)
|
||||
|
||||
async def forward_to_web(
|
||||
self,
|
||||
request,
|
||||
params=None,
|
||||
data=None,
|
||||
session=None,
|
||||
token=None
|
||||
self,
|
||||
request,
|
||||
params=None,
|
||||
data=None,
|
||||
session=None,
|
||||
token=None
|
||||
):
|
||||
"""Forward the given request and convert the response to a Response.
|
||||
|
||||
@ -510,9 +514,11 @@ class ProxyDaemon:
|
||||
return
|
||||
|
||||
pan_client = PanClient(
|
||||
self.name,
|
||||
self.store,
|
||||
self.homeserver_url,
|
||||
self.send_queue,
|
||||
user,
|
||||
user_id,
|
||||
store_path=self.data_dir,
|
||||
ssl=self.ssl,
|
||||
proxy=self.proxy
|
||||
@ -888,6 +894,37 @@ class ProxyDaemon:
|
||||
data=json.dumps(sanitized_content)
|
||||
)
|
||||
|
||||
async def search(self, request):
|
||||
access_token = self.get_access_token(request)
|
||||
|
||||
if not access_token:
|
||||
return self._missing_token
|
||||
|
||||
client = await self._find_client(access_token)
|
||||
|
||||
if not client:
|
||||
return self._unknown_token
|
||||
|
||||
try:
|
||||
search_categories = await request.json()
|
||||
except (JSONDecodeError, ContentTypeError):
|
||||
return self._not_json
|
||||
|
||||
try:
|
||||
result = await client.search(search_categories)
|
||||
except ValidationError:
|
||||
return web.json_response(
|
||||
{
|
||||
"errcode": "M_BAD_JSON",
|
||||
"error": "Invalid search query"
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
except UnknownRoomError:
|
||||
return await self.forward_to_web(request)
|
||||
|
||||
return web.json_response(result, status=200)
|
||||
|
||||
async def shutdown(self, app):
|
||||
"""Shut the daemon down closing all the client sessions it has.
|
||||
|
||||
|
143
pantalaimon/index.py
Normal file
143
pantalaimon/index.py
Normal file
@ -0,0 +1,143 @@
|
||||
# Copyright 2019 The Matrix.org Foundation CIC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
|
||||
import tantivy
|
||||
from nio import RoomMessageText, RoomNameEvent, RoomTopicEvent
|
||||
|
||||
|
||||
def sanitize_room_id(room_id):
|
||||
return room_id.replace(":", "/").replace("!", "")
|
||||
|
||||
|
||||
class Searcher:
|
||||
def __init__(self, index, body_field, name_field, topic_field,
|
||||
column_field, room_field, searcher):
|
||||
self._index = index
|
||||
self._searcher = searcher
|
||||
|
||||
self.body_field = body_field
|
||||
self.name_field = topic_field
|
||||
self.topic_field = name_field
|
||||
self.column_field = column_field
|
||||
self.room_field = room_field
|
||||
|
||||
def search(self, search_term, room=None, max_results=10):
|
||||
# type (str, str) -> List[int, int]
|
||||
"""Search for events in the index.
|
||||
|
||||
Returns the score and the column id for the event.
|
||||
"""
|
||||
|
||||
queryparser = tantivy.QueryParser.for_index(
|
||||
self._index,
|
||||
[
|
||||
self.body_field,
|
||||
self.name_field,
|
||||
self.topic_field,
|
||||
self.room_field
|
||||
]
|
||||
)
|
||||
|
||||
# This currently supports only a single room since the query parser
|
||||
# doesn't seem to work with multiple room fields here.
|
||||
if room:
|
||||
search_term = "{} AND room:{}".format(
|
||||
search_term,
|
||||
sanitize_room_id(room)
|
||||
)
|
||||
|
||||
query = queryparser.parse_query(search_term)
|
||||
|
||||
collector = tantivy.TopDocs(max_results)
|
||||
|
||||
result = self._searcher.search(query, collector)
|
||||
|
||||
retrieved_result = []
|
||||
|
||||
for score, doc_address in result:
|
||||
doc = self._searcher.doc(doc_address)
|
||||
column = doc.get_first(self.column_field)
|
||||
retrieved_result.append((score, column))
|
||||
|
||||
return retrieved_result
|
||||
|
||||
|
||||
class Index:
|
||||
def __init__(self, path=None):
|
||||
schema_builder = tantivy.SchemaBuilder()
|
||||
|
||||
self.body_field = schema_builder.add_text_field("body")
|
||||
self.name_field = schema_builder.add_text_field("name")
|
||||
self.topic_field = schema_builder.add_text_field("topic")
|
||||
|
||||
self.timestamp_field = schema_builder.add_date_field(
|
||||
"server_timestamp"
|
||||
)
|
||||
self.room_field = schema_builder.add_facet_field("room")
|
||||
|
||||
self.column_field = schema_builder.add_integer_field(
|
||||
"database_column",
|
||||
indexed=True,
|
||||
stored=True,
|
||||
fast="single"
|
||||
)
|
||||
|
||||
schema = schema_builder.build()
|
||||
|
||||
self.index = tantivy.Index(schema, path)
|
||||
|
||||
self.reader = self.index.reader()
|
||||
self.writer = self.index.writer()
|
||||
|
||||
def add_event(self, column_id, event, room_id):
|
||||
doc = tantivy.Document()
|
||||
|
||||
room_path = "/{}".format(sanitize_room_id(room_id))
|
||||
|
||||
room_facet = tantivy.Facet.from_string(room_path)
|
||||
|
||||
doc.add_integer(self.column_field, column_id)
|
||||
doc.add_facet(self.room_field, room_facet)
|
||||
doc.add_date(
|
||||
self.timestamp_field,
|
||||
datetime.datetime.fromtimestamp(event.server_timestamp / 1000)
|
||||
)
|
||||
|
||||
if isinstance(event, RoomMessageText):
|
||||
doc.add_text(self.body_field, event.body)
|
||||
elif isinstance(event, RoomNameEvent):
|
||||
doc.add_text(self.name_field, event.name)
|
||||
elif isinstance(event, RoomTopicEvent):
|
||||
doc.add_text(self.topic_field, event.topic)
|
||||
else:
|
||||
raise ValueError("Invalid event passed.")
|
||||
|
||||
self.writer.add_document(doc)
|
||||
|
||||
def commit(self):
|
||||
self.writer.commit()
|
||||
|
||||
def searcher(self):
|
||||
self.reader.reload()
|
||||
return Searcher(
|
||||
self.index,
|
||||
self.body_field,
|
||||
self.name_field,
|
||||
self.topic_field,
|
||||
self.column_field,
|
||||
self.room_field,
|
||||
self.reader.searcher()
|
||||
)
|
@ -67,6 +67,7 @@ async def init(data_dir, server_conf, send_queue, recv_queue):
|
||||
proxy.send_message
|
||||
),
|
||||
web.post("/_matrix/client/r0/user/{user_id}/filter", proxy.filter),
|
||||
web.post("/_matrix/client/r0/search", proxy.search),
|
||||
])
|
||||
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)
|
||||
app.on_shutdown.append(proxy.shutdown)
|
||||
@ -174,7 +175,7 @@ def main(
|
||||
data_dir,
|
||||
server_conf,
|
||||
pan_queue.async_q,
|
||||
ui_queue.async_q
|
||||
ui_queue.async_q,
|
||||
)
|
||||
)
|
||||
servers.append((proxy, runner, site))
|
||||
|
@ -12,15 +12,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
from nio.store import (Accounts, DeviceKeys, DeviceTrustState, TrustState,
|
||||
use_database)
|
||||
from peewee import (SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
|
||||
TextField)
|
||||
from peewee import (SQL, DateTimeField, DoesNotExist, ForeignKeyField, Model,
|
||||
SqliteDatabase, TextField)
|
||||
|
||||
|
||||
class DictField(TextField):
|
||||
def python_value(self, value): # pragma: no cover
|
||||
return json.loads(value)
|
||||
|
||||
def db_value(self, value): # pragma: no cover
|
||||
return json.dumps(value)
|
||||
|
||||
|
||||
class AccessTokens(Model):
|
||||
@ -53,21 +63,47 @@ class ServerUsers(Model):
|
||||
constraints = [SQL("UNIQUE(user_id,server_id)")]
|
||||
|
||||
|
||||
class Profile(Model):
|
||||
user_id = TextField()
|
||||
avatar_url = TextField(null=True)
|
||||
display_name = TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(user_id,avatar_url,display_name)")]
|
||||
|
||||
|
||||
class Event(Model):
|
||||
event_id = TextField()
|
||||
sender = TextField()
|
||||
date = DateTimeField()
|
||||
room_id = TextField()
|
||||
|
||||
source = DictField()
|
||||
|
||||
profile = ForeignKeyField(
|
||||
model=Profile,
|
||||
column_name="profile_id",
|
||||
)
|
||||
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")]
|
||||
|
||||
|
||||
class UserMessages(Model):
|
||||
user = ForeignKeyField(
|
||||
model=ServerUsers,
|
||||
column_name="user_id")
|
||||
event = ForeignKeyField(
|
||||
model=Event,
|
||||
column_name="event_id")
|
||||
|
||||
|
||||
@attr.s
|
||||
class ClientInfo:
|
||||
user_id = attr.ib(type=str)
|
||||
access_token = attr.ib(type=str)
|
||||
|
||||
|
||||
@attr.s
|
||||
class OlmDevice:
|
||||
user_id = attr.ib()
|
||||
id = attr.ib()
|
||||
fp_key = attr.ib()
|
||||
sender_key = attr.ib()
|
||||
trust_state = attr.ib()
|
||||
|
||||
|
||||
@attr.s
|
||||
class PanStore:
|
||||
store_path = attr.ib(type=str)
|
||||
@ -81,6 +117,9 @@ class PanStore:
|
||||
ServerUsers,
|
||||
DeviceKeys,
|
||||
DeviceTrustState,
|
||||
Profile,
|
||||
Event,
|
||||
UserMessages,
|
||||
]
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
@ -108,29 +147,107 @@ class PanStore:
|
||||
def _get_account(self, user_id, device_id):
|
||||
try:
|
||||
return Accounts.get(
|
||||
Accounts.user_id == user_id,
|
||||
Accounts.device_id == device_id,
|
||||
Accounts.user_id == user_id,
|
||||
Accounts.device_id == device_id,
|
||||
)
|
||||
except DoesNotExist:
|
||||
return None
|
||||
|
||||
@use_database
|
||||
def save_event(self, server, pan_user, event, room_id, display_name,
|
||||
avatar_url):
|
||||
# type (str, str, str, RoomMessage, str, str, str) -> Optional[int]
|
||||
"""Save an event to the store.
|
||||
|
||||
Returns the database id of the event.
|
||||
"""
|
||||
server = Servers.get(name=server)
|
||||
user = ServerUsers.get(server=server, user_id=pan_user)
|
||||
|
||||
profile_id, _ = Profile.get_or_create(
|
||||
user_id=event.sender,
|
||||
display_name=display_name,
|
||||
avatar_url=avatar_url
|
||||
)
|
||||
|
||||
event_source = event.source
|
||||
event_source["room_id"] = room_id
|
||||
|
||||
event_id = Event.insert(
|
||||
event_id=event.event_id,
|
||||
sender=event.sender,
|
||||
date=datetime.datetime.fromtimestamp(
|
||||
event.server_timestamp / 1000
|
||||
),
|
||||
room_id=room_id,
|
||||
source=event_source,
|
||||
profile=profile_id
|
||||
).on_conflict_ignore().execute()
|
||||
|
||||
# TODO why do we get a 0 on conflict here, the test show that we get the
|
||||
# existing event id.
|
||||
if event_id <= 0:
|
||||
return None
|
||||
|
||||
_, created = UserMessages.get_or_create(
|
||||
user=user,
|
||||
event=event_id,
|
||||
)
|
||||
|
||||
if created:
|
||||
return event_id
|
||||
|
||||
return None
|
||||
|
||||
@use_database
|
||||
def load_event_by_columns(
|
||||
self,
|
||||
server, # type: str
|
||||
pan_user, # type: str
|
||||
column, # type: List[int]
|
||||
include_profile=False # type: bool
|
||||
):
|
||||
# type: (...) -> Union[List[Dict[Any, Any]], List[Tuple[Dict, Dict]]]
|
||||
server = Servers.get(name=server)
|
||||
user = ServerUsers.get(server=server, user_id=pan_user)
|
||||
|
||||
message = UserMessages.get_or_none(user=user, event=column)
|
||||
|
||||
if not message:
|
||||
return None
|
||||
|
||||
event = message.event
|
||||
event_profile = event.profile
|
||||
|
||||
profile = {
|
||||
event_profile.user_id: {
|
||||
"display_name": event_profile.display_name,
|
||||
"avatar_url": event_profile.avatar_url,
|
||||
}
|
||||
}
|
||||
|
||||
if include_profile:
|
||||
return event.source, profile
|
||||
|
||||
return event.source
|
||||
|
||||
@use_database
|
||||
def save_server_user(self, server_name, user_id):
|
||||
# type: (ClientInfo) -> None
|
||||
# type: (str, str) -> None
|
||||
server, _ = Servers.get_or_create(name=server_name)
|
||||
|
||||
ServerUsers.replace(
|
||||
ServerUsers.insert(
|
||||
user_id=user_id,
|
||||
server=server
|
||||
).execute()
|
||||
).on_conflict_ignore().execute()
|
||||
|
||||
@use_database
|
||||
def load_all_users(self):
|
||||
users = []
|
||||
|
||||
query = Accounts.select(
|
||||
Accounts.user_id,
|
||||
Accounts.device_id,
|
||||
Accounts.user_id,
|
||||
Accounts.device_id,
|
||||
)
|
||||
|
||||
for account in query:
|
||||
@ -140,7 +257,7 @@ class PanStore:
|
||||
|
||||
@use_database
|
||||
def load_users(self, server_name):
|
||||
# type: () -> List[Tuple[str, str]]
|
||||
# type: (str) -> List[Tuple[str, str]]
|
||||
users = []
|
||||
|
||||
server = Servers.get_or_none(Servers.name == server_name)
|
||||
@ -154,8 +271,8 @@ class PanStore:
|
||||
server_users.append(u.user_id)
|
||||
|
||||
query = Accounts.select(
|
||||
Accounts.user_id,
|
||||
Accounts.device_id,
|
||||
Accounts.user_id,
|
||||
Accounts.device_id,
|
||||
).where(Accounts.user_id.in_(server_users))
|
||||
|
||||
for account in query:
|
||||
|
@ -63,3 +63,18 @@ def panstore(tempdir):
|
||||
|
||||
store = PanStore(tempdir, "pan.db")
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def panstore_with_users(panstore):
|
||||
accounts = panstore.load_all_users()
|
||||
user_id, device_id = accounts[0]
|
||||
server = "example"
|
||||
|
||||
panstore.save_server_user(server, user_id)
|
||||
|
||||
server2 = "localhost"
|
||||
user_id2, device_id2 = accounts[1]
|
||||
panstore.save_server_user(server2, user_id2)
|
||||
|
||||
return panstore
|
||||
|
@ -1,9 +1,47 @@
|
||||
import pdb
|
||||
|
||||
from nio import RoomMessage
|
||||
|
||||
from conftest import faker
|
||||
from pantalaimon.index import Index
|
||||
|
||||
TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost"
|
||||
TEST_ROOM2 = "!testroom:localhost"
|
||||
|
||||
|
||||
class TestClass(object):
|
||||
@property
|
||||
def test_event(self):
|
||||
return RoomMessage.parse_event(
|
||||
{
|
||||
"content": {"body": "Test message", "msgtype": "m.text"},
|
||||
"event_id": "$15163622445EBvZJ:localhost",
|
||||
"origin_server_ts": 1516362244026,
|
||||
"room_id": "!SVkFJHzfwvuaIEawgC:localhost",
|
||||
"sender": "@example2:localhost",
|
||||
"type": "m.room.message",
|
||||
"unsigned": {"age": 43289803095},
|
||||
"user_id": "@example2:localhost",
|
||||
"age": 43289803095
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def another_event(self):
|
||||
return RoomMessage.parse_event(
|
||||
{
|
||||
"content": {"body": "Another message", "msgtype": "m.text"},
|
||||
"event_id": "$15163622445EBvZJ:localhost",
|
||||
"origin_server_ts": 1516362244026,
|
||||
"room_id": "!SVkFJHzfwvuaIEawgC:localhost",
|
||||
"sender": "@example2:localhost",
|
||||
"type": "m.room.message",
|
||||
"unsigned": {"age": 43289803095},
|
||||
"user_id": "@example2:localhost",
|
||||
"age": 43289803095
|
||||
}
|
||||
)
|
||||
|
||||
def test_account_loading(self, panstore):
|
||||
accounts = panstore.load_all_users()
|
||||
# pdb.set_trace()
|
||||
@ -19,18 +57,63 @@ class TestClass(object):
|
||||
token = panstore.load_access_token(user_id, device_id)
|
||||
access_token == token
|
||||
|
||||
def test_server_account_storing(self, panstore):
|
||||
def test_event_storing(self, panstore_with_users):
|
||||
panstore = panstore_with_users
|
||||
accounts = panstore.load_all_users()
|
||||
user, _ = accounts[0]
|
||||
|
||||
user_id, device_id = accounts[0]
|
||||
server = faker.hostname()
|
||||
event = self.test_event
|
||||
|
||||
panstore.save_server_user(server, user_id)
|
||||
event_id = panstore.save_event("example", user, event, TEST_ROOM,
|
||||
"Example2", None)
|
||||
|
||||
server2 = faker.hostname()
|
||||
user_id2, device_id2 = accounts[1]
|
||||
panstore.save_server_user(server2, user_id2)
|
||||
assert event_id == 1
|
||||
|
||||
server_users = panstore.load_users(server)
|
||||
assert (user_id, device_id) in server_users
|
||||
assert (user_id2, device_id2) not in server_users
|
||||
event_id = panstore.save_event("example", user, event, TEST_ROOM,
|
||||
"Example2", None)
|
||||
assert event_id is None
|
||||
assert False
|
||||
|
||||
event_dict = panstore.load_event_by_columns("example", user, 1)
|
||||
assert event.source == event_dict
|
||||
|
||||
_, profile = panstore.load_event_by_columns("example", user, 1, True)
|
||||
|
||||
assert profile == {
|
||||
"@example2:localhost": {
|
||||
"display_name": "Example2",
|
||||
"avatar_url": None
|
||||
}
|
||||
}
|
||||
|
||||
def test_index(self, panstore_with_users):
|
||||
panstore = panstore_with_users
|
||||
accounts = panstore.load_all_users()
|
||||
user, _ = accounts[0]
|
||||
|
||||
event = self.test_event
|
||||
another_event = self.another_event
|
||||
|
||||
index = Index(panstore.store_path)
|
||||
|
||||
event_id = panstore.save_event("example", user, event, TEST_ROOM,
|
||||
"Example2", None)
|
||||
assert event_id == 1
|
||||
index.add_event(event_id, event, TEST_ROOM)
|
||||
|
||||
event_id = panstore.save_event("example", user, another_event,
|
||||
TEST_ROOM2, "Example2", None)
|
||||
assert event_id == 2
|
||||
index.add_event(event_id, another_event, TEST_ROOM2)
|
||||
|
||||
index.commit()
|
||||
|
||||
searcher = index.searcher()
|
||||
|
||||
searched_events = searcher.search("message", TEST_ROOM)
|
||||
|
||||
_, found_id = searched_events[0]
|
||||
|
||||
event_dict = panstore.load_event_by_columns("example", user, found_id)
|
||||
|
||||
assert event_dict == event.source
|
||||
|
Loading…
Reference in New Issue
Block a user