mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-02-09 03:18:36 -05:00
index: Make the index optional.
This patch disables the message indexer an message store if tantivy is not installed. The search API endpoint forwards all search requests to the homeserver.
This commit is contained in:
parent
fa1c2bb694
commit
53d1a2945d
@ -43,7 +43,7 @@ from nio import (
|
||||
from nio.crypto import Sas
|
||||
from nio.store import SqliteStore
|
||||
|
||||
from pantalaimon.index import IndexStore
|
||||
from pantalaimon.index import INDEXING_ENABLED
|
||||
from pantalaimon.log import logger
|
||||
from pantalaimon.store import FetchTask
|
||||
from pantalaimon.thread_messages import (
|
||||
@ -151,7 +151,16 @@ class PanClient(AsyncClient):
|
||||
self.server_name = server_name
|
||||
self.pan_store = pan_store
|
||||
self.pan_conf = pan_conf
|
||||
self.index = IndexStore(self.user_id, index_dir)
|
||||
|
||||
if INDEXING_ENABLED:
|
||||
logger.info("Indexing enabled.")
|
||||
from pantalaimon.index import IndexStore
|
||||
|
||||
self.index = IndexStore(self.user_id, index_dir)
|
||||
else:
|
||||
logger.info("Indexing disabled.")
|
||||
self.index = None
|
||||
|
||||
self.task = None
|
||||
self.queue = queue
|
||||
|
||||
@ -159,26 +168,32 @@ class PanClient(AsyncClient):
|
||||
|
||||
self.send_semaphores = defaultdict(asyncio.Semaphore)
|
||||
self.send_decision_queues = dict() # type: asyncio.Queue
|
||||
self.last_sync_token = None
|
||||
|
||||
self.history_fetcher_task = None
|
||||
self.history_fetch_queue = asyncio.Queue()
|
||||
|
||||
self.add_to_device_callback(self.key_verification_cb, KeyVerificationEvent)
|
||||
self.add_event_callback(self.undecrypted_event_cb, MegolmEvent)
|
||||
self.add_event_callback(
|
||||
self.store_message_cb,
|
||||
(
|
||||
RoomMessageText,
|
||||
RoomMessageMedia,
|
||||
RoomEncryptedMedia,
|
||||
RoomTopicEvent,
|
||||
RoomNameEvent,
|
||||
),
|
||||
)
|
||||
|
||||
if INDEXING_ENABLED:
|
||||
self.add_event_callback(
|
||||
self.store_message_cb,
|
||||
(
|
||||
RoomMessageText,
|
||||
RoomMessageMedia,
|
||||
RoomEncryptedMedia,
|
||||
RoomTopicEvent,
|
||||
RoomNameEvent,
|
||||
),
|
||||
)
|
||||
|
||||
self.add_response_callback(self.keys_query_cb, KeysQueryResponse)
|
||||
self.add_response_callback(self.sync_tasks, SyncResponse)
|
||||
|
||||
def store_message_cb(self, room, event):
|
||||
assert INDEXING_ENABLED
|
||||
|
||||
display_name = room.user_name(event.sender)
|
||||
avatar_url = room.avatar_url(event.sender)
|
||||
|
||||
@ -233,6 +248,8 @@ class PanClient(AsyncClient):
|
||||
self.pan_store.delete_fetcher_task(self.server_name, self.user_id, task)
|
||||
|
||||
async def fetcher_loop(self):
|
||||
assert INDEXING_ENABLED
|
||||
|
||||
for t in self.pan_store.load_fetcher_tasks(self.server_name, self.user_id):
|
||||
await self.history_fetch_queue.put(t)
|
||||
|
||||
@ -300,7 +317,9 @@ class PanClient(AsyncClient):
|
||||
return
|
||||
|
||||
async def sync_tasks(self, response):
|
||||
await self.index.commit_events()
|
||||
if self.index:
|
||||
await self.index.commit_events()
|
||||
|
||||
self.pan_store.save_token(self.server_name, self.user_id, self.next_batch)
|
||||
|
||||
for room_id, room_info in response.rooms.join.items():
|
||||
@ -392,7 +411,8 @@ class PanClient(AsyncClient):
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
self.history_fetcher_task = loop.create_task(self.fetcher_loop())
|
||||
if INDEXING_ENABLED:
|
||||
self.history_fetcher_task = loop.create_task(self.fetcher_loop())
|
||||
|
||||
timeout = 30000
|
||||
sync_filter = {"room": {"state": {"lazy_load_members": True}}}
|
||||
@ -671,6 +691,8 @@ class PanClient(AsyncClient):
|
||||
|
||||
async def search(self, search_terms):
|
||||
# type: (Dict[Any, Any]) -> Dict[Any, Any]
|
||||
assert INDEXING_ENABLED
|
||||
|
||||
state_cache = dict()
|
||||
|
||||
async def add_context(event_dict, room_id, event_id, include_state):
|
||||
|
@ -35,7 +35,7 @@ from pantalaimon.client import (
|
||||
UnknownRoomError,
|
||||
validate_json,
|
||||
)
|
||||
from pantalaimon.index import InvalidQueryError
|
||||
from pantalaimon.index import INDEXING_ENABLED, InvalidQueryError
|
||||
from pantalaimon.log import logger
|
||||
from pantalaimon.store import ClientInfo, PanStore
|
||||
from pantalaimon.thread_messages import (
|
||||
@ -905,6 +905,9 @@ class ProxyDaemon:
|
||||
if not access_token:
|
||||
return self._missing_token
|
||||
|
||||
if not INDEXING_ENABLED:
|
||||
return await self.forward_to_web(request)
|
||||
|
||||
client = await self._find_client(access_token)
|
||||
|
||||
if not client:
|
||||
|
@ -12,487 +12,497 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
import tantivy
|
||||
from nio import (
|
||||
RoomEncryptedMedia,
|
||||
RoomMessageMedia,
|
||||
RoomMessageText,
|
||||
RoomNameEvent,
|
||||
RoomTopicEvent,
|
||||
)
|
||||
from peewee import SQL, DateTimeField, ForeignKeyField, Model, SqliteDatabase, TextField
|
||||
|
||||
from pantalaimon.store import use_database
|
||||
from importlib import util
|
||||
|
||||
|
||||
class InvalidQueryError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DictField(TextField):
|
||||
def python_value(self, value): # pragma: no cover
|
||||
return json.loads(value)
|
||||
if util.find_spec("tantivy"):
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
def db_value(self, value): # pragma: no cover
|
||||
return json.dumps(value)
|
||||
import attr
|
||||
import tantivy
|
||||
from nio import (
|
||||
RoomEncryptedMedia,
|
||||
RoomMessageMedia,
|
||||
RoomMessageText,
|
||||
RoomNameEvent,
|
||||
RoomTopicEvent,
|
||||
)
|
||||
from peewee import (
|
||||
SQL,
|
||||
DateTimeField,
|
||||
ForeignKeyField,
|
||||
Model,
|
||||
SqliteDatabase,
|
||||
TextField,
|
||||
)
|
||||
|
||||
from pantalaimon.store import use_database
|
||||
|
||||
class StoreUser(Model):
|
||||
user_id = TextField()
|
||||
INDEXING_ENABLED = True
|
||||
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(user_id)")]
|
||||
class DictField(TextField):
|
||||
def python_value(self, value): # pragma: no cover
|
||||
return json.loads(value)
|
||||
|
||||
def db_value(self, value): # pragma: no cover
|
||||
return json.dumps(value)
|
||||
|
||||
class Profile(Model):
|
||||
user_id = TextField()
|
||||
avatar_url = TextField(null=True)
|
||||
display_name = TextField(null=True)
|
||||
class StoreUser(Model):
|
||||
user_id = TextField()
|
||||
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(user_id,avatar_url,display_name)")]
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(user_id)")]
|
||||
|
||||
class Profile(Model):
|
||||
user_id = TextField()
|
||||
avatar_url = TextField(null=True)
|
||||
display_name = TextField(null=True)
|
||||
|
||||
class Event(Model):
|
||||
event_id = TextField()
|
||||
sender = TextField()
|
||||
date = DateTimeField()
|
||||
room_id = TextField()
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(user_id,avatar_url,display_name)")]
|
||||
|
||||
source = DictField()
|
||||
class Event(Model):
|
||||
event_id = TextField()
|
||||
sender = TextField()
|
||||
date = DateTimeField()
|
||||
room_id = TextField()
|
||||
|
||||
profile = ForeignKeyField(model=Profile, column_name="profile_id")
|
||||
source = DictField()
|
||||
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")]
|
||||
profile = ForeignKeyField(model=Profile, column_name="profile_id")
|
||||
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")]
|
||||
|
||||
class UserMessages(Model):
|
||||
user = ForeignKeyField(model=StoreUser, column_name="user_id")
|
||||
event = ForeignKeyField(model=Event, column_name="event_id")
|
||||
class UserMessages(Model):
|
||||
user = ForeignKeyField(model=StoreUser, column_name="user_id")
|
||||
event = ForeignKeyField(model=Event, column_name="event_id")
|
||||
|
||||
@attr.s
|
||||
class MessageStore:
|
||||
user = attr.ib(type=str)
|
||||
store_path = attr.ib(type=str)
|
||||
database_name = attr.ib(type=str)
|
||||
database = attr.ib(type=SqliteDatabase, init=False)
|
||||
database_path = attr.ib(type=str, init=False)
|
||||
|
||||
@attr.s
|
||||
class MessageStore:
|
||||
user = attr.ib(type=str)
|
||||
store_path = attr.ib(type=str)
|
||||
database_name = attr.ib(type=str)
|
||||
database = attr.ib(type=SqliteDatabase, init=False)
|
||||
database_path = attr.ib(type=str, init=False)
|
||||
models = [StoreUser, Event, Profile, UserMessages]
|
||||
|
||||
models = [StoreUser, Event, Profile, UserMessages]
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.database_path = os.path.join(
|
||||
os.path.abspath(self.store_path), self.database_name
|
||||
)
|
||||
|
||||
self.database = self._create_database()
|
||||
self.database.connect()
|
||||
|
||||
with self.database.bind_ctx(self.models):
|
||||
self.database.create_tables(self.models)
|
||||
|
||||
def _create_database(self):
|
||||
return SqliteDatabase(
|
||||
self.database_path, pragmas={"foreign_keys": 1, "secure_delete": 1}
|
||||
)
|
||||
|
||||
@use_database
|
||||
def event_in_store(self, event_id, room_id):
|
||||
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||
query = (
|
||||
Event.select()
|
||||
.join(UserMessages)
|
||||
.where(
|
||||
(Event.room_id == room_id)
|
||||
& (Event.event_id == event_id)
|
||||
& (UserMessages.user == user)
|
||||
def __attrs_post_init__(self):
|
||||
self.database_path = os.path.join(
|
||||
os.path.abspath(self.store_path), self.database_name
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
for _ in query:
|
||||
return True
|
||||
self.database = self._create_database()
|
||||
self.database.connect()
|
||||
|
||||
return False
|
||||
with self.database.bind_ctx(self.models):
|
||||
self.database.create_tables(self.models)
|
||||
|
||||
def save_event(self, event, room_id, display_name=None, avatar_url=None):
|
||||
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||
|
||||
profile_id, _ = Profile.get_or_create(
|
||||
user_id=event.sender, display_name=display_name, avatar_url=avatar_url
|
||||
)
|
||||
|
||||
event_source = event.source
|
||||
event_source["room_id"] = room_id
|
||||
|
||||
event_id = (
|
||||
Event.insert(
|
||||
event_id=event.event_id,
|
||||
sender=event.sender,
|
||||
date=datetime.datetime.fromtimestamp(event.server_timestamp / 1000),
|
||||
room_id=room_id,
|
||||
source=event_source,
|
||||
profile=profile_id,
|
||||
def _create_database(self):
|
||||
return SqliteDatabase(
|
||||
self.database_path, pragmas={"foreign_keys": 1, "secure_delete": 1}
|
||||
)
|
||||
.on_conflict_ignore()
|
||||
.execute()
|
||||
)
|
||||
|
||||
if event_id <= 0:
|
||||
@use_database
|
||||
def event_in_store(self, event_id, room_id):
|
||||
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||
query = (
|
||||
Event.select()
|
||||
.join(UserMessages)
|
||||
.where(
|
||||
(Event.room_id == room_id)
|
||||
& (Event.event_id == event_id)
|
||||
& (UserMessages.user == user)
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
for _ in query:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def save_event(self, event, room_id, display_name=None, avatar_url=None):
|
||||
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||
|
||||
profile_id, _ = Profile.get_or_create(
|
||||
user_id=event.sender, display_name=display_name, avatar_url=avatar_url
|
||||
)
|
||||
|
||||
event_source = event.source
|
||||
event_source["room_id"] = room_id
|
||||
|
||||
event_id = (
|
||||
Event.insert(
|
||||
event_id=event.event_id,
|
||||
sender=event.sender,
|
||||
date=datetime.datetime.fromtimestamp(event.server_timestamp / 1000),
|
||||
room_id=room_id,
|
||||
source=event_source,
|
||||
profile=profile_id,
|
||||
)
|
||||
.on_conflict_ignore()
|
||||
.execute()
|
||||
)
|
||||
|
||||
if event_id <= 0:
|
||||
return None
|
||||
|
||||
_, created = UserMessages.get_or_create(user=user, event=event_id)
|
||||
|
||||
if created:
|
||||
return event_id
|
||||
|
||||
return None
|
||||
|
||||
_, created = UserMessages.get_or_create(user=user, event=event_id)
|
||||
def _load_context(self, user, event, before, after):
|
||||
context = {}
|
||||
|
||||
if created:
|
||||
return event_id
|
||||
|
||||
return None
|
||||
|
||||
def _load_context(self, user, event, before, after):
|
||||
context = {}
|
||||
|
||||
if before > 0:
|
||||
query = (
|
||||
Event.select()
|
||||
.join(UserMessages)
|
||||
.where(
|
||||
(Event.date <= event.date)
|
||||
& (Event.room_id == event.room_id)
|
||||
& (Event.id != event.id)
|
||||
& (UserMessages.user == user)
|
||||
if before > 0:
|
||||
query = (
|
||||
Event.select()
|
||||
.join(UserMessages)
|
||||
.where(
|
||||
(Event.date <= event.date)
|
||||
& (Event.room_id == event.room_id)
|
||||
& (Event.id != event.id)
|
||||
& (UserMessages.user == user)
|
||||
)
|
||||
.order_by(Event.date.desc())
|
||||
.limit(before)
|
||||
)
|
||||
.order_by(Event.date.desc())
|
||||
.limit(before)
|
||||
|
||||
context["events_before"] = [e.source for e in query]
|
||||
else:
|
||||
context["events_before"] = []
|
||||
|
||||
if after > 0:
|
||||
query = (
|
||||
Event.select()
|
||||
.join(UserMessages)
|
||||
.where(
|
||||
(Event.date >= event.date)
|
||||
& (Event.room_id == event.room_id)
|
||||
& (Event.id != event.id)
|
||||
& (UserMessages.user == user)
|
||||
)
|
||||
.order_by(Event.date)
|
||||
.limit(after)
|
||||
)
|
||||
|
||||
context["events_after"] = [e.source for e in query]
|
||||
else:
|
||||
context["events_after"] = []
|
||||
|
||||
return context
|
||||
|
||||
@use_database
|
||||
def load_events(
|
||||
self,
|
||||
search_result, # type: List[Tuple[int, int]]
|
||||
include_profile=False, # type: bool
|
||||
order_by_recent=False, # type: bool
|
||||
before=0, # type: int
|
||||
after=0, # type: int
|
||||
):
|
||||
# type: (...) -> Dict[Any, Any]
|
||||
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||
|
||||
search_dict = {r[1]: r[0] for r in search_result}
|
||||
columns = list(search_dict.keys())
|
||||
|
||||
result_dict = {"results": []}
|
||||
|
||||
query = (
|
||||
UserMessages.select()
|
||||
.where(
|
||||
(UserMessages.user_id == user) & (UserMessages.event.in_(columns))
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
context["events_before"] = [e.source for e in query]
|
||||
else:
|
||||
context["events_before"] = []
|
||||
for message in query:
|
||||
|
||||
if after > 0:
|
||||
query = (
|
||||
Event.select()
|
||||
.join(UserMessages)
|
||||
.where(
|
||||
(Event.date >= event.date)
|
||||
& (Event.room_id == event.room_id)
|
||||
& (Event.id != event.id)
|
||||
& (UserMessages.user == user)
|
||||
)
|
||||
.order_by(Event.date)
|
||||
.limit(after)
|
||||
)
|
||||
event = message.event
|
||||
|
||||
context["events_after"] = [e.source for e in query]
|
||||
else:
|
||||
context["events_after"] = []
|
||||
|
||||
return context
|
||||
|
||||
@use_database
|
||||
def load_events(
|
||||
self,
|
||||
search_result, # type: List[Tuple[int, int]]
|
||||
include_profile=False, # type: bool
|
||||
order_by_recent=False, # type: bool
|
||||
before=0, # type: int
|
||||
after=0, # type: int
|
||||
):
|
||||
# type: (...) -> Dict[Any, Any]
|
||||
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||
|
||||
search_dict = {r[1]: r[0] for r in search_result}
|
||||
columns = list(search_dict.keys())
|
||||
|
||||
result_dict = {"results": []}
|
||||
|
||||
query = (
|
||||
UserMessages.select()
|
||||
.where((UserMessages.user_id == user) & (UserMessages.event.in_(columns)))
|
||||
.execute()
|
||||
)
|
||||
|
||||
for message in query:
|
||||
|
||||
event = message.event
|
||||
|
||||
event_dict = {
|
||||
"rank": 1 if order_by_recent else search_dict[event.id],
|
||||
"result": event.source,
|
||||
"context": {},
|
||||
}
|
||||
|
||||
if include_profile:
|
||||
event_profile = event.profile
|
||||
|
||||
event_dict["context"]["profile_info"] = {
|
||||
event_profile.user_id: {
|
||||
"display_name": event_profile.display_name,
|
||||
"avatar_url": event_profile.avatar_url,
|
||||
}
|
||||
event_dict = {
|
||||
"rank": 1 if order_by_recent else search_dict[event.id],
|
||||
"result": event.source,
|
||||
"context": {},
|
||||
}
|
||||
|
||||
context = self._load_context(user, event, before, after)
|
||||
if include_profile:
|
||||
event_profile = event.profile
|
||||
|
||||
event_dict["context"]["events_before"] = context["events_before"]
|
||||
event_dict["context"]["events_after"] = context["events_after"]
|
||||
event_dict["context"]["profile_info"] = {
|
||||
event_profile.user_id: {
|
||||
"display_name": event_profile.display_name,
|
||||
"avatar_url": event_profile.avatar_url,
|
||||
}
|
||||
}
|
||||
|
||||
result_dict["results"].append(event_dict)
|
||||
context = self._load_context(user, event, before, after)
|
||||
|
||||
return result_dict
|
||||
event_dict["context"]["events_before"] = context["events_before"]
|
||||
event_dict["context"]["events_after"] = context["events_after"]
|
||||
|
||||
result_dict["results"].append(event_dict)
|
||||
|
||||
def sanitize_room_id(room_id):
|
||||
return room_id.replace(":", "/").replace("!", "")
|
||||
return result_dict
|
||||
|
||||
def sanitize_room_id(room_id):
|
||||
return room_id.replace(":", "/").replace("!", "")
|
||||
|
||||
class Searcher:
|
||||
def __init__(
|
||||
self,
|
||||
index,
|
||||
body_field,
|
||||
name_field,
|
||||
topic_field,
|
||||
column_field,
|
||||
room_field,
|
||||
timestamp_field,
|
||||
searcher,
|
||||
):
|
||||
self._index = index
|
||||
self._searcher = searcher
|
||||
class Searcher:
|
||||
def __init__(
|
||||
self,
|
||||
index,
|
||||
body_field,
|
||||
name_field,
|
||||
topic_field,
|
||||
column_field,
|
||||
room_field,
|
||||
timestamp_field,
|
||||
searcher,
|
||||
):
|
||||
self._index = index
|
||||
self._searcher = searcher
|
||||
|
||||
self.body_field = body_field
|
||||
self.name_field = topic_field
|
||||
self.topic_field = name_field
|
||||
self.column_field = column_field
|
||||
self.room_field = room_field
|
||||
self.timestamp_field = timestamp_field
|
||||
self.body_field = body_field
|
||||
self.name_field = topic_field
|
||||
self.topic_field = name_field
|
||||
self.column_field = column_field
|
||||
self.room_field = room_field
|
||||
self.timestamp_field = timestamp_field
|
||||
|
||||
def search(self, search_term, room=None, max_results=10, order_by_recent=False):
|
||||
# type (str, str, int, bool) -> List[int, int]
|
||||
"""Search for events in the index.
|
||||
def search(self, search_term, room=None, max_results=10, order_by_recent=False):
|
||||
# type (str, str, int, bool) -> List[int, int]
|
||||
"""Search for events in the index.
|
||||
|
||||
Returns the score and the column id for the event.
|
||||
"""
|
||||
queryparser = tantivy.QueryParser.for_index(
|
||||
self._index,
|
||||
[self.body_field, self.name_field, self.topic_field, self.room_field],
|
||||
)
|
||||
|
||||
# This currently supports only a single room since the query parser
|
||||
# doesn't seem to work with multiple room fields here.
|
||||
if room:
|
||||
query_string = "{} AND room:{}".format(search_term, sanitize_room_id(room))
|
||||
else:
|
||||
query_string = search_term
|
||||
|
||||
try:
|
||||
query = queryparser.parse_query(query_string)
|
||||
except ValueError:
|
||||
raise InvalidQueryError(f"Invalid search term: {search_term}")
|
||||
|
||||
if order_by_recent:
|
||||
collector = tantivy.TopDocs(
|
||||
max_results, order_by_field=self.timestamp_field
|
||||
)
|
||||
else:
|
||||
collector = tantivy.TopDocs(max_results)
|
||||
|
||||
result = self._searcher.search(query, collector)
|
||||
|
||||
retrieved_result = []
|
||||
|
||||
for score, doc_address in result:
|
||||
doc = self._searcher.doc(doc_address)
|
||||
column = doc.get_first(self.column_field)
|
||||
retrieved_result.append((score, column))
|
||||
|
||||
return retrieved_result
|
||||
|
||||
|
||||
class Index:
|
||||
def __init__(self, path=None, num_searchers=None):
|
||||
schema_builder = tantivy.SchemaBuilder()
|
||||
|
||||
self.body_field = schema_builder.add_text_field("body")
|
||||
self.name_field = schema_builder.add_text_field("name")
|
||||
self.topic_field = schema_builder.add_text_field("topic")
|
||||
|
||||
self.timestamp_field = schema_builder.add_unsigned_field(
|
||||
"server_timestamp", fast="single"
|
||||
)
|
||||
self.date_field = schema_builder.add_date_field("message_date")
|
||||
self.room_field = schema_builder.add_facet_field("room")
|
||||
|
||||
self.column_field = schema_builder.add_unsigned_field(
|
||||
"database_column", indexed=True, stored=True, fast="single"
|
||||
)
|
||||
|
||||
schema = schema_builder.build()
|
||||
|
||||
self.index = tantivy.Index(schema, path)
|
||||
|
||||
self.reader = self.index.reader(num_searchers=num_searchers)
|
||||
self.writer = self.index.writer()
|
||||
|
||||
def add_event(self, column_id, event, room_id):
|
||||
doc = tantivy.Document()
|
||||
|
||||
room_path = "/{}".format(sanitize_room_id(room_id))
|
||||
|
||||
room_facet = tantivy.Facet.from_string(room_path)
|
||||
|
||||
doc.add_unsigned(self.column_field, column_id)
|
||||
doc.add_facet(self.room_field, room_facet)
|
||||
doc.add_date(
|
||||
self.date_field,
|
||||
datetime.datetime.fromtimestamp(event.server_timestamp / 1000),
|
||||
)
|
||||
doc.add_unsigned(self.timestamp_field, event.server_timestamp)
|
||||
|
||||
if isinstance(event, RoomMessageText):
|
||||
doc.add_text(self.body_field, event.body)
|
||||
elif isinstance(event, (RoomMessageMedia, RoomEncryptedMedia)):
|
||||
doc.add_text(self.body_field, event.body)
|
||||
elif isinstance(event, RoomNameEvent):
|
||||
doc.add_text(self.name_field, event.name)
|
||||
elif isinstance(event, RoomTopicEvent):
|
||||
doc.add_text(self.topic_field, event.topic)
|
||||
else:
|
||||
raise ValueError("Invalid event passed.")
|
||||
|
||||
self.writer.add_document(doc)
|
||||
|
||||
def commit(self):
|
||||
self.writer.commit()
|
||||
|
||||
def searcher(self):
|
||||
self.reader.reload()
|
||||
return Searcher(
|
||||
self.index,
|
||||
self.body_field,
|
||||
self.name_field,
|
||||
self.topic_field,
|
||||
self.column_field,
|
||||
self.room_field,
|
||||
self.timestamp_field,
|
||||
self.reader.searcher(),
|
||||
)
|
||||
|
||||
|
||||
@attr.s
|
||||
class StoreItem:
|
||||
event = attr.ib()
|
||||
room_id = attr.ib()
|
||||
display_name = attr.ib(default=None)
|
||||
avatar_url = attr.ib(default=None)
|
||||
|
||||
|
||||
@attr.s
|
||||
class IndexStore:
|
||||
user = attr.ib(type=str)
|
||||
index_path = attr.ib(type=str)
|
||||
store_path = attr.ib(type=str, default=None)
|
||||
store_name = attr.ib(default="events.db")
|
||||
|
||||
index = attr.ib(type=Index, init=False)
|
||||
store = attr.ib(type=MessageStore, init=False)
|
||||
event_queue = attr.ib(factory=list)
|
||||
write_lock = attr.ib(factory=asyncio.Lock)
|
||||
read_semaphore = attr.ib(type=asyncio.Semaphore, init=False)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.store_path = self.store_path or self.index_path
|
||||
num_searchers = os.cpu_count()
|
||||
self.index = Index(self.index_path, num_searchers)
|
||||
self.read_semaphore = asyncio.Semaphore(num_searchers or 1)
|
||||
self.store = MessageStore(self.user, self.store_path, self.store_name)
|
||||
|
||||
def add_event(self, event, room_id, display_name, avatar_url):
|
||||
item = StoreItem(event, room_id, display_name, avatar_url)
|
||||
self.event_queue.append(item)
|
||||
|
||||
@staticmethod
|
||||
def write_events(store, index, event_queue):
|
||||
with store.database.bind_ctx(store.models):
|
||||
with store.database.atomic():
|
||||
for item in event_queue:
|
||||
column_id = store.save_event(item.event, item.room_id)
|
||||
|
||||
if column_id:
|
||||
index.add_event(column_id, item.event, item.room_id)
|
||||
index.commit()
|
||||
|
||||
async def commit_events(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
event_queue = self.event_queue
|
||||
|
||||
if not event_queue:
|
||||
return
|
||||
|
||||
self.event_queue = []
|
||||
|
||||
async with self.write_lock:
|
||||
write_func = partial(
|
||||
IndexStore.write_events, self.store, self.index, event_queue
|
||||
)
|
||||
await loop.run_in_executor(None, write_func)
|
||||
|
||||
def event_in_store(self, event_id, room_id):
|
||||
return self.store.event_in_store(event_id, room_id)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
search_term, # type: str
|
||||
room=None, # type: Optional[str]
|
||||
max_results=10, # type: int
|
||||
order_by_recent=False, # type: bool
|
||||
include_profile=False, # type: bool
|
||||
before_limit=0, # type: int
|
||||
after_limit=0, # type: int
|
||||
):
|
||||
# type: (...) -> Dict[Any, Any]
|
||||
"""Search the indexstore for an event."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Getting a searcher from tantivy may block if there is no searcher
|
||||
# available. To avoid blocking we set up the number of searchers to be
|
||||
# the number of CPUs and the semaphore has the same counter value.
|
||||
async with self.read_semaphore:
|
||||
searcher = self.index.searcher()
|
||||
search_func = partial(
|
||||
searcher.search,
|
||||
search_term,
|
||||
room=room,
|
||||
max_results=max_results,
|
||||
order_by_recent=order_by_recent,
|
||||
Returns the score and the column id for the event.
|
||||
"""
|
||||
queryparser = tantivy.QueryParser.for_index(
|
||||
self._index,
|
||||
[self.body_field, self.name_field, self.topic_field, self.room_field],
|
||||
)
|
||||
|
||||
result = await loop.run_in_executor(None, search_func)
|
||||
# This currently supports only a single room since the query parser
|
||||
# doesn't seem to work with multiple room fields here.
|
||||
if room:
|
||||
query_string = "{} AND room:{}".format(
|
||||
search_term, sanitize_room_id(room)
|
||||
)
|
||||
else:
|
||||
query_string = search_term
|
||||
|
||||
load_event_func = partial(
|
||||
self.store.load_events,
|
||||
result,
|
||||
include_profile,
|
||||
order_by_recent,
|
||||
before_limit,
|
||||
after_limit,
|
||||
try:
|
||||
query = queryparser.parse_query(query_string)
|
||||
except ValueError:
|
||||
raise InvalidQueryError(f"Invalid search term: {search_term}")
|
||||
|
||||
if order_by_recent:
|
||||
collector = tantivy.TopDocs(
|
||||
max_results, order_by_field=self.timestamp_field
|
||||
)
|
||||
else:
|
||||
collector = tantivy.TopDocs(max_results)
|
||||
|
||||
result = self._searcher.search(query, collector)
|
||||
|
||||
retrieved_result = []
|
||||
|
||||
for score, doc_address in result:
|
||||
doc = self._searcher.doc(doc_address)
|
||||
column = doc.get_first(self.column_field)
|
||||
retrieved_result.append((score, column))
|
||||
|
||||
return retrieved_result
|
||||
|
||||
class Index:
|
||||
def __init__(self, path=None, num_searchers=None):
|
||||
schema_builder = tantivy.SchemaBuilder()
|
||||
|
||||
self.body_field = schema_builder.add_text_field("body")
|
||||
self.name_field = schema_builder.add_text_field("name")
|
||||
self.topic_field = schema_builder.add_text_field("topic")
|
||||
|
||||
self.timestamp_field = schema_builder.add_unsigned_field(
|
||||
"server_timestamp", fast="single"
|
||||
)
|
||||
self.date_field = schema_builder.add_date_field("message_date")
|
||||
self.room_field = schema_builder.add_facet_field("room")
|
||||
|
||||
self.column_field = schema_builder.add_unsigned_field(
|
||||
"database_column", indexed=True, stored=True, fast="single"
|
||||
)
|
||||
|
||||
search_result = await loop.run_in_executor(None, load_event_func)
|
||||
schema = schema_builder.build()
|
||||
|
||||
search_result["count"] = len(search_result["results"])
|
||||
search_result["highlights"] = []
|
||||
self.index = tantivy.Index(schema, path)
|
||||
|
||||
return search_result
|
||||
self.reader = self.index.reader(num_searchers=num_searchers)
|
||||
self.writer = self.index.writer()
|
||||
|
||||
def add_event(self, column_id, event, room_id):
|
||||
doc = tantivy.Document()
|
||||
|
||||
room_path = "/{}".format(sanitize_room_id(room_id))
|
||||
|
||||
room_facet = tantivy.Facet.from_string(room_path)
|
||||
|
||||
doc.add_unsigned(self.column_field, column_id)
|
||||
doc.add_facet(self.room_field, room_facet)
|
||||
doc.add_date(
|
||||
self.date_field,
|
||||
datetime.datetime.fromtimestamp(event.server_timestamp / 1000),
|
||||
)
|
||||
doc.add_unsigned(self.timestamp_field, event.server_timestamp)
|
||||
|
||||
if isinstance(event, RoomMessageText):
|
||||
doc.add_text(self.body_field, event.body)
|
||||
elif isinstance(event, (RoomMessageMedia, RoomEncryptedMedia)):
|
||||
doc.add_text(self.body_field, event.body)
|
||||
elif isinstance(event, RoomNameEvent):
|
||||
doc.add_text(self.name_field, event.name)
|
||||
elif isinstance(event, RoomTopicEvent):
|
||||
doc.add_text(self.topic_field, event.topic)
|
||||
else:
|
||||
raise ValueError("Invalid event passed.")
|
||||
|
||||
self.writer.add_document(doc)
|
||||
|
||||
def commit(self):
|
||||
self.writer.commit()
|
||||
|
||||
def searcher(self):
|
||||
self.reader.reload()
|
||||
return Searcher(
|
||||
self.index,
|
||||
self.body_field,
|
||||
self.name_field,
|
||||
self.topic_field,
|
||||
self.column_field,
|
||||
self.room_field,
|
||||
self.timestamp_field,
|
||||
self.reader.searcher(),
|
||||
)
|
||||
|
||||
@attr.s
|
||||
class StoreItem:
|
||||
event = attr.ib()
|
||||
room_id = attr.ib()
|
||||
display_name = attr.ib(default=None)
|
||||
avatar_url = attr.ib(default=None)
|
||||
|
||||
@attr.s
|
||||
class IndexStore:
|
||||
user = attr.ib(type=str)
|
||||
index_path = attr.ib(type=str)
|
||||
store_path = attr.ib(type=str, default=None)
|
||||
store_name = attr.ib(default="events.db")
|
||||
|
||||
index = attr.ib(type=Index, init=False)
|
||||
store = attr.ib(type=MessageStore, init=False)
|
||||
event_queue = attr.ib(factory=list)
|
||||
write_lock = attr.ib(factory=asyncio.Lock)
|
||||
read_semaphore = attr.ib(type=asyncio.Semaphore, init=False)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.store_path = self.store_path or self.index_path
|
||||
num_searchers = os.cpu_count()
|
||||
self.index = Index(self.index_path, num_searchers)
|
||||
self.read_semaphore = asyncio.Semaphore(num_searchers or 1)
|
||||
self.store = MessageStore(self.user, self.store_path, self.store_name)
|
||||
|
||||
def add_event(self, event, room_id, display_name, avatar_url):
|
||||
item = StoreItem(event, room_id, display_name, avatar_url)
|
||||
self.event_queue.append(item)
|
||||
|
||||
@staticmethod
|
||||
def write_events(store, index, event_queue):
|
||||
with store.database.bind_ctx(store.models):
|
||||
with store.database.atomic():
|
||||
for item in event_queue:
|
||||
column_id = store.save_event(item.event, item.room_id)
|
||||
|
||||
if column_id:
|
||||
index.add_event(column_id, item.event, item.room_id)
|
||||
index.commit()
|
||||
|
||||
async def commit_events(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
event_queue = self.event_queue
|
||||
|
||||
if not event_queue:
|
||||
return
|
||||
|
||||
self.event_queue = []
|
||||
|
||||
async with self.write_lock:
|
||||
write_func = partial(
|
||||
IndexStore.write_events, self.store, self.index, event_queue
|
||||
)
|
||||
await loop.run_in_executor(None, write_func)
|
||||
|
||||
def event_in_store(self, event_id, room_id):
|
||||
return self.store.event_in_store(event_id, room_id)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
search_term, # type: str
|
||||
room=None, # type: Optional[str]
|
||||
max_results=10, # type: int
|
||||
order_by_recent=False, # type: bool
|
||||
include_profile=False, # type: bool
|
||||
before_limit=0, # type: int
|
||||
after_limit=0, # type: int
|
||||
):
|
||||
# type: (...) -> Dict[Any, Any]
|
||||
"""Search the indexstore for an event."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Getting a searcher from tantivy may block if there is no searcher
|
||||
# available. To avoid blocking we set up the number of searchers to be
|
||||
# the number of CPUs and the semaphore has the same counter value.
|
||||
async with self.read_semaphore:
|
||||
searcher = self.index.searcher()
|
||||
search_func = partial(
|
||||
searcher.search,
|
||||
search_term,
|
||||
room=room,
|
||||
max_results=max_results,
|
||||
order_by_recent=order_by_recent,
|
||||
)
|
||||
|
||||
result = await loop.run_in_executor(None, search_func)
|
||||
|
||||
load_event_func = partial(
|
||||
self.store.load_events,
|
||||
result,
|
||||
include_profile,
|
||||
order_by_recent,
|
||||
before_limit,
|
||||
after_limit,
|
||||
)
|
||||
|
||||
search_result = await loop.run_in_executor(None, load_event_func)
|
||||
|
||||
search_result["count"] = len(search_result["results"])
|
||||
search_result["highlights"] = []
|
||||
|
||||
return search_result
|
||||
|
||||
|
||||
else:
|
||||
INDEXING_ENABLED = False
|
||||
|
Loading…
x
Reference in New Issue
Block a user