index: Rewrite the search logic.

This patch moves all the indexing, event storing and searching into a
separate class.

The index and message store are now represented as a single class and
messages are indexed and stored atomically now which should minimize the
chance of store/index inconsistencies.

Messages are now loaded from the store as a single SQL query and the
context for the messages is as well loaded from the store instead of
fetched from the server.

The room state and start/end tokens for the context aren't currently
loaded.
This commit is contained in:
Damir Jelić 2019-06-14 14:53:25 +02:00
parent a489031962
commit 3a1b001244
5 changed files with 423 additions and 270 deletions

View File

@ -15,7 +15,6 @@
import asyncio
import os
from collections import defaultdict
from functools import partial
from pprint import pformat
from typing import Any, Dict, Optional
@ -30,7 +29,7 @@ from nio import (AsyncClient, ClientConfig, EncryptionError, KeysQueryResponse,
from nio.crypto import Sas
from nio.store import SqliteStore
from pantalaimon.index import Index
from pantalaimon.index import IndexStore
from pantalaimon.log import logger
from pantalaimon.store import FetchTask
from pantalaimon.thread_messages import (DaemonResponse, InviteSasSignal,
@ -137,7 +136,7 @@ class PanClient(AsyncClient):
self.server_name = server_name
self.pan_store = pan_store
self.index = Index(index_dir)
self.index = IndexStore(self.user_id, index_dir)
self.task = None
self.queue = queue
@ -179,20 +178,7 @@ class PanClient(AsyncClient):
display_name = room.user_name(event.sender)
avatar_url = room.avatar_url(event.sender)
column_id = self.pan_store.save_event(
self.server_name,
self.user_id,
event,
room.room_id,
display_name,
avatar_url
)
if column_id:
self.index.add_event(column_id, event, room.room_id)
return True
return False
self.index.add_event(event, room.room_id, display_name, avatar_url)
@property
def unable_to_decrypt(self):
@ -247,7 +233,8 @@ class PanClient(AsyncClient):
room.display_name
))
response = await self.room_messages(fetch_task.room_id,
fetch_task.token)
fetch_task.token,
limit=100)
except ClientConnectionError:
self.history_fetch_queue.put(fetch_task)
@ -266,10 +253,17 @@ class PanClient(AsyncClient):
)):
continue
if not self.store_message_cb(room, event):
# The event was already in our store, we catched up.
break
else:
display_name = room.user_name(event.sender)
avatar_url = room.avatar_url(event.sender)
self.index.add_event(event, room.room_id, display_name,
avatar_url)
last_event = response.chunk[-1]
if not self.index.event_in_store(
last_event.event_id,
room.room_id
):
# There may be even more events to fetch, add a new task to
# the queue.
task = FetchTask(room.room_id, response.end)
@ -277,6 +271,7 @@ class PanClient(AsyncClient):
self.user_id, task)
await self.history_fetch_queue.put(task)
await self.index.commit_events()
self.delete_fetcher_task(fetch_task)
except (asyncio.CancelledError, KeyboardInterrupt):
return
@ -292,8 +287,7 @@ class PanClient(AsyncClient):
self.key_verificatins_tasks = []
self.key_request_tasks = []
self.index.commit()
await self.index.commit_events()
self.pan_store.save_token(
self.server_name,
self.user_id,
@ -688,13 +682,11 @@ class PanClient(AsyncClient):
async def search(self, search_terms):
# type: (Dict[Any, Any]) -> Dict[Any, Any]
loop = asyncio.get_event_loop()
state_cache = dict()
async def add_context(room_id, event_id, before, after, include_state):
async def add_context(event_dict, room_id, event_id, include_state):
try:
context = await self.room_context(room_id, event_id,
limit=before+after)
context = await self.room_context(room_id, event_id, limit=0)
except ClientConnectionError:
return
@ -704,16 +696,8 @@ class PanClient(AsyncClient):
if include_state:
state_cache[room_id] = [e.source for e in context.state]
event_context = event_dict["context"]
event_context["events_before"] = [
e.source for e in context.events_before[:before]
]
event_context["events_after"] = [
e.source for e in context.events_after[:after]
]
event_context["start"] = context.start
event_context["end"] = context.end
event_dict["context"]["start"] = context.start
event_dict["context"]["end"] = context.end
validate_json(search_terms, SEARCH_TERMS_SCHEMA)
search_terms = search_terms["search_categories"]["room_events"]
@ -723,7 +707,7 @@ class PanClient(AsyncClient):
limit = search_filter.get("limit", 10)
if limit <= 0:
raise InvalidLimit(f"The limit must be strictly greater than 0.")
raise InvalidLimit("The limit must be strictly greater than 0.")
rooms = search_filter.get("rooms", [])
@ -734,7 +718,7 @@ class PanClient(AsyncClient):
if order_by not in ["rank", "recent"]:
raise InvalidOrderByError(f"Invalid order by: {order_by}")
order_by_date = order_by == "recent"
order_by_recent = order_by == "recent"
before_limit = 0
after_limit = 0
@ -746,51 +730,37 @@ class PanClient(AsyncClient):
if event_context:
before_limit = event_context.get("before_limit", 5)
after_limit = event_context.get("before_limit", 5)
include_profile = event_context.get("include_profile", False)
searcher = self.index.searcher()
search_func = partial(searcher.search, term, room=room_id,
max_results=limit, order_by_date=order_by_date)
if before_limit < 0 or after_limit < 0:
raise InvalidLimit("Invalid context limit, the limit must be a "
"positive number")
result = await loop.run_in_executor(None, search_func)
response_dict = await self.index.search(
term,
room=room_id,
max_results=limit,
order_by_recent=order_by_recent,
include_profile=include_profile,
before_limit=before_limit,
after_limit=after_limit
)
result_dict = {
"results": []
}
for score, column_id in result:
event_dict = self.pan_store.load_event_by_columns(
self.server_name,
self.user_id,
column_id, include_profile)
if not event_dict:
continue
if include_state or before_limit or after_limit:
await add_context(
event_dict["result"]["room_id"],
event_dict["result"]["event_id"],
before_limit,
after_limit,
include_state
)
if order_by_date:
event_dict["rank"] = 1.0
else:
event_dict["rank"] = score
result_dict["results"].append(event_dict)
result_dict["count"] = len(result_dict["results"])
result_dict["highlights"] = []
# TODO add the state and start/end tokens
# if event_context or include_state:
# for event_dict in response_dict:
# await add_context(
# event_dict["result"]["room_id"],
# event_dict["result"]["event_id"],
# 0,
# 0,
# include_state
# )
if include_state:
result_dict["state"] = state_cache
response_dict["state"] = state_cache
return {
"search_categories": {
"room_events": result_dict
"room_events": response_dict
}
}

View File

@ -28,8 +28,8 @@ from multidict import CIMultiDict
from nio import (Api, EncryptionError, LoginResponse, OlmTrustError,
SendRetryError)
from pantalaimon.client import (InvalidOrderByError, PanClient,
UnknownRoomError, InvalidLimit)
from pantalaimon.client import (InvalidLimit, InvalidOrderByError, PanClient,
UnknownRoomError)
from pantalaimon.log import logger
from pantalaimon.store import ClientInfo, PanStore
from pantalaimon.thread_messages import (AcceptSasMessage, CancelSasMessage,

View File

@ -12,11 +12,239 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import datetime
import json
import os
from functools import partial
from typing import Any, Dict, List, Optional, Tuple
import attr
import tantivy
from nio import (RoomEncryptedMedia, RoomMessageMedia, RoomMessageText,
RoomNameEvent, RoomTopicEvent)
from peewee import (SQL, DateTimeField, ForeignKeyField, Model, SqliteDatabase,
TextField)
from pantalaimon.store import use_database
class DictField(TextField):
def python_value(self, value): # pragma: no cover
return json.loads(value)
def db_value(self, value): # pragma: no cover
return json.dumps(value)
class StoreUser(Model):
user_id = TextField()
class Meta:
constraints = [SQL("UNIQUE(user_id)")]
class Profile(Model):
user_id = TextField()
avatar_url = TextField(null=True)
display_name = TextField(null=True)
class Meta:
constraints = [SQL("UNIQUE(user_id,avatar_url,display_name)")]
class Event(Model):
event_id = TextField()
sender = TextField()
date = DateTimeField()
room_id = TextField()
source = DictField()
profile = ForeignKeyField(
model=Profile,
column_name="profile_id",
)
class Meta:
constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")]
class UserMessages(Model):
user = ForeignKeyField(
model=StoreUser,
column_name="user_id")
event = ForeignKeyField(
model=Event,
column_name="event_id")
@attr.s
class MessageStore:
user = attr.ib(type=str)
store_path = attr.ib(type=str)
database_name = attr.ib(type=str)
database = attr.ib(type=SqliteDatabase, init=False)
database_path = attr.ib(type=str, init=False)
models = [
StoreUser,
Event,
Profile,
UserMessages
]
def __attrs_post_init__(self):
self.database_path = os.path.join(
os.path.abspath(self.store_path),
self.database_name
)
self.database = self._create_database()
self.database.connect()
with self.database.bind_ctx(self.models):
self.database.create_tables(self.models)
def _create_database(self):
return SqliteDatabase(
self.database_path,
pragmas={
"foreign_keys": 1,
"secure_delete": 1,
}
)
@use_database
def event_in_store(self, event_id, room_id):
user, _ = StoreUser.get_or_create(user_id=self.user)
query = Event.select().join(UserMessages).where(
(Event.room_id == room_id) &
(Event.event_id == event_id) &
(UserMessages.user == user)
).execute()
for _ in query:
return True
return False
def save_event(self, event, room_id, display_name=None, avatar_url=None):
user, _ = StoreUser.get_or_create(user_id=self.user)
profile_id, _ = Profile.get_or_create(
user_id=event.sender,
display_name=display_name,
avatar_url=avatar_url
)
event_source = event.source
event_source["room_id"] = room_id
event_id = Event.insert(
event_id=event.event_id,
sender=event.sender,
date=datetime.datetime.fromtimestamp(
event.server_timestamp / 1000
),
room_id=room_id,
source=event_source,
profile=profile_id
).on_conflict_ignore().execute()
if event_id <= 0:
return None
_, created = UserMessages.get_or_create(
user=user,
event=event_id,
)
if created:
return event_id
return None
def _load_context(self, user, event, before, after):
context = {}
if before > 0:
query = Event.select().join(UserMessages).where(
(Event.date <= event.date) &
(Event.room_id == event.room_id) &
(Event.id != event.id) &
(UserMessages.user == user)
).order_by(Event.date.desc()).limit(before)
context["events_before"] = [e.source for e in query]
else:
context["events_before"] = []
if after > 0:
query = Event.select().join(UserMessages).where(
(Event.date >= event.date) &
(Event.room_id == event.room_id) &
(Event.id != event.id) &
(UserMessages.user == user)
).order_by(Event.date).limit(after)
context["events_after"] = [e.source for e in query]
else:
context["events_after"] = []
return context
@use_database
def load_events(
self,
search_result, # type: List[Tuple[int, int]]
include_profile=False, # type: bool
order_by_recent=False, # type: bool
before=0, # type: int
after=0 # type: int
):
# type: (...) -> Dict[Any, Any]
user, _ = StoreUser.get_or_create(user_id=self.user)
search_dict = {r[1]: r[0] for r in search_result}
columns = list(search_dict.keys())
result_dict = {
"results": []
}
query = UserMessages.select().where(
(UserMessages.user_id == user) & (UserMessages.event.in_(columns))
).execute()
for message in query:
event = message.event
event_dict = {
"rank": 1 if order_by_recent else search_dict[event.id],
"result": event.source,
"context": {}
}
if include_profile:
event_profile = event.profile
event_dict["context"]["profile_info"] = {
event_profile.user_id: {
"display_name": event_profile.display_name,
"avatar_url": event_profile.avatar_url,
}
}
context = self._load_context(user, event, before, after)
event_dict["context"]["events_before"] = context["events_before"]
event_dict["context"]["events_after"] = context["events_after"]
result_dict["results"].append(event_dict)
return result_dict
def sanitize_room_id(room_id):
@ -37,7 +265,7 @@ class Searcher:
self.timestamp_field = timestamp_field
def search(self, search_term, room=None, max_results=10,
order_by_date=False):
order_by_recent=False):
# type (str, str, int, bool) -> List[int, int]
"""Search for events in the index.
@ -63,7 +291,7 @@ class Searcher:
query = queryparser.parse_query(search_term)
if order_by_date:
if order_by_recent:
collector = tantivy.TopDocs(max_results,
order_by_field=self.timestamp_field)
else:
@ -82,7 +310,7 @@ class Searcher:
class Index:
def __init__(self, path=None):
def __init__(self, path=None, num_searchers=None):
schema_builder = tantivy.SchemaBuilder()
self.body_field = schema_builder.add_text_field("body")
@ -108,7 +336,7 @@ class Index:
self.index = tantivy.Index(schema, path)
self.reader = self.index.reader()
self.reader = self.index.reader(num_searchers=num_searchers)
self.writer = self.index.writer()
def add_event(self, column_id, event, room_id):
@ -154,3 +382,113 @@ class Index:
self.timestamp_field,
self.reader.searcher()
)
@attr.s
class StoreItem:
event = attr.ib()
room_id = attr.ib()
display_name = attr.ib(default=None)
avatar_url = attr.ib(default=None)
@attr.s
class IndexStore:
user = attr.ib(type=str)
index_path = attr.ib(type=str)
store_path = attr.ib(type=str, default=None)
store_name = attr.ib(default="events.db")
index = attr.ib(type=Index, init=False)
store = attr.ib(type=MessageStore, init=False)
event_queue = attr.ib(factory=list)
write_lock = attr.ib(factory=asyncio.Lock)
read_semaphore = attr.ib(type=asyncio.Semaphore, init=False)
def __attrs_post_init__(self):
self.store_path = self.store_path or self.index_path
num_searchers = os.cpu_count()
self.index = Index(self.index_path, num_searchers)
self.read_semaphore = asyncio.Semaphore(num_searchers or 1)
self.store = MessageStore(self.user, self.store_path, self.store_name)
def add_event(self, event, room_id, display_name, avatar_url):
item = StoreItem(event, room_id, display_name, avatar_url)
self.event_queue.append(item)
@staticmethod
def write_events(store, index, event_queue):
with store.database.bind_ctx(store.models):
with store.database.atomic():
for item in event_queue:
column_id = store.save_event(
item.event,
item.room_id,
)
if column_id:
index.add_event(column_id, item.event, item.room_id)
index.commit()
async def commit_events(self):
loop = asyncio.get_event_loop()
event_queue = self.event_queue
if not event_queue:
return
self.event_queue = []
async with self.write_lock:
write_func = partial(
IndexStore.write_events,
self.store,
self.index,
event_queue
)
await loop.run_in_executor(None, write_func)
def event_in_store(self, event_id, room_id):
return self.store.event_in_store(event_id, room_id)
async def search(
self,
search_term, # type: str
room=None, # type: Optional[str]
max_results=10, # type: int
order_by_recent=False, # type: bool
include_profile=False, # type: bool
before_limit=0, # type: int
after_limit=0 # type: int
):
# type: (...) -> Dict[Any, Any]
"""Search the indexstore for an event."""
loop = asyncio.get_event_loop()
# Getting a searcher from tantivy may block if there is no searcher
# available. To avoid blocking we set up the number of searchers to be
# the number of CPUs and the semaphore has the same counter value.
async with self.read_semaphore:
searcher = self.index.searcher()
search_func = partial(searcher.search, search_term, room=room,
max_results=max_results,
order_by_recent=order_by_recent)
result = await loop.run_in_executor(None, search_func)
load_event_func = partial(
self.store.load_events,
result,
include_profile,
order_by_recent,
before_limit,
after_limit
)
search_result = await loop.run_in_executor(None, load_event_func)
search_result["count"] = len(search_result["results"])
search_result["highlights"] = []
return search_result

View File

@ -12,18 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import json
import os
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import attr
from nio import RoomMessage
from nio.store import (Accounts, DeviceKeys, DeviceTrustState, TrustState,
use_database)
from peewee import (SQL, DateTimeField, DoesNotExist, ForeignKeyField, Model,
SqliteDatabase, TextField)
from peewee import (SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
TextField)
@attr.s
@ -70,41 +68,6 @@ class ServerUsers(Model):
constraints = [SQL("UNIQUE(user_id,server_id)")]
class Profile(Model):
user_id = TextField()
avatar_url = TextField(null=True)
display_name = TextField(null=True)
class Meta:
constraints = [SQL("UNIQUE(user_id,avatar_url,display_name)")]
class Event(Model):
event_id = TextField()
sender = TextField()
date = DateTimeField()
room_id = TextField()
source = DictField()
profile = ForeignKeyField(
model=Profile,
column_name="profile_id",
)
class Meta:
constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")]
class UserMessages(Model):
user = ForeignKeyField(
model=ServerUsers,
column_name="user_id")
event = ForeignKeyField(
model=Event,
column_name="event_id")
class PanSyncTokens(Model):
token = TextField()
user = ForeignKeyField(
@ -146,9 +109,6 @@ class PanStore:
ServerUsers,
DeviceKeys,
DeviceTrustState,
Profile,
Event,
UserMessages,
PanSyncTokens,
PanFetcherTasks
]
@ -242,86 +202,6 @@ class PanStore:
return None
@use_database
def save_event(self, server, pan_user, event, room_id, display_name,
avatar_url):
# type: (str, str, str, RoomMessage, str, str, str) -> Optional[int]
"""Save an event to the store.
Returns the database id of the event.
"""
server = Servers.get(name=server)
user = ServerUsers.get(server=server, user_id=pan_user)
profile_id, _ = Profile.get_or_create(
user_id=event.sender,
display_name=display_name,
avatar_url=avatar_url
)
event_source = event.source
event_source["room_id"] = room_id
event_id = Event.insert(
event_id=event.event_id,
sender=event.sender,
date=datetime.datetime.fromtimestamp(
event.server_timestamp / 1000
),
room_id=room_id,
source=event_source,
profile=profile_id
).on_conflict_ignore().execute()
if event_id <= 0:
return None
_, created = UserMessages.get_or_create(
user=user,
event=event_id,
)
if created:
return event_id
return None
@use_database
def load_event_by_columns(
self,
server, # type: str
pan_user, # type: str
column, # type: List[int]
include_profile=False # type: bool
):
# type: (...) -> Optional[Dict]
server = Servers.get(name=server)
user = ServerUsers.get(server=server, user_id=pan_user)
message = UserMessages.get_or_none(user=user, event=column)
if not message:
return None
event = message.event
event_dict = {
"result": event.source,
"context": {}
}
if include_profile:
event_profile = event.profile
event_dict["context"]["profile_info"] = {
event_profile.user_id: {
"display_name": event_profile.display_name,
"avatar_url": event_profile.avatar_url,
}
}
return event_dict
@use_database
def save_server_user(self, server_name, user_id):
# type: (str, str) -> None

View File

@ -1,9 +1,11 @@
import asyncio
import pdb
import pprint
from nio import RoomMessage
from conftest import faker
from pantalaimon.index import Index
from pantalaimon.index import Index, IndexStore
from pantalaimon.store import FetchTask
TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost"
@ -32,8 +34,8 @@ class TestClass(object):
return RoomMessage.parse_event(
{
"content": {"body": "Another message", "msgtype": "m.text"},
"event_id": "$15163622445EBvZJ:localhost",
"origin_server_ts": 1516362244026,
"event_id": "$15163622445EBvZK:localhost",
"origin_server_ts": 1516362244030,
"room_id": "!SVkFJHzfwvuaIEawgC:localhost",
"sender": "@example2:localhost",
"type": "m.room.message",
@ -58,66 +60,6 @@ class TestClass(object):
token = panstore.load_access_token(user_id, device_id)
access_token == token
def test_event_storing(self, panstore_with_users):
panstore = panstore_with_users
accounts = panstore.load_all_users()
user, _ = accounts[0]
event = self.test_event
event_id = panstore.save_event("example", user, event, TEST_ROOM,
"Example2", None)
assert event_id == 1
event_id = panstore.save_event("example", user, event, TEST_ROOM,
"Example2", None)
assert event_id is None
event_dict = panstore.load_event_by_columns("example", user, 1)
assert event.source == event_dict["result"]
event_source = panstore.load_event_by_columns("example", user, 1, True)
assert event_source["context"]["profile_info"] == {
"@example2:localhost": {
"display_name": "Example2",
"avatar_url": None
}
}
def test_index(self, panstore_with_users):
panstore = panstore_with_users
accounts = panstore.load_all_users()
user, _ = accounts[0]
event = self.test_event
another_event = self.another_event
index = Index(panstore.store_path)
event_id = panstore.save_event("example", user, event, TEST_ROOM,
"Example2", None)
assert event_id == 1
index.add_event(event_id, event, TEST_ROOM)
event_id = panstore.save_event("example", user, another_event,
TEST_ROOM2, "Example2", None)
assert event_id == 2
index.add_event(event_id, another_event, TEST_ROOM2)
index.commit()
searcher = index.searcher()
searched_events = searcher.search("message", TEST_ROOM)
_, found_id = searched_events[0]
event_dict = panstore.load_event_by_columns("example", user, found_id)
assert event_dict["result"] == event.source
def test_token_storing(self, panstore_with_users):
panstore = panstore_with_users
accounts = panstore.load_all_users()
@ -151,3 +93,26 @@ class TestClass(object):
assert task not in tasks
assert task2 in tasks
def test_new_indexstore(self, tempdir):
loop = asyncio.get_event_loop()
store = IndexStore("example", tempdir)
store.add_event(self.test_event, TEST_ROOM, None, None)
store.add_event(self.another_event, TEST_ROOM, None, None)
loop.run_until_complete(store.commit_events())
assert store.event_in_store(self.test_event.event_id, TEST_ROOM)
assert not store.event_in_store("FAKE", TEST_ROOM)
result = loop.run_until_complete(
store.search("test", TEST_ROOM, after_limit=10, before_limit=10)
)
pprint.pprint(result)
assert len(result["results"]) == 1
assert result["count"] == 1
assert result["results"][0]["result"] == self.test_event.source
assert (result["results"][0]["context"]["events_after"][0]
== self.another_event.source)