mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-07-26 08:15:39 -04:00
index: Rewrite the search logic.
This patch moves all the indexing, event storing and searching into a separate class. The index and message store are now represented as a single class and messages are indexed and stored atomically now which should minimize the chance of store/index inconsistencies. Messages are now loaded from the store as a single SQL query and the context for the messages is as well loaded from the store instead of fetched from the server. The room state and start/end tokens for the context aren't currently loaded.
This commit is contained in:
parent
a489031962
commit
3a1b001244
5 changed files with 423 additions and 270 deletions
|
@ -15,7 +15,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial
|
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
@ -30,7 +29,7 @@ from nio import (AsyncClient, ClientConfig, EncryptionError, KeysQueryResponse,
|
||||||
from nio.crypto import Sas
|
from nio.crypto import Sas
|
||||||
from nio.store import SqliteStore
|
from nio.store import SqliteStore
|
||||||
|
|
||||||
from pantalaimon.index import Index
|
from pantalaimon.index import IndexStore
|
||||||
from pantalaimon.log import logger
|
from pantalaimon.log import logger
|
||||||
from pantalaimon.store import FetchTask
|
from pantalaimon.store import FetchTask
|
||||||
from pantalaimon.thread_messages import (DaemonResponse, InviteSasSignal,
|
from pantalaimon.thread_messages import (DaemonResponse, InviteSasSignal,
|
||||||
|
@ -137,7 +136,7 @@ class PanClient(AsyncClient):
|
||||||
|
|
||||||
self.server_name = server_name
|
self.server_name = server_name
|
||||||
self.pan_store = pan_store
|
self.pan_store = pan_store
|
||||||
self.index = Index(index_dir)
|
self.index = IndexStore(self.user_id, index_dir)
|
||||||
self.task = None
|
self.task = None
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
|
|
||||||
|
@ -179,20 +178,7 @@ class PanClient(AsyncClient):
|
||||||
display_name = room.user_name(event.sender)
|
display_name = room.user_name(event.sender)
|
||||||
avatar_url = room.avatar_url(event.sender)
|
avatar_url = room.avatar_url(event.sender)
|
||||||
|
|
||||||
column_id = self.pan_store.save_event(
|
self.index.add_event(event, room.room_id, display_name, avatar_url)
|
||||||
self.server_name,
|
|
||||||
self.user_id,
|
|
||||||
event,
|
|
||||||
room.room_id,
|
|
||||||
display_name,
|
|
||||||
avatar_url
|
|
||||||
)
|
|
||||||
|
|
||||||
if column_id:
|
|
||||||
self.index.add_event(column_id, event, room.room_id)
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unable_to_decrypt(self):
|
def unable_to_decrypt(self):
|
||||||
|
@ -247,7 +233,8 @@ class PanClient(AsyncClient):
|
||||||
room.display_name
|
room.display_name
|
||||||
))
|
))
|
||||||
response = await self.room_messages(fetch_task.room_id,
|
response = await self.room_messages(fetch_task.room_id,
|
||||||
fetch_task.token)
|
fetch_task.token,
|
||||||
|
limit=100)
|
||||||
except ClientConnectionError:
|
except ClientConnectionError:
|
||||||
self.history_fetch_queue.put(fetch_task)
|
self.history_fetch_queue.put(fetch_task)
|
||||||
|
|
||||||
|
@ -266,10 +253,17 @@ class PanClient(AsyncClient):
|
||||||
)):
|
)):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not self.store_message_cb(room, event):
|
display_name = room.user_name(event.sender)
|
||||||
# The event was already in our store, we catched up.
|
avatar_url = room.avatar_url(event.sender)
|
||||||
break
|
self.index.add_event(event, room.room_id, display_name,
|
||||||
else:
|
avatar_url)
|
||||||
|
|
||||||
|
last_event = response.chunk[-1]
|
||||||
|
|
||||||
|
if not self.index.event_in_store(
|
||||||
|
last_event.event_id,
|
||||||
|
room.room_id
|
||||||
|
):
|
||||||
# There may be even more events to fetch, add a new task to
|
# There may be even more events to fetch, add a new task to
|
||||||
# the queue.
|
# the queue.
|
||||||
task = FetchTask(room.room_id, response.end)
|
task = FetchTask(room.room_id, response.end)
|
||||||
|
@ -277,6 +271,7 @@ class PanClient(AsyncClient):
|
||||||
self.user_id, task)
|
self.user_id, task)
|
||||||
await self.history_fetch_queue.put(task)
|
await self.history_fetch_queue.put(task)
|
||||||
|
|
||||||
|
await self.index.commit_events()
|
||||||
self.delete_fetcher_task(fetch_task)
|
self.delete_fetcher_task(fetch_task)
|
||||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||||
return
|
return
|
||||||
|
@ -292,8 +287,7 @@ class PanClient(AsyncClient):
|
||||||
self.key_verificatins_tasks = []
|
self.key_verificatins_tasks = []
|
||||||
self.key_request_tasks = []
|
self.key_request_tasks = []
|
||||||
|
|
||||||
self.index.commit()
|
await self.index.commit_events()
|
||||||
|
|
||||||
self.pan_store.save_token(
|
self.pan_store.save_token(
|
||||||
self.server_name,
|
self.server_name,
|
||||||
self.user_id,
|
self.user_id,
|
||||||
|
@ -688,13 +682,11 @@ class PanClient(AsyncClient):
|
||||||
|
|
||||||
async def search(self, search_terms):
|
async def search(self, search_terms):
|
||||||
# type: (Dict[Any, Any]) -> Dict[Any, Any]
|
# type: (Dict[Any, Any]) -> Dict[Any, Any]
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
state_cache = dict()
|
state_cache = dict()
|
||||||
|
|
||||||
async def add_context(room_id, event_id, before, after, include_state):
|
async def add_context(event_dict, room_id, event_id, include_state):
|
||||||
try:
|
try:
|
||||||
context = await self.room_context(room_id, event_id,
|
context = await self.room_context(room_id, event_id, limit=0)
|
||||||
limit=before+after)
|
|
||||||
except ClientConnectionError:
|
except ClientConnectionError:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -704,16 +696,8 @@ class PanClient(AsyncClient):
|
||||||
if include_state:
|
if include_state:
|
||||||
state_cache[room_id] = [e.source for e in context.state]
|
state_cache[room_id] = [e.source for e in context.state]
|
||||||
|
|
||||||
event_context = event_dict["context"]
|
event_dict["context"]["start"] = context.start
|
||||||
|
event_dict["context"]["end"] = context.end
|
||||||
event_context["events_before"] = [
|
|
||||||
e.source for e in context.events_before[:before]
|
|
||||||
]
|
|
||||||
event_context["events_after"] = [
|
|
||||||
e.source for e in context.events_after[:after]
|
|
||||||
]
|
|
||||||
event_context["start"] = context.start
|
|
||||||
event_context["end"] = context.end
|
|
||||||
|
|
||||||
validate_json(search_terms, SEARCH_TERMS_SCHEMA)
|
validate_json(search_terms, SEARCH_TERMS_SCHEMA)
|
||||||
search_terms = search_terms["search_categories"]["room_events"]
|
search_terms = search_terms["search_categories"]["room_events"]
|
||||||
|
@ -723,7 +707,7 @@ class PanClient(AsyncClient):
|
||||||
limit = search_filter.get("limit", 10)
|
limit = search_filter.get("limit", 10)
|
||||||
|
|
||||||
if limit <= 0:
|
if limit <= 0:
|
||||||
raise InvalidLimit(f"The limit must be strictly greater than 0.")
|
raise InvalidLimit("The limit must be strictly greater than 0.")
|
||||||
|
|
||||||
rooms = search_filter.get("rooms", [])
|
rooms = search_filter.get("rooms", [])
|
||||||
|
|
||||||
|
@ -734,7 +718,7 @@ class PanClient(AsyncClient):
|
||||||
if order_by not in ["rank", "recent"]:
|
if order_by not in ["rank", "recent"]:
|
||||||
raise InvalidOrderByError(f"Invalid order by: {order_by}")
|
raise InvalidOrderByError(f"Invalid order by: {order_by}")
|
||||||
|
|
||||||
order_by_date = order_by == "recent"
|
order_by_recent = order_by == "recent"
|
||||||
|
|
||||||
before_limit = 0
|
before_limit = 0
|
||||||
after_limit = 0
|
after_limit = 0
|
||||||
|
@ -746,51 +730,37 @@ class PanClient(AsyncClient):
|
||||||
if event_context:
|
if event_context:
|
||||||
before_limit = event_context.get("before_limit", 5)
|
before_limit = event_context.get("before_limit", 5)
|
||||||
after_limit = event_context.get("before_limit", 5)
|
after_limit = event_context.get("before_limit", 5)
|
||||||
include_profile = event_context.get("include_profile", False)
|
|
||||||
|
|
||||||
searcher = self.index.searcher()
|
if before_limit < 0 or after_limit < 0:
|
||||||
search_func = partial(searcher.search, term, room=room_id,
|
raise InvalidLimit("Invalid context limit, the limit must be a "
|
||||||
max_results=limit, order_by_date=order_by_date)
|
"positive number")
|
||||||
|
|
||||||
result = await loop.run_in_executor(None, search_func)
|
response_dict = await self.index.search(
|
||||||
|
term,
|
||||||
result_dict = {
|
room=room_id,
|
||||||
"results": []
|
max_results=limit,
|
||||||
}
|
order_by_recent=order_by_recent,
|
||||||
|
include_profile=include_profile,
|
||||||
for score, column_id in result:
|
before_limit=before_limit,
|
||||||
event_dict = self.pan_store.load_event_by_columns(
|
after_limit=after_limit
|
||||||
self.server_name,
|
|
||||||
self.user_id,
|
|
||||||
column_id, include_profile)
|
|
||||||
|
|
||||||
if not event_dict:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if include_state or before_limit or after_limit:
|
|
||||||
await add_context(
|
|
||||||
event_dict["result"]["room_id"],
|
|
||||||
event_dict["result"]["event_id"],
|
|
||||||
before_limit,
|
|
||||||
after_limit,
|
|
||||||
include_state
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if order_by_date:
|
# TODO add the state and start/end tokens
|
||||||
event_dict["rank"] = 1.0
|
# if event_context or include_state:
|
||||||
else:
|
# for event_dict in response_dict:
|
||||||
event_dict["rank"] = score
|
# await add_context(
|
||||||
|
# event_dict["result"]["room_id"],
|
||||||
result_dict["results"].append(event_dict)
|
# event_dict["result"]["event_id"],
|
||||||
|
# 0,
|
||||||
result_dict["count"] = len(result_dict["results"])
|
# 0,
|
||||||
result_dict["highlights"] = []
|
# include_state
|
||||||
|
# )
|
||||||
|
|
||||||
if include_state:
|
if include_state:
|
||||||
result_dict["state"] = state_cache
|
response_dict["state"] = state_cache
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"search_categories": {
|
"search_categories": {
|
||||||
"room_events": result_dict
|
"room_events": response_dict
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,8 +28,8 @@ from multidict import CIMultiDict
|
||||||
from nio import (Api, EncryptionError, LoginResponse, OlmTrustError,
|
from nio import (Api, EncryptionError, LoginResponse, OlmTrustError,
|
||||||
SendRetryError)
|
SendRetryError)
|
||||||
|
|
||||||
from pantalaimon.client import (InvalidOrderByError, PanClient,
|
from pantalaimon.client import (InvalidLimit, InvalidOrderByError, PanClient,
|
||||||
UnknownRoomError, InvalidLimit)
|
UnknownRoomError)
|
||||||
from pantalaimon.log import logger
|
from pantalaimon.log import logger
|
||||||
from pantalaimon.store import ClientInfo, PanStore
|
from pantalaimon.store import ClientInfo, PanStore
|
||||||
from pantalaimon.thread_messages import (AcceptSasMessage, CancelSasMessage,
|
from pantalaimon.thread_messages import (AcceptSasMessage, CancelSasMessage,
|
||||||
|
|
|
@ -12,11 +12,239 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import attr
|
||||||
import tantivy
|
import tantivy
|
||||||
from nio import (RoomEncryptedMedia, RoomMessageMedia, RoomMessageText,
|
from nio import (RoomEncryptedMedia, RoomMessageMedia, RoomMessageText,
|
||||||
RoomNameEvent, RoomTopicEvent)
|
RoomNameEvent, RoomTopicEvent)
|
||||||
|
from peewee import (SQL, DateTimeField, ForeignKeyField, Model, SqliteDatabase,
|
||||||
|
TextField)
|
||||||
|
|
||||||
|
from pantalaimon.store import use_database
|
||||||
|
|
||||||
|
|
||||||
|
class DictField(TextField):
|
||||||
|
def python_value(self, value): # pragma: no cover
|
||||||
|
return json.loads(value)
|
||||||
|
|
||||||
|
def db_value(self, value): # pragma: no cover
|
||||||
|
return json.dumps(value)
|
||||||
|
|
||||||
|
|
||||||
|
class StoreUser(Model):
|
||||||
|
user_id = TextField()
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
constraints = [SQL("UNIQUE(user_id)")]
|
||||||
|
|
||||||
|
|
||||||
|
class Profile(Model):
|
||||||
|
user_id = TextField()
|
||||||
|
avatar_url = TextField(null=True)
|
||||||
|
display_name = TextField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
constraints = [SQL("UNIQUE(user_id,avatar_url,display_name)")]
|
||||||
|
|
||||||
|
|
||||||
|
class Event(Model):
|
||||||
|
event_id = TextField()
|
||||||
|
sender = TextField()
|
||||||
|
date = DateTimeField()
|
||||||
|
room_id = TextField()
|
||||||
|
|
||||||
|
source = DictField()
|
||||||
|
|
||||||
|
profile = ForeignKeyField(
|
||||||
|
model=Profile,
|
||||||
|
column_name="profile_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")]
|
||||||
|
|
||||||
|
|
||||||
|
class UserMessages(Model):
|
||||||
|
user = ForeignKeyField(
|
||||||
|
model=StoreUser,
|
||||||
|
column_name="user_id")
|
||||||
|
event = ForeignKeyField(
|
||||||
|
model=Event,
|
||||||
|
column_name="event_id")
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class MessageStore:
|
||||||
|
user = attr.ib(type=str)
|
||||||
|
store_path = attr.ib(type=str)
|
||||||
|
database_name = attr.ib(type=str)
|
||||||
|
database = attr.ib(type=SqliteDatabase, init=False)
|
||||||
|
database_path = attr.ib(type=str, init=False)
|
||||||
|
|
||||||
|
models = [
|
||||||
|
StoreUser,
|
||||||
|
Event,
|
||||||
|
Profile,
|
||||||
|
UserMessages
|
||||||
|
]
|
||||||
|
|
||||||
|
def __attrs_post_init__(self):
|
||||||
|
self.database_path = os.path.join(
|
||||||
|
os.path.abspath(self.store_path),
|
||||||
|
self.database_name
|
||||||
|
)
|
||||||
|
|
||||||
|
self.database = self._create_database()
|
||||||
|
self.database.connect()
|
||||||
|
|
||||||
|
with self.database.bind_ctx(self.models):
|
||||||
|
self.database.create_tables(self.models)
|
||||||
|
|
||||||
|
def _create_database(self):
|
||||||
|
return SqliteDatabase(
|
||||||
|
self.database_path,
|
||||||
|
pragmas={
|
||||||
|
"foreign_keys": 1,
|
||||||
|
"secure_delete": 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@use_database
|
||||||
|
def event_in_store(self, event_id, room_id):
|
||||||
|
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||||
|
query = Event.select().join(UserMessages).where(
|
||||||
|
(Event.room_id == room_id) &
|
||||||
|
(Event.event_id == event_id) &
|
||||||
|
(UserMessages.user == user)
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
for _ in query:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def save_event(self, event, room_id, display_name=None, avatar_url=None):
|
||||||
|
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||||
|
|
||||||
|
profile_id, _ = Profile.get_or_create(
|
||||||
|
user_id=event.sender,
|
||||||
|
display_name=display_name,
|
||||||
|
avatar_url=avatar_url
|
||||||
|
)
|
||||||
|
|
||||||
|
event_source = event.source
|
||||||
|
event_source["room_id"] = room_id
|
||||||
|
|
||||||
|
event_id = Event.insert(
|
||||||
|
event_id=event.event_id,
|
||||||
|
sender=event.sender,
|
||||||
|
date=datetime.datetime.fromtimestamp(
|
||||||
|
event.server_timestamp / 1000
|
||||||
|
),
|
||||||
|
room_id=room_id,
|
||||||
|
source=event_source,
|
||||||
|
profile=profile_id
|
||||||
|
).on_conflict_ignore().execute()
|
||||||
|
|
||||||
|
if event_id <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
_, created = UserMessages.get_or_create(
|
||||||
|
user=user,
|
||||||
|
event=event_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if created:
|
||||||
|
return event_id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load_context(self, user, event, before, after):
|
||||||
|
context = {}
|
||||||
|
|
||||||
|
if before > 0:
|
||||||
|
query = Event.select().join(UserMessages).where(
|
||||||
|
(Event.date <= event.date) &
|
||||||
|
(Event.room_id == event.room_id) &
|
||||||
|
(Event.id != event.id) &
|
||||||
|
(UserMessages.user == user)
|
||||||
|
).order_by(Event.date.desc()).limit(before)
|
||||||
|
|
||||||
|
context["events_before"] = [e.source for e in query]
|
||||||
|
else:
|
||||||
|
context["events_before"] = []
|
||||||
|
|
||||||
|
if after > 0:
|
||||||
|
query = Event.select().join(UserMessages).where(
|
||||||
|
(Event.date >= event.date) &
|
||||||
|
(Event.room_id == event.room_id) &
|
||||||
|
(Event.id != event.id) &
|
||||||
|
(UserMessages.user == user)
|
||||||
|
).order_by(Event.date).limit(after)
|
||||||
|
|
||||||
|
context["events_after"] = [e.source for e in query]
|
||||||
|
else:
|
||||||
|
context["events_after"] = []
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
|
@use_database
|
||||||
|
def load_events(
|
||||||
|
self,
|
||||||
|
search_result, # type: List[Tuple[int, int]]
|
||||||
|
include_profile=False, # type: bool
|
||||||
|
order_by_recent=False, # type: bool
|
||||||
|
before=0, # type: int
|
||||||
|
after=0 # type: int
|
||||||
|
):
|
||||||
|
# type: (...) -> Dict[Any, Any]
|
||||||
|
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||||
|
|
||||||
|
search_dict = {r[1]: r[0] for r in search_result}
|
||||||
|
columns = list(search_dict.keys())
|
||||||
|
|
||||||
|
result_dict = {
|
||||||
|
"results": []
|
||||||
|
}
|
||||||
|
|
||||||
|
query = UserMessages.select().where(
|
||||||
|
(UserMessages.user_id == user) & (UserMessages.event.in_(columns))
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
for message in query:
|
||||||
|
|
||||||
|
event = message.event
|
||||||
|
|
||||||
|
event_dict = {
|
||||||
|
"rank": 1 if order_by_recent else search_dict[event.id],
|
||||||
|
"result": event.source,
|
||||||
|
"context": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if include_profile:
|
||||||
|
event_profile = event.profile
|
||||||
|
|
||||||
|
event_dict["context"]["profile_info"] = {
|
||||||
|
event_profile.user_id: {
|
||||||
|
"display_name": event_profile.display_name,
|
||||||
|
"avatar_url": event_profile.avatar_url,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
context = self._load_context(user, event, before, after)
|
||||||
|
|
||||||
|
event_dict["context"]["events_before"] = context["events_before"]
|
||||||
|
event_dict["context"]["events_after"] = context["events_after"]
|
||||||
|
|
||||||
|
result_dict["results"].append(event_dict)
|
||||||
|
|
||||||
|
return result_dict
|
||||||
|
|
||||||
|
|
||||||
def sanitize_room_id(room_id):
|
def sanitize_room_id(room_id):
|
||||||
|
@ -37,7 +265,7 @@ class Searcher:
|
||||||
self.timestamp_field = timestamp_field
|
self.timestamp_field = timestamp_field
|
||||||
|
|
||||||
def search(self, search_term, room=None, max_results=10,
|
def search(self, search_term, room=None, max_results=10,
|
||||||
order_by_date=False):
|
order_by_recent=False):
|
||||||
# type (str, str, int, bool) -> List[int, int]
|
# type (str, str, int, bool) -> List[int, int]
|
||||||
"""Search for events in the index.
|
"""Search for events in the index.
|
||||||
|
|
||||||
|
@ -63,7 +291,7 @@ class Searcher:
|
||||||
|
|
||||||
query = queryparser.parse_query(search_term)
|
query = queryparser.parse_query(search_term)
|
||||||
|
|
||||||
if order_by_date:
|
if order_by_recent:
|
||||||
collector = tantivy.TopDocs(max_results,
|
collector = tantivy.TopDocs(max_results,
|
||||||
order_by_field=self.timestamp_field)
|
order_by_field=self.timestamp_field)
|
||||||
else:
|
else:
|
||||||
|
@ -82,7 +310,7 @@ class Searcher:
|
||||||
|
|
||||||
|
|
||||||
class Index:
|
class Index:
|
||||||
def __init__(self, path=None):
|
def __init__(self, path=None, num_searchers=None):
|
||||||
schema_builder = tantivy.SchemaBuilder()
|
schema_builder = tantivy.SchemaBuilder()
|
||||||
|
|
||||||
self.body_field = schema_builder.add_text_field("body")
|
self.body_field = schema_builder.add_text_field("body")
|
||||||
|
@ -108,7 +336,7 @@ class Index:
|
||||||
|
|
||||||
self.index = tantivy.Index(schema, path)
|
self.index = tantivy.Index(schema, path)
|
||||||
|
|
||||||
self.reader = self.index.reader()
|
self.reader = self.index.reader(num_searchers=num_searchers)
|
||||||
self.writer = self.index.writer()
|
self.writer = self.index.writer()
|
||||||
|
|
||||||
def add_event(self, column_id, event, room_id):
|
def add_event(self, column_id, event, room_id):
|
||||||
|
@ -154,3 +382,113 @@ class Index:
|
||||||
self.timestamp_field,
|
self.timestamp_field,
|
||||||
self.reader.searcher()
|
self.reader.searcher()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class StoreItem:
|
||||||
|
event = attr.ib()
|
||||||
|
room_id = attr.ib()
|
||||||
|
display_name = attr.ib(default=None)
|
||||||
|
avatar_url = attr.ib(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class IndexStore:
|
||||||
|
user = attr.ib(type=str)
|
||||||
|
index_path = attr.ib(type=str)
|
||||||
|
store_path = attr.ib(type=str, default=None)
|
||||||
|
store_name = attr.ib(default="events.db")
|
||||||
|
|
||||||
|
index = attr.ib(type=Index, init=False)
|
||||||
|
store = attr.ib(type=MessageStore, init=False)
|
||||||
|
event_queue = attr.ib(factory=list)
|
||||||
|
write_lock = attr.ib(factory=asyncio.Lock)
|
||||||
|
read_semaphore = attr.ib(type=asyncio.Semaphore, init=False)
|
||||||
|
|
||||||
|
def __attrs_post_init__(self):
|
||||||
|
self.store_path = self.store_path or self.index_path
|
||||||
|
num_searchers = os.cpu_count()
|
||||||
|
self.index = Index(self.index_path, num_searchers)
|
||||||
|
self.read_semaphore = asyncio.Semaphore(num_searchers or 1)
|
||||||
|
self.store = MessageStore(self.user, self.store_path, self.store_name)
|
||||||
|
|
||||||
|
def add_event(self, event, room_id, display_name, avatar_url):
|
||||||
|
item = StoreItem(event, room_id, display_name, avatar_url)
|
||||||
|
self.event_queue.append(item)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def write_events(store, index, event_queue):
|
||||||
|
with store.database.bind_ctx(store.models):
|
||||||
|
with store.database.atomic():
|
||||||
|
for item in event_queue:
|
||||||
|
column_id = store.save_event(
|
||||||
|
item.event,
|
||||||
|
item.room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if column_id:
|
||||||
|
index.add_event(column_id, item.event, item.room_id)
|
||||||
|
index.commit()
|
||||||
|
|
||||||
|
async def commit_events(self):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
event_queue = self.event_queue
|
||||||
|
|
||||||
|
if not event_queue:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.event_queue = []
|
||||||
|
|
||||||
|
async with self.write_lock:
|
||||||
|
write_func = partial(
|
||||||
|
IndexStore.write_events,
|
||||||
|
self.store,
|
||||||
|
self.index,
|
||||||
|
event_queue
|
||||||
|
)
|
||||||
|
await loop.run_in_executor(None, write_func)
|
||||||
|
|
||||||
|
def event_in_store(self, event_id, room_id):
|
||||||
|
return self.store.event_in_store(event_id, room_id)
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
search_term, # type: str
|
||||||
|
room=None, # type: Optional[str]
|
||||||
|
max_results=10, # type: int
|
||||||
|
order_by_recent=False, # type: bool
|
||||||
|
include_profile=False, # type: bool
|
||||||
|
before_limit=0, # type: int
|
||||||
|
after_limit=0 # type: int
|
||||||
|
):
|
||||||
|
# type: (...) -> Dict[Any, Any]
|
||||||
|
"""Search the indexstore for an event."""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
# Getting a searcher from tantivy may block if there is no searcher
|
||||||
|
# available. To avoid blocking we set up the number of searchers to be
|
||||||
|
# the number of CPUs and the semaphore has the same counter value.
|
||||||
|
async with self.read_semaphore:
|
||||||
|
searcher = self.index.searcher()
|
||||||
|
search_func = partial(searcher.search, search_term, room=room,
|
||||||
|
max_results=max_results,
|
||||||
|
order_by_recent=order_by_recent)
|
||||||
|
|
||||||
|
result = await loop.run_in_executor(None, search_func)
|
||||||
|
|
||||||
|
load_event_func = partial(
|
||||||
|
self.store.load_events,
|
||||||
|
result,
|
||||||
|
include_profile,
|
||||||
|
order_by_recent,
|
||||||
|
before_limit,
|
||||||
|
after_limit
|
||||||
|
)
|
||||||
|
|
||||||
|
search_result = await loop.run_in_executor(None, load_event_func)
|
||||||
|
|
||||||
|
search_result["count"] = len(search_result["results"])
|
||||||
|
search_result["highlights"] = []
|
||||||
|
|
||||||
|
return search_result
|
||||||
|
|
|
@ -12,18 +12,16 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import datetime
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from nio import RoomMessage
|
|
||||||
from nio.store import (Accounts, DeviceKeys, DeviceTrustState, TrustState,
|
from nio.store import (Accounts, DeviceKeys, DeviceTrustState, TrustState,
|
||||||
use_database)
|
use_database)
|
||||||
from peewee import (SQL, DateTimeField, DoesNotExist, ForeignKeyField, Model,
|
from peewee import (SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
|
||||||
SqliteDatabase, TextField)
|
TextField)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
|
@ -70,41 +68,6 @@ class ServerUsers(Model):
|
||||||
constraints = [SQL("UNIQUE(user_id,server_id)")]
|
constraints = [SQL("UNIQUE(user_id,server_id)")]
|
||||||
|
|
||||||
|
|
||||||
class Profile(Model):
|
|
||||||
user_id = TextField()
|
|
||||||
avatar_url = TextField(null=True)
|
|
||||||
display_name = TextField(null=True)
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
constraints = [SQL("UNIQUE(user_id,avatar_url,display_name)")]
|
|
||||||
|
|
||||||
|
|
||||||
class Event(Model):
|
|
||||||
event_id = TextField()
|
|
||||||
sender = TextField()
|
|
||||||
date = DateTimeField()
|
|
||||||
room_id = TextField()
|
|
||||||
|
|
||||||
source = DictField()
|
|
||||||
|
|
||||||
profile = ForeignKeyField(
|
|
||||||
model=Profile,
|
|
||||||
column_name="profile_id",
|
|
||||||
)
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")]
|
|
||||||
|
|
||||||
|
|
||||||
class UserMessages(Model):
|
|
||||||
user = ForeignKeyField(
|
|
||||||
model=ServerUsers,
|
|
||||||
column_name="user_id")
|
|
||||||
event = ForeignKeyField(
|
|
||||||
model=Event,
|
|
||||||
column_name="event_id")
|
|
||||||
|
|
||||||
|
|
||||||
class PanSyncTokens(Model):
|
class PanSyncTokens(Model):
|
||||||
token = TextField()
|
token = TextField()
|
||||||
user = ForeignKeyField(
|
user = ForeignKeyField(
|
||||||
|
@ -146,9 +109,6 @@ class PanStore:
|
||||||
ServerUsers,
|
ServerUsers,
|
||||||
DeviceKeys,
|
DeviceKeys,
|
||||||
DeviceTrustState,
|
DeviceTrustState,
|
||||||
Profile,
|
|
||||||
Event,
|
|
||||||
UserMessages,
|
|
||||||
PanSyncTokens,
|
PanSyncTokens,
|
||||||
PanFetcherTasks
|
PanFetcherTasks
|
||||||
]
|
]
|
||||||
|
@ -242,86 +202,6 @@ class PanStore:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@use_database
|
|
||||||
def save_event(self, server, pan_user, event, room_id, display_name,
|
|
||||||
avatar_url):
|
|
||||||
# type: (str, str, str, RoomMessage, str, str, str) -> Optional[int]
|
|
||||||
"""Save an event to the store.
|
|
||||||
|
|
||||||
Returns the database id of the event.
|
|
||||||
"""
|
|
||||||
server = Servers.get(name=server)
|
|
||||||
user = ServerUsers.get(server=server, user_id=pan_user)
|
|
||||||
|
|
||||||
profile_id, _ = Profile.get_or_create(
|
|
||||||
user_id=event.sender,
|
|
||||||
display_name=display_name,
|
|
||||||
avatar_url=avatar_url
|
|
||||||
)
|
|
||||||
|
|
||||||
event_source = event.source
|
|
||||||
event_source["room_id"] = room_id
|
|
||||||
|
|
||||||
event_id = Event.insert(
|
|
||||||
event_id=event.event_id,
|
|
||||||
sender=event.sender,
|
|
||||||
date=datetime.datetime.fromtimestamp(
|
|
||||||
event.server_timestamp / 1000
|
|
||||||
),
|
|
||||||
room_id=room_id,
|
|
||||||
source=event_source,
|
|
||||||
profile=profile_id
|
|
||||||
).on_conflict_ignore().execute()
|
|
||||||
|
|
||||||
if event_id <= 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
_, created = UserMessages.get_or_create(
|
|
||||||
user=user,
|
|
||||||
event=event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if created:
|
|
||||||
return event_id
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@use_database
|
|
||||||
def load_event_by_columns(
|
|
||||||
self,
|
|
||||||
server, # type: str
|
|
||||||
pan_user, # type: str
|
|
||||||
column, # type: List[int]
|
|
||||||
include_profile=False # type: bool
|
|
||||||
):
|
|
||||||
# type: (...) -> Optional[Dict]
|
|
||||||
server = Servers.get(name=server)
|
|
||||||
user = ServerUsers.get(server=server, user_id=pan_user)
|
|
||||||
|
|
||||||
message = UserMessages.get_or_none(user=user, event=column)
|
|
||||||
|
|
||||||
if not message:
|
|
||||||
return None
|
|
||||||
|
|
||||||
event = message.event
|
|
||||||
|
|
||||||
event_dict = {
|
|
||||||
"result": event.source,
|
|
||||||
"context": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
if include_profile:
|
|
||||||
event_profile = event.profile
|
|
||||||
|
|
||||||
event_dict["context"]["profile_info"] = {
|
|
||||||
event_profile.user_id: {
|
|
||||||
"display_name": event_profile.display_name,
|
|
||||||
"avatar_url": event_profile.avatar_url,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return event_dict
|
|
||||||
|
|
||||||
@use_database
|
@use_database
|
||||||
def save_server_user(self, server_name, user_id):
|
def save_server_user(self, server_name, user_id):
|
||||||
# type: (str, str) -> None
|
# type: (str, str) -> None
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
|
import asyncio
|
||||||
import pdb
|
import pdb
|
||||||
|
import pprint
|
||||||
|
|
||||||
from nio import RoomMessage
|
from nio import RoomMessage
|
||||||
|
|
||||||
from conftest import faker
|
from conftest import faker
|
||||||
from pantalaimon.index import Index
|
from pantalaimon.index import Index, IndexStore
|
||||||
from pantalaimon.store import FetchTask
|
from pantalaimon.store import FetchTask
|
||||||
|
|
||||||
TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost"
|
TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost"
|
||||||
|
@ -32,8 +34,8 @@ class TestClass(object):
|
||||||
return RoomMessage.parse_event(
|
return RoomMessage.parse_event(
|
||||||
{
|
{
|
||||||
"content": {"body": "Another message", "msgtype": "m.text"},
|
"content": {"body": "Another message", "msgtype": "m.text"},
|
||||||
"event_id": "$15163622445EBvZJ:localhost",
|
"event_id": "$15163622445EBvZK:localhost",
|
||||||
"origin_server_ts": 1516362244026,
|
"origin_server_ts": 1516362244030,
|
||||||
"room_id": "!SVkFJHzfwvuaIEawgC:localhost",
|
"room_id": "!SVkFJHzfwvuaIEawgC:localhost",
|
||||||
"sender": "@example2:localhost",
|
"sender": "@example2:localhost",
|
||||||
"type": "m.room.message",
|
"type": "m.room.message",
|
||||||
|
@ -58,66 +60,6 @@ class TestClass(object):
|
||||||
token = panstore.load_access_token(user_id, device_id)
|
token = panstore.load_access_token(user_id, device_id)
|
||||||
access_token == token
|
access_token == token
|
||||||
|
|
||||||
def test_event_storing(self, panstore_with_users):
|
|
||||||
panstore = panstore_with_users
|
|
||||||
accounts = panstore.load_all_users()
|
|
||||||
user, _ = accounts[0]
|
|
||||||
|
|
||||||
event = self.test_event
|
|
||||||
|
|
||||||
event_id = panstore.save_event("example", user, event, TEST_ROOM,
|
|
||||||
"Example2", None)
|
|
||||||
|
|
||||||
assert event_id == 1
|
|
||||||
|
|
||||||
event_id = panstore.save_event("example", user, event, TEST_ROOM,
|
|
||||||
"Example2", None)
|
|
||||||
assert event_id is None
|
|
||||||
|
|
||||||
event_dict = panstore.load_event_by_columns("example", user, 1)
|
|
||||||
assert event.source == event_dict["result"]
|
|
||||||
|
|
||||||
event_source = panstore.load_event_by_columns("example", user, 1, True)
|
|
||||||
|
|
||||||
assert event_source["context"]["profile_info"] == {
|
|
||||||
"@example2:localhost": {
|
|
||||||
"display_name": "Example2",
|
|
||||||
"avatar_url": None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_index(self, panstore_with_users):
|
|
||||||
panstore = panstore_with_users
|
|
||||||
accounts = panstore.load_all_users()
|
|
||||||
user, _ = accounts[0]
|
|
||||||
|
|
||||||
event = self.test_event
|
|
||||||
another_event = self.another_event
|
|
||||||
|
|
||||||
index = Index(panstore.store_path)
|
|
||||||
|
|
||||||
event_id = panstore.save_event("example", user, event, TEST_ROOM,
|
|
||||||
"Example2", None)
|
|
||||||
assert event_id == 1
|
|
||||||
index.add_event(event_id, event, TEST_ROOM)
|
|
||||||
|
|
||||||
event_id = panstore.save_event("example", user, another_event,
|
|
||||||
TEST_ROOM2, "Example2", None)
|
|
||||||
assert event_id == 2
|
|
||||||
index.add_event(event_id, another_event, TEST_ROOM2)
|
|
||||||
|
|
||||||
index.commit()
|
|
||||||
|
|
||||||
searcher = index.searcher()
|
|
||||||
|
|
||||||
searched_events = searcher.search("message", TEST_ROOM)
|
|
||||||
|
|
||||||
_, found_id = searched_events[0]
|
|
||||||
|
|
||||||
event_dict = panstore.load_event_by_columns("example", user, found_id)
|
|
||||||
|
|
||||||
assert event_dict["result"] == event.source
|
|
||||||
|
|
||||||
def test_token_storing(self, panstore_with_users):
|
def test_token_storing(self, panstore_with_users):
|
||||||
panstore = panstore_with_users
|
panstore = panstore_with_users
|
||||||
accounts = panstore.load_all_users()
|
accounts = panstore.load_all_users()
|
||||||
|
@ -151,3 +93,26 @@ class TestClass(object):
|
||||||
|
|
||||||
assert task not in tasks
|
assert task not in tasks
|
||||||
assert task2 in tasks
|
assert task2 in tasks
|
||||||
|
|
||||||
|
def test_new_indexstore(self, tempdir):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
store = IndexStore("example", tempdir)
|
||||||
|
|
||||||
|
store.add_event(self.test_event, TEST_ROOM, None, None)
|
||||||
|
store.add_event(self.another_event, TEST_ROOM, None, None)
|
||||||
|
loop.run_until_complete(store.commit_events())
|
||||||
|
|
||||||
|
assert store.event_in_store(self.test_event.event_id, TEST_ROOM)
|
||||||
|
assert not store.event_in_store("FAKE", TEST_ROOM)
|
||||||
|
|
||||||
|
result = loop.run_until_complete(
|
||||||
|
store.search("test", TEST_ROOM, after_limit=10, before_limit=10)
|
||||||
|
)
|
||||||
|
pprint.pprint(result)
|
||||||
|
|
||||||
|
assert len(result["results"]) == 1
|
||||||
|
assert result["count"] == 1
|
||||||
|
assert result["results"][0]["result"] == self.test_event.source
|
||||||
|
assert (result["results"][0]["context"]["events_after"][0]
|
||||||
|
== self.another_event.source)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue