mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2024-10-01 03:35:38 -04:00
index: Rewrite the search logic.
This patch moves all the indexing, event storing and searching into a separate class. The index and message store are now represented as a single class and messages are indexed and stored atomically now which should minimize the chance of store/index inconsistencies. Messages are now loaded from the store as a single SQL query and the context for the messages is as well loaded from the store instead of fetched from the server. The room state and start/end tokens for the context aren't currently loaded.
This commit is contained in:
parent
a489031962
commit
3a1b001244
@ -15,7 +15,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from pprint import pformat
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@ -30,7 +29,7 @@ from nio import (AsyncClient, ClientConfig, EncryptionError, KeysQueryResponse,
|
||||
from nio.crypto import Sas
|
||||
from nio.store import SqliteStore
|
||||
|
||||
from pantalaimon.index import Index
|
||||
from pantalaimon.index import IndexStore
|
||||
from pantalaimon.log import logger
|
||||
from pantalaimon.store import FetchTask
|
||||
from pantalaimon.thread_messages import (DaemonResponse, InviteSasSignal,
|
||||
@ -137,7 +136,7 @@ class PanClient(AsyncClient):
|
||||
|
||||
self.server_name = server_name
|
||||
self.pan_store = pan_store
|
||||
self.index = Index(index_dir)
|
||||
self.index = IndexStore(self.user_id, index_dir)
|
||||
self.task = None
|
||||
self.queue = queue
|
||||
|
||||
@ -179,20 +178,7 @@ class PanClient(AsyncClient):
|
||||
display_name = room.user_name(event.sender)
|
||||
avatar_url = room.avatar_url(event.sender)
|
||||
|
||||
column_id = self.pan_store.save_event(
|
||||
self.server_name,
|
||||
self.user_id,
|
||||
event,
|
||||
room.room_id,
|
||||
display_name,
|
||||
avatar_url
|
||||
)
|
||||
|
||||
if column_id:
|
||||
self.index.add_event(column_id, event, room.room_id)
|
||||
return True
|
||||
|
||||
return False
|
||||
self.index.add_event(event, room.room_id, display_name, avatar_url)
|
||||
|
||||
@property
|
||||
def unable_to_decrypt(self):
|
||||
@ -247,7 +233,8 @@ class PanClient(AsyncClient):
|
||||
room.display_name
|
||||
))
|
||||
response = await self.room_messages(fetch_task.room_id,
|
||||
fetch_task.token)
|
||||
fetch_task.token,
|
||||
limit=100)
|
||||
except ClientConnectionError:
|
||||
self.history_fetch_queue.put(fetch_task)
|
||||
|
||||
@ -266,10 +253,17 @@ class PanClient(AsyncClient):
|
||||
)):
|
||||
continue
|
||||
|
||||
if not self.store_message_cb(room, event):
|
||||
# The event was already in our store, we catched up.
|
||||
break
|
||||
else:
|
||||
display_name = room.user_name(event.sender)
|
||||
avatar_url = room.avatar_url(event.sender)
|
||||
self.index.add_event(event, room.room_id, display_name,
|
||||
avatar_url)
|
||||
|
||||
last_event = response.chunk[-1]
|
||||
|
||||
if not self.index.event_in_store(
|
||||
last_event.event_id,
|
||||
room.room_id
|
||||
):
|
||||
# There may be even more events to fetch, add a new task to
|
||||
# the queue.
|
||||
task = FetchTask(room.room_id, response.end)
|
||||
@ -277,6 +271,7 @@ class PanClient(AsyncClient):
|
||||
self.user_id, task)
|
||||
await self.history_fetch_queue.put(task)
|
||||
|
||||
await self.index.commit_events()
|
||||
self.delete_fetcher_task(fetch_task)
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
return
|
||||
@ -292,8 +287,7 @@ class PanClient(AsyncClient):
|
||||
self.key_verificatins_tasks = []
|
||||
self.key_request_tasks = []
|
||||
|
||||
self.index.commit()
|
||||
|
||||
await self.index.commit_events()
|
||||
self.pan_store.save_token(
|
||||
self.server_name,
|
||||
self.user_id,
|
||||
@ -688,13 +682,11 @@ class PanClient(AsyncClient):
|
||||
|
||||
async def search(self, search_terms):
|
||||
# type: (Dict[Any, Any]) -> Dict[Any, Any]
|
||||
loop = asyncio.get_event_loop()
|
||||
state_cache = dict()
|
||||
|
||||
async def add_context(room_id, event_id, before, after, include_state):
|
||||
async def add_context(event_dict, room_id, event_id, include_state):
|
||||
try:
|
||||
context = await self.room_context(room_id, event_id,
|
||||
limit=before+after)
|
||||
context = await self.room_context(room_id, event_id, limit=0)
|
||||
except ClientConnectionError:
|
||||
return
|
||||
|
||||
@ -704,16 +696,8 @@ class PanClient(AsyncClient):
|
||||
if include_state:
|
||||
state_cache[room_id] = [e.source for e in context.state]
|
||||
|
||||
event_context = event_dict["context"]
|
||||
|
||||
event_context["events_before"] = [
|
||||
e.source for e in context.events_before[:before]
|
||||
]
|
||||
event_context["events_after"] = [
|
||||
e.source for e in context.events_after[:after]
|
||||
]
|
||||
event_context["start"] = context.start
|
||||
event_context["end"] = context.end
|
||||
event_dict["context"]["start"] = context.start
|
||||
event_dict["context"]["end"] = context.end
|
||||
|
||||
validate_json(search_terms, SEARCH_TERMS_SCHEMA)
|
||||
search_terms = search_terms["search_categories"]["room_events"]
|
||||
@ -723,7 +707,7 @@ class PanClient(AsyncClient):
|
||||
limit = search_filter.get("limit", 10)
|
||||
|
||||
if limit <= 0:
|
||||
raise InvalidLimit(f"The limit must be strictly greater than 0.")
|
||||
raise InvalidLimit("The limit must be strictly greater than 0.")
|
||||
|
||||
rooms = search_filter.get("rooms", [])
|
||||
|
||||
@ -734,7 +718,7 @@ class PanClient(AsyncClient):
|
||||
if order_by not in ["rank", "recent"]:
|
||||
raise InvalidOrderByError(f"Invalid order by: {order_by}")
|
||||
|
||||
order_by_date = order_by == "recent"
|
||||
order_by_recent = order_by == "recent"
|
||||
|
||||
before_limit = 0
|
||||
after_limit = 0
|
||||
@ -746,51 +730,37 @@ class PanClient(AsyncClient):
|
||||
if event_context:
|
||||
before_limit = event_context.get("before_limit", 5)
|
||||
after_limit = event_context.get("before_limit", 5)
|
||||
include_profile = event_context.get("include_profile", False)
|
||||
|
||||
searcher = self.index.searcher()
|
||||
search_func = partial(searcher.search, term, room=room_id,
|
||||
max_results=limit, order_by_date=order_by_date)
|
||||
if before_limit < 0 or after_limit < 0:
|
||||
raise InvalidLimit("Invalid context limit, the limit must be a "
|
||||
"positive number")
|
||||
|
||||
result = await loop.run_in_executor(None, search_func)
|
||||
response_dict = await self.index.search(
|
||||
term,
|
||||
room=room_id,
|
||||
max_results=limit,
|
||||
order_by_recent=order_by_recent,
|
||||
include_profile=include_profile,
|
||||
before_limit=before_limit,
|
||||
after_limit=after_limit
|
||||
)
|
||||
|
||||
result_dict = {
|
||||
"results": []
|
||||
}
|
||||
|
||||
for score, column_id in result:
|
||||
event_dict = self.pan_store.load_event_by_columns(
|
||||
self.server_name,
|
||||
self.user_id,
|
||||
column_id, include_profile)
|
||||
|
||||
if not event_dict:
|
||||
continue
|
||||
|
||||
if include_state or before_limit or after_limit:
|
||||
await add_context(
|
||||
event_dict["result"]["room_id"],
|
||||
event_dict["result"]["event_id"],
|
||||
before_limit,
|
||||
after_limit,
|
||||
include_state
|
||||
)
|
||||
|
||||
if order_by_date:
|
||||
event_dict["rank"] = 1.0
|
||||
else:
|
||||
event_dict["rank"] = score
|
||||
|
||||
result_dict["results"].append(event_dict)
|
||||
|
||||
result_dict["count"] = len(result_dict["results"])
|
||||
result_dict["highlights"] = []
|
||||
# TODO add the state and start/end tokens
|
||||
# if event_context or include_state:
|
||||
# for event_dict in response_dict:
|
||||
# await add_context(
|
||||
# event_dict["result"]["room_id"],
|
||||
# event_dict["result"]["event_id"],
|
||||
# 0,
|
||||
# 0,
|
||||
# include_state
|
||||
# )
|
||||
|
||||
if include_state:
|
||||
result_dict["state"] = state_cache
|
||||
response_dict["state"] = state_cache
|
||||
|
||||
return {
|
||||
"search_categories": {
|
||||
"room_events": result_dict
|
||||
"room_events": response_dict
|
||||
}
|
||||
}
|
||||
|
@ -28,8 +28,8 @@ from multidict import CIMultiDict
|
||||
from nio import (Api, EncryptionError, LoginResponse, OlmTrustError,
|
||||
SendRetryError)
|
||||
|
||||
from pantalaimon.client import (InvalidOrderByError, PanClient,
|
||||
UnknownRoomError, InvalidLimit)
|
||||
from pantalaimon.client import (InvalidLimit, InvalidOrderByError, PanClient,
|
||||
UnknownRoomError)
|
||||
from pantalaimon.log import logger
|
||||
from pantalaimon.store import ClientInfo, PanStore
|
||||
from pantalaimon.thread_messages import (AcceptSasMessage, CancelSasMessage,
|
||||
|
@ -12,11 +12,239 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
import tantivy
|
||||
from nio import (RoomEncryptedMedia, RoomMessageMedia, RoomMessageText,
|
||||
RoomNameEvent, RoomTopicEvent)
|
||||
from peewee import (SQL, DateTimeField, ForeignKeyField, Model, SqliteDatabase,
|
||||
TextField)
|
||||
|
||||
from pantalaimon.store import use_database
|
||||
|
||||
|
||||
class DictField(TextField):
|
||||
def python_value(self, value): # pragma: no cover
|
||||
return json.loads(value)
|
||||
|
||||
def db_value(self, value): # pragma: no cover
|
||||
return json.dumps(value)
|
||||
|
||||
|
||||
class StoreUser(Model):
|
||||
user_id = TextField()
|
||||
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(user_id)")]
|
||||
|
||||
|
||||
class Profile(Model):
|
||||
user_id = TextField()
|
||||
avatar_url = TextField(null=True)
|
||||
display_name = TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(user_id,avatar_url,display_name)")]
|
||||
|
||||
|
||||
class Event(Model):
|
||||
event_id = TextField()
|
||||
sender = TextField()
|
||||
date = DateTimeField()
|
||||
room_id = TextField()
|
||||
|
||||
source = DictField()
|
||||
|
||||
profile = ForeignKeyField(
|
||||
model=Profile,
|
||||
column_name="profile_id",
|
||||
)
|
||||
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")]
|
||||
|
||||
|
||||
class UserMessages(Model):
|
||||
user = ForeignKeyField(
|
||||
model=StoreUser,
|
||||
column_name="user_id")
|
||||
event = ForeignKeyField(
|
||||
model=Event,
|
||||
column_name="event_id")
|
||||
|
||||
|
||||
@attr.s
|
||||
class MessageStore:
|
||||
user = attr.ib(type=str)
|
||||
store_path = attr.ib(type=str)
|
||||
database_name = attr.ib(type=str)
|
||||
database = attr.ib(type=SqliteDatabase, init=False)
|
||||
database_path = attr.ib(type=str, init=False)
|
||||
|
||||
models = [
|
||||
StoreUser,
|
||||
Event,
|
||||
Profile,
|
||||
UserMessages
|
||||
]
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.database_path = os.path.join(
|
||||
os.path.abspath(self.store_path),
|
||||
self.database_name
|
||||
)
|
||||
|
||||
self.database = self._create_database()
|
||||
self.database.connect()
|
||||
|
||||
with self.database.bind_ctx(self.models):
|
||||
self.database.create_tables(self.models)
|
||||
|
||||
def _create_database(self):
|
||||
return SqliteDatabase(
|
||||
self.database_path,
|
||||
pragmas={
|
||||
"foreign_keys": 1,
|
||||
"secure_delete": 1,
|
||||
}
|
||||
)
|
||||
|
||||
@use_database
|
||||
def event_in_store(self, event_id, room_id):
|
||||
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||
query = Event.select().join(UserMessages).where(
|
||||
(Event.room_id == room_id) &
|
||||
(Event.event_id == event_id) &
|
||||
(UserMessages.user == user)
|
||||
).execute()
|
||||
|
||||
for _ in query:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def save_event(self, event, room_id, display_name=None, avatar_url=None):
|
||||
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||
|
||||
profile_id, _ = Profile.get_or_create(
|
||||
user_id=event.sender,
|
||||
display_name=display_name,
|
||||
avatar_url=avatar_url
|
||||
)
|
||||
|
||||
event_source = event.source
|
||||
event_source["room_id"] = room_id
|
||||
|
||||
event_id = Event.insert(
|
||||
event_id=event.event_id,
|
||||
sender=event.sender,
|
||||
date=datetime.datetime.fromtimestamp(
|
||||
event.server_timestamp / 1000
|
||||
),
|
||||
room_id=room_id,
|
||||
source=event_source,
|
||||
profile=profile_id
|
||||
).on_conflict_ignore().execute()
|
||||
|
||||
if event_id <= 0:
|
||||
return None
|
||||
|
||||
_, created = UserMessages.get_or_create(
|
||||
user=user,
|
||||
event=event_id,
|
||||
)
|
||||
|
||||
if created:
|
||||
return event_id
|
||||
|
||||
return None
|
||||
|
||||
def _load_context(self, user, event, before, after):
|
||||
context = {}
|
||||
|
||||
if before > 0:
|
||||
query = Event.select().join(UserMessages).where(
|
||||
(Event.date <= event.date) &
|
||||
(Event.room_id == event.room_id) &
|
||||
(Event.id != event.id) &
|
||||
(UserMessages.user == user)
|
||||
).order_by(Event.date.desc()).limit(before)
|
||||
|
||||
context["events_before"] = [e.source for e in query]
|
||||
else:
|
||||
context["events_before"] = []
|
||||
|
||||
if after > 0:
|
||||
query = Event.select().join(UserMessages).where(
|
||||
(Event.date >= event.date) &
|
||||
(Event.room_id == event.room_id) &
|
||||
(Event.id != event.id) &
|
||||
(UserMessages.user == user)
|
||||
).order_by(Event.date).limit(after)
|
||||
|
||||
context["events_after"] = [e.source for e in query]
|
||||
else:
|
||||
context["events_after"] = []
|
||||
|
||||
return context
|
||||
|
||||
@use_database
|
||||
def load_events(
|
||||
self,
|
||||
search_result, # type: List[Tuple[int, int]]
|
||||
include_profile=False, # type: bool
|
||||
order_by_recent=False, # type: bool
|
||||
before=0, # type: int
|
||||
after=0 # type: int
|
||||
):
|
||||
# type: (...) -> Dict[Any, Any]
|
||||
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||
|
||||
search_dict = {r[1]: r[0] for r in search_result}
|
||||
columns = list(search_dict.keys())
|
||||
|
||||
result_dict = {
|
||||
"results": []
|
||||
}
|
||||
|
||||
query = UserMessages.select().where(
|
||||
(UserMessages.user_id == user) & (UserMessages.event.in_(columns))
|
||||
).execute()
|
||||
|
||||
for message in query:
|
||||
|
||||
event = message.event
|
||||
|
||||
event_dict = {
|
||||
"rank": 1 if order_by_recent else search_dict[event.id],
|
||||
"result": event.source,
|
||||
"context": {}
|
||||
}
|
||||
|
||||
if include_profile:
|
||||
event_profile = event.profile
|
||||
|
||||
event_dict["context"]["profile_info"] = {
|
||||
event_profile.user_id: {
|
||||
"display_name": event_profile.display_name,
|
||||
"avatar_url": event_profile.avatar_url,
|
||||
}
|
||||
}
|
||||
|
||||
context = self._load_context(user, event, before, after)
|
||||
|
||||
event_dict["context"]["events_before"] = context["events_before"]
|
||||
event_dict["context"]["events_after"] = context["events_after"]
|
||||
|
||||
result_dict["results"].append(event_dict)
|
||||
|
||||
return result_dict
|
||||
|
||||
|
||||
def sanitize_room_id(room_id):
|
||||
@ -37,7 +265,7 @@ class Searcher:
|
||||
self.timestamp_field = timestamp_field
|
||||
|
||||
def search(self, search_term, room=None, max_results=10,
|
||||
order_by_date=False):
|
||||
order_by_recent=False):
|
||||
# type (str, str, int, bool) -> List[int, int]
|
||||
"""Search for events in the index.
|
||||
|
||||
@ -63,7 +291,7 @@ class Searcher:
|
||||
|
||||
query = queryparser.parse_query(search_term)
|
||||
|
||||
if order_by_date:
|
||||
if order_by_recent:
|
||||
collector = tantivy.TopDocs(max_results,
|
||||
order_by_field=self.timestamp_field)
|
||||
else:
|
||||
@ -82,7 +310,7 @@ class Searcher:
|
||||
|
||||
|
||||
class Index:
|
||||
def __init__(self, path=None):
|
||||
def __init__(self, path=None, num_searchers=None):
|
||||
schema_builder = tantivy.SchemaBuilder()
|
||||
|
||||
self.body_field = schema_builder.add_text_field("body")
|
||||
@ -108,7 +336,7 @@ class Index:
|
||||
|
||||
self.index = tantivy.Index(schema, path)
|
||||
|
||||
self.reader = self.index.reader()
|
||||
self.reader = self.index.reader(num_searchers=num_searchers)
|
||||
self.writer = self.index.writer()
|
||||
|
||||
def add_event(self, column_id, event, room_id):
|
||||
@ -154,3 +382,113 @@ class Index:
|
||||
self.timestamp_field,
|
||||
self.reader.searcher()
|
||||
)
|
||||
|
||||
|
||||
@attr.s
|
||||
class StoreItem:
|
||||
event = attr.ib()
|
||||
room_id = attr.ib()
|
||||
display_name = attr.ib(default=None)
|
||||
avatar_url = attr.ib(default=None)
|
||||
|
||||
|
||||
@attr.s
|
||||
class IndexStore:
|
||||
user = attr.ib(type=str)
|
||||
index_path = attr.ib(type=str)
|
||||
store_path = attr.ib(type=str, default=None)
|
||||
store_name = attr.ib(default="events.db")
|
||||
|
||||
index = attr.ib(type=Index, init=False)
|
||||
store = attr.ib(type=MessageStore, init=False)
|
||||
event_queue = attr.ib(factory=list)
|
||||
write_lock = attr.ib(factory=asyncio.Lock)
|
||||
read_semaphore = attr.ib(type=asyncio.Semaphore, init=False)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.store_path = self.store_path or self.index_path
|
||||
num_searchers = os.cpu_count()
|
||||
self.index = Index(self.index_path, num_searchers)
|
||||
self.read_semaphore = asyncio.Semaphore(num_searchers or 1)
|
||||
self.store = MessageStore(self.user, self.store_path, self.store_name)
|
||||
|
||||
def add_event(self, event, room_id, display_name, avatar_url):
|
||||
item = StoreItem(event, room_id, display_name, avatar_url)
|
||||
self.event_queue.append(item)
|
||||
|
||||
@staticmethod
|
||||
def write_events(store, index, event_queue):
|
||||
with store.database.bind_ctx(store.models):
|
||||
with store.database.atomic():
|
||||
for item in event_queue:
|
||||
column_id = store.save_event(
|
||||
item.event,
|
||||
item.room_id,
|
||||
)
|
||||
|
||||
if column_id:
|
||||
index.add_event(column_id, item.event, item.room_id)
|
||||
index.commit()
|
||||
|
||||
async def commit_events(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
event_queue = self.event_queue
|
||||
|
||||
if not event_queue:
|
||||
return
|
||||
|
||||
self.event_queue = []
|
||||
|
||||
async with self.write_lock:
|
||||
write_func = partial(
|
||||
IndexStore.write_events,
|
||||
self.store,
|
||||
self.index,
|
||||
event_queue
|
||||
)
|
||||
await loop.run_in_executor(None, write_func)
|
||||
|
||||
def event_in_store(self, event_id, room_id):
|
||||
return self.store.event_in_store(event_id, room_id)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
search_term, # type: str
|
||||
room=None, # type: Optional[str]
|
||||
max_results=10, # type: int
|
||||
order_by_recent=False, # type: bool
|
||||
include_profile=False, # type: bool
|
||||
before_limit=0, # type: int
|
||||
after_limit=0 # type: int
|
||||
):
|
||||
# type: (...) -> Dict[Any, Any]
|
||||
"""Search the indexstore for an event."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Getting a searcher from tantivy may block if there is no searcher
|
||||
# available. To avoid blocking we set up the number of searchers to be
|
||||
# the number of CPUs and the semaphore has the same counter value.
|
||||
async with self.read_semaphore:
|
||||
searcher = self.index.searcher()
|
||||
search_func = partial(searcher.search, search_term, room=room,
|
||||
max_results=max_results,
|
||||
order_by_recent=order_by_recent)
|
||||
|
||||
result = await loop.run_in_executor(None, search_func)
|
||||
|
||||
load_event_func = partial(
|
||||
self.store.load_events,
|
||||
result,
|
||||
include_profile,
|
||||
order_by_recent,
|
||||
before_limit,
|
||||
after_limit
|
||||
)
|
||||
|
||||
search_result = await loop.run_in_executor(None, load_event_func)
|
||||
|
||||
search_result["count"] = len(search_result["results"])
|
||||
search_result["highlights"] = []
|
||||
|
||||
return search_result
|
||||
|
@ -12,18 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from nio import RoomMessage
|
||||
from nio.store import (Accounts, DeviceKeys, DeviceTrustState, TrustState,
|
||||
use_database)
|
||||
from peewee import (SQL, DateTimeField, DoesNotExist, ForeignKeyField, Model,
|
||||
SqliteDatabase, TextField)
|
||||
from peewee import (SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
|
||||
TextField)
|
||||
|
||||
|
||||
@attr.s
|
||||
@ -70,41 +68,6 @@ class ServerUsers(Model):
|
||||
constraints = [SQL("UNIQUE(user_id,server_id)")]
|
||||
|
||||
|
||||
class Profile(Model):
|
||||
user_id = TextField()
|
||||
avatar_url = TextField(null=True)
|
||||
display_name = TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(user_id,avatar_url,display_name)")]
|
||||
|
||||
|
||||
class Event(Model):
|
||||
event_id = TextField()
|
||||
sender = TextField()
|
||||
date = DateTimeField()
|
||||
room_id = TextField()
|
||||
|
||||
source = DictField()
|
||||
|
||||
profile = ForeignKeyField(
|
||||
model=Profile,
|
||||
column_name="profile_id",
|
||||
)
|
||||
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")]
|
||||
|
||||
|
||||
class UserMessages(Model):
|
||||
user = ForeignKeyField(
|
||||
model=ServerUsers,
|
||||
column_name="user_id")
|
||||
event = ForeignKeyField(
|
||||
model=Event,
|
||||
column_name="event_id")
|
||||
|
||||
|
||||
class PanSyncTokens(Model):
|
||||
token = TextField()
|
||||
user = ForeignKeyField(
|
||||
@ -146,9 +109,6 @@ class PanStore:
|
||||
ServerUsers,
|
||||
DeviceKeys,
|
||||
DeviceTrustState,
|
||||
Profile,
|
||||
Event,
|
||||
UserMessages,
|
||||
PanSyncTokens,
|
||||
PanFetcherTasks
|
||||
]
|
||||
@ -242,86 +202,6 @@ class PanStore:
|
||||
|
||||
return None
|
||||
|
||||
@use_database
|
||||
def save_event(self, server, pan_user, event, room_id, display_name,
|
||||
avatar_url):
|
||||
# type: (str, str, str, RoomMessage, str, str, str) -> Optional[int]
|
||||
"""Save an event to the store.
|
||||
|
||||
Returns the database id of the event.
|
||||
"""
|
||||
server = Servers.get(name=server)
|
||||
user = ServerUsers.get(server=server, user_id=pan_user)
|
||||
|
||||
profile_id, _ = Profile.get_or_create(
|
||||
user_id=event.sender,
|
||||
display_name=display_name,
|
||||
avatar_url=avatar_url
|
||||
)
|
||||
|
||||
event_source = event.source
|
||||
event_source["room_id"] = room_id
|
||||
|
||||
event_id = Event.insert(
|
||||
event_id=event.event_id,
|
||||
sender=event.sender,
|
||||
date=datetime.datetime.fromtimestamp(
|
||||
event.server_timestamp / 1000
|
||||
),
|
||||
room_id=room_id,
|
||||
source=event_source,
|
||||
profile=profile_id
|
||||
).on_conflict_ignore().execute()
|
||||
|
||||
if event_id <= 0:
|
||||
return None
|
||||
|
||||
_, created = UserMessages.get_or_create(
|
||||
user=user,
|
||||
event=event_id,
|
||||
)
|
||||
|
||||
if created:
|
||||
return event_id
|
||||
|
||||
return None
|
||||
|
||||
@use_database
|
||||
def load_event_by_columns(
|
||||
self,
|
||||
server, # type: str
|
||||
pan_user, # type: str
|
||||
column, # type: List[int]
|
||||
include_profile=False # type: bool
|
||||
):
|
||||
# type: (...) -> Optional[Dict]
|
||||
server = Servers.get(name=server)
|
||||
user = ServerUsers.get(server=server, user_id=pan_user)
|
||||
|
||||
message = UserMessages.get_or_none(user=user, event=column)
|
||||
|
||||
if not message:
|
||||
return None
|
||||
|
||||
event = message.event
|
||||
|
||||
event_dict = {
|
||||
"result": event.source,
|
||||
"context": {}
|
||||
}
|
||||
|
||||
if include_profile:
|
||||
event_profile = event.profile
|
||||
|
||||
event_dict["context"]["profile_info"] = {
|
||||
event_profile.user_id: {
|
||||
"display_name": event_profile.display_name,
|
||||
"avatar_url": event_profile.avatar_url,
|
||||
}
|
||||
}
|
||||
|
||||
return event_dict
|
||||
|
||||
@use_database
|
||||
def save_server_user(self, server_name, user_id):
|
||||
# type: (str, str) -> None
|
||||
|
@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
import pdb
|
||||
import pprint
|
||||
|
||||
from nio import RoomMessage
|
||||
|
||||
from conftest import faker
|
||||
from pantalaimon.index import Index
|
||||
from pantalaimon.index import Index, IndexStore
|
||||
from pantalaimon.store import FetchTask
|
||||
|
||||
TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost"
|
||||
@ -32,8 +34,8 @@ class TestClass(object):
|
||||
return RoomMessage.parse_event(
|
||||
{
|
||||
"content": {"body": "Another message", "msgtype": "m.text"},
|
||||
"event_id": "$15163622445EBvZJ:localhost",
|
||||
"origin_server_ts": 1516362244026,
|
||||
"event_id": "$15163622445EBvZK:localhost",
|
||||
"origin_server_ts": 1516362244030,
|
||||
"room_id": "!SVkFJHzfwvuaIEawgC:localhost",
|
||||
"sender": "@example2:localhost",
|
||||
"type": "m.room.message",
|
||||
@ -58,66 +60,6 @@ class TestClass(object):
|
||||
token = panstore.load_access_token(user_id, device_id)
|
||||
access_token == token
|
||||
|
||||
def test_event_storing(self, panstore_with_users):
|
||||
panstore = panstore_with_users
|
||||
accounts = panstore.load_all_users()
|
||||
user, _ = accounts[0]
|
||||
|
||||
event = self.test_event
|
||||
|
||||
event_id = panstore.save_event("example", user, event, TEST_ROOM,
|
||||
"Example2", None)
|
||||
|
||||
assert event_id == 1
|
||||
|
||||
event_id = panstore.save_event("example", user, event, TEST_ROOM,
|
||||
"Example2", None)
|
||||
assert event_id is None
|
||||
|
||||
event_dict = panstore.load_event_by_columns("example", user, 1)
|
||||
assert event.source == event_dict["result"]
|
||||
|
||||
event_source = panstore.load_event_by_columns("example", user, 1, True)
|
||||
|
||||
assert event_source["context"]["profile_info"] == {
|
||||
"@example2:localhost": {
|
||||
"display_name": "Example2",
|
||||
"avatar_url": None
|
||||
}
|
||||
}
|
||||
|
||||
def test_index(self, panstore_with_users):
|
||||
panstore = panstore_with_users
|
||||
accounts = panstore.load_all_users()
|
||||
user, _ = accounts[0]
|
||||
|
||||
event = self.test_event
|
||||
another_event = self.another_event
|
||||
|
||||
index = Index(panstore.store_path)
|
||||
|
||||
event_id = panstore.save_event("example", user, event, TEST_ROOM,
|
||||
"Example2", None)
|
||||
assert event_id == 1
|
||||
index.add_event(event_id, event, TEST_ROOM)
|
||||
|
||||
event_id = panstore.save_event("example", user, another_event,
|
||||
TEST_ROOM2, "Example2", None)
|
||||
assert event_id == 2
|
||||
index.add_event(event_id, another_event, TEST_ROOM2)
|
||||
|
||||
index.commit()
|
||||
|
||||
searcher = index.searcher()
|
||||
|
||||
searched_events = searcher.search("message", TEST_ROOM)
|
||||
|
||||
_, found_id = searched_events[0]
|
||||
|
||||
event_dict = panstore.load_event_by_columns("example", user, found_id)
|
||||
|
||||
assert event_dict["result"] == event.source
|
||||
|
||||
def test_token_storing(self, panstore_with_users):
|
||||
panstore = panstore_with_users
|
||||
accounts = panstore.load_all_users()
|
||||
@ -151,3 +93,26 @@ class TestClass(object):
|
||||
|
||||
assert task not in tasks
|
||||
assert task2 in tasks
|
||||
|
||||
def test_new_indexstore(self, tempdir):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
store = IndexStore("example", tempdir)
|
||||
|
||||
store.add_event(self.test_event, TEST_ROOM, None, None)
|
||||
store.add_event(self.another_event, TEST_ROOM, None, None)
|
||||
loop.run_until_complete(store.commit_events())
|
||||
|
||||
assert store.event_in_store(self.test_event.event_id, TEST_ROOM)
|
||||
assert not store.event_in_store("FAKE", TEST_ROOM)
|
||||
|
||||
result = loop.run_until_complete(
|
||||
store.search("test", TEST_ROOM, after_limit=10, before_limit=10)
|
||||
)
|
||||
pprint.pprint(result)
|
||||
|
||||
assert len(result["results"]) == 1
|
||||
assert result["count"] == 1
|
||||
assert result["results"][0]["result"] == self.test_event.source
|
||||
assert (result["results"][0]["context"]["events_after"][0]
|
||||
== self.another_event.source)
|
||||
|
Loading…
Reference in New Issue
Block a user