Implement the SHHS complexity API (#5216)

This commit is contained in:
Amber Brown 2019-05-30 01:47:16 +10:00 committed by GitHub
parent 532b825ed9
commit 46c8f7a517
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 180 additions and 5 deletions

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import itertools
import logging
from collections import namedtuple
@ -614,7 +616,7 @@ class EventsWorkerStore(SQLBaseStore):
def _get_total_state_event_counts_txn(self, txn, room_id):
"""
See get_state_event_counts.
See get_total_state_event_counts.
"""
sql = "SELECT COUNT(*) FROM state_events WHERE room_id=?"
txn.execute(sql, (room_id,))
@ -635,3 +637,49 @@ class EventsWorkerStore(SQLBaseStore):
"get_total_state_event_counts",
self._get_total_state_event_counts_txn, room_id
)
def _get_current_state_event_counts_txn(self, txn, room_id):
"""
See get_current_state_event_counts.
"""
sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?"
txn.execute(sql, (room_id,))
row = txn.fetchone()
return row[0] if row else 0
def get_current_state_event_counts(self, room_id):
"""
Gets the current number of state events in a room.
Args:
room_id (str)
Returns:
Deferred[int]
"""
return self.runInteraction(
"get_current_state_event_counts",
self._get_current_state_event_counts_txn, room_id
)
@defer.inlineCallbacks
def get_room_complexity(self, room_id):
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
Higher complexity value indicates that being in the room will consume
more resources.
Args:
room_id (str)
Returns:
Deferred[dict[str:int]] of complexity version to complexity.
"""
state_events = yield self.get_current_state_event_counts(room_id)
# Call this one "v1", so we can introduce new ones as we want to develop
# it.
complexity_v1 = round(state_events / 500, 2)
defer.returnValue({"v1": complexity_v1})