store: Add methods to store media info in the store.

This commit is contained in:
Damir Jelić 2020-02-20 13:14:18 +01:00
parent 99a1bccbc8
commit d25989cfd7
3 changed files with 130 additions and 2 deletions

View file

@ -27,6 +27,10 @@ from nio.store import (
use_database_atomic, use_database_atomic,
) )
from peewee import SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase, TextField from peewee import SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase, TextField
from cachetools import LRUCache
MAX_LOADED_MEDIA = 10000
@attr.s @attr.s
@ -35,6 +39,15 @@ class FetchTask:
token = attr.ib(type=str) token = attr.ib(type=str)
@attr.s
class MediaInfo:
mxc_server = attr.ib(type=str)
mxc_path = attr.ib(type=str)
key = attr.ib(type=dict)
iv = attr.ib(type=str)
hashes = attr.ib(type=dict)
class DictField(TextField): class DictField(TextField):
def python_value(self, value): # pragma: no cover def python_value(self, value): # pragma: no cover
return json.loads(value) return json.loads(value)
@ -86,6 +99,20 @@ class PanFetcherTasks(Model):
constraints = [SQL("UNIQUE(user_id, room_id, token)")] constraints = [SQL("UNIQUE(user_id, room_id, token)")]
class PanMediaInfo(Model):
server = ForeignKeyField(
model=Servers, column_name="server_id", backref="media", on_delete="CASCADE"
)
mxc_server = TextField()
mxc_path = TextField()
key = DictField()
hashes = DictField()
iv = TextField()
class Meta:
constraints = [SQL("UNIQUE(server_id, mxc_server, mxc_path)")]
@attr.s @attr.s
class ClientInfo: class ClientInfo:
user_id = attr.ib(type=str) user_id = attr.ib(type=str)
@ -107,6 +134,7 @@ class PanStore:
DeviceTrustState, DeviceTrustState,
PanSyncTokens, PanSyncTokens,
PanFetcherTasks, PanFetcherTasks,
PanMediaInfo,
] ]
def __attrs_post_init__(self): def __attrs_post_init__(self):
@ -134,6 +162,46 @@ class PanStore:
except DoesNotExist: except DoesNotExist:
return None return None
@use_database
def save_media(self, server, media):
server = Servers.get(name=server)
PanMediaInfo.insert(
server=server,
mxc_server=media.mxc_server,
mxc_path=media.mxc_path,
key=media.key,
iv=media.iv,
hashes=media.hashes,
).on_conflict_ignore().execute()
@use_database
def load_media(self, server, mxc_server=None, mxc_path=None):
server, _ = Servers.get_or_create(name=server)
if not mxc_path:
media_cache = LRUCache(maxsize=MAX_LOADED_MEDIA)
for i, m in enumerate(server.media):
if i > MAX_LOADED_MEDIA:
break
media = MediaInfo(m.mxc_server, m.mxc_path, m.key, m.iv, m.hashes)
media_cache[(m.mxc_server, m.mxc_path)] = media
return media_cache
else:
m = PanMediaInfo.get_or_none(
PanMediaInfo.server == server,
PanMediaInfo.mxc_server == mxc_server,
PanMediaInfo.mxc_path == mxc_path,
)
if not m:
return None
return MediaInfo(m.mxc_server, m.mxc_path, m.key, m.iv, m.hashes)
@use_database_atomic @use_database_atomic
def replace_fetcher_task(self, server, pan_user, old_task, new_task): def replace_fetcher_task(self, server, pan_user, old_task, new_task):
server = Servers.get(name=server) server = Servers.get(name=server)
@ -169,6 +237,7 @@ class PanStore:
return tasks return tasks
@use_database
def delete_fetcher_task(self, server, pan_user, task): def delete_fetcher_task(self, server, pan_user, task):
server = Servers.get(name=server) server = Servers.get(name=server)
user = ServerUsers.get(server=server, user_id=pan_user) user = ServerUsers.get(server=server, user_id=pan_user)

View file

@ -26,6 +26,7 @@ setup(
"logbook", "logbook",
"peewee", "peewee",
"janus", "janus",
"cachetools >= 3.0.0"
"prompt_toolkit>2<4", "prompt_toolkit>2<4",
"typing;python_version<'3.5'", "typing;python_version<'3.5'",
"matrix-nio[e2e] >= 0.4.1" "matrix-nio[e2e] >= 0.4.1"

View file

@ -3,11 +3,12 @@ import pdb
import pprint import pprint
import pytest import pytest
from nio import RoomMessage from nio import RoomMessage, RoomEncryptedMedia
from urllib.parse import urlparse
from conftest import faker from conftest import faker
from pantalaimon.index import INDEXING_ENABLED from pantalaimon.index import INDEXING_ENABLED
from pantalaimon.store import FetchTask from pantalaimon.store import FetchTask, MediaInfo
TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost" TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost"
TEST_ROOM2 = "!testroom:localhost" TEST_ROOM2 = "!testroom:localhost"
@ -46,6 +47,36 @@ class TestClass(object):
} }
) )
@property
def encrypted_media_event(self):
return RoomEncryptedMedia.from_dict({
"room_id": "!testroom:localhost",
"event_id": "$15163622445EBvZK:localhost",
"origin_server_ts": 1516362244030,
"sender": "@example2:localhost",
"type": "m.room.message",
"content": {
"body": "orange_cat.jpg",
"msgtype": "m.image",
"file": {
"v": "v2",
"key": {
"alg": "A256CTR",
"ext": True,
"k": "yx0QvkgYlasdWEsdalkejaHBzCkKEBAp3tB7dGtWgrs",
"key_ops": ["encrypt", "decrypt"],
"kty": "oct"
},
"iv": "0pglXX7fspIBBBBAEERLFd",
"hashes": {
"sha256": "eXRDFvh+aXsQRj8a+5ZVVWUQ9Y6u9DYiz4tq1NvbLu8"
},
"url": "mxc://localhost/maDtasSiPFjROFMnlwxIhhyW",
"mimetype": "image/jpeg"
}
}
})
def test_account_loading(self, panstore): def test_account_loading(self, panstore):
accounts = panstore.load_all_users() accounts = panstore.load_all_users()
# pdb.set_trace() # pdb.set_trace()
@ -119,3 +150,30 @@ class TestClass(object):
assert result["results"][0]["result"] == self.test_event.source assert result["results"][0]["result"] == self.test_event.source
assert (result["results"][0]["context"]["events_after"][0] assert (result["results"][0]["context"]["events_after"][0]
== self.another_event.source) == self.another_event.source)
def test_media_storage(self, panstore):
server_name = "test"
media_cache = panstore.load_media(server_name)
assert not media_cache
event = self.encrypted_media_event
mxc = urlparse(event.url)
assert mxc
mxc_server = mxc.netloc
mxc_path = mxc.path
assert not panstore.load_media(server_name, mxc_server, mxc_path)
media = MediaInfo(mxc_server, mxc_path, event.key, event.iv, event.hashes)
panstore.save_media(server_name, media)
media_cache = panstore.load_media(server_name)
assert (mxc_server, mxc_path) in media_cache
media_info = media_cache[(mxc_server, mxc_path)]
assert media_info == media
assert media_info == panstore.load_media(server_name, mxc_server, mxc_path)