mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
2557531f0f
This was caused by accidentally overwritting a `last_seen` variable in a for loop, causing the wrong value to be written to the progress table. The result of which was that we didn't scan sections of the table when searching for duplicates, and so some duplicates did not get deleted.
394 lines
13 KiB
Python
394 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright 2016 OpenMarket Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import logging
|
|
|
|
from six import iteritems
|
|
|
|
from twisted.internet import defer
|
|
|
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
|
from synapse.util.caches import CACHE_SIZE_FACTOR
|
|
|
|
from . import background_updates
|
|
from ._base import Cache
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
|
# times give more inserts into the database even for readonly API hits
|
|
# 120 seconds == 2 minutes
|
|
LAST_SEEN_GRANULARITY = 120 * 1000
|
|
|
|
|
|
class ClientIpStore(background_updates.BackgroundUpdateStore):
|
|
def __init__(self, db_conn, hs):
|
|
|
|
self.client_ip_last_seen = Cache(
|
|
name="client_ip_last_seen",
|
|
keylen=4,
|
|
max_entries=50000 * CACHE_SIZE_FACTOR,
|
|
)
|
|
|
|
super(ClientIpStore, self).__init__(db_conn, hs)
|
|
|
|
self.register_background_index_update(
|
|
"user_ips_device_index",
|
|
index_name="user_ips_device_id",
|
|
table="user_ips",
|
|
columns=["user_id", "device_id", "last_seen"],
|
|
)
|
|
|
|
self.register_background_index_update(
|
|
"user_ips_last_seen_index",
|
|
index_name="user_ips_last_seen",
|
|
table="user_ips",
|
|
columns=["user_id", "last_seen"],
|
|
)
|
|
|
|
self.register_background_index_update(
|
|
"user_ips_last_seen_only_index",
|
|
index_name="user_ips_last_seen_only",
|
|
table="user_ips",
|
|
columns=["last_seen"],
|
|
)
|
|
|
|
self.register_background_update_handler(
|
|
"user_ips_remove_dupes",
|
|
self._remove_user_ip_dupes,
|
|
)
|
|
|
|
# Register a unique index
|
|
self.register_background_index_update(
|
|
"user_ips_device_unique_index",
|
|
index_name="user_ips_user_token_ip_unique_index",
|
|
table="user_ips",
|
|
columns=["user_id", "access_token", "ip"],
|
|
unique=True,
|
|
)
|
|
|
|
# Drop the old non-unique index
|
|
self.register_background_update_handler(
|
|
"user_ips_drop_nonunique_index",
|
|
self._remove_user_ip_nonunique,
|
|
)
|
|
|
|
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
|
|
self._batch_row_update = {}
|
|
|
|
self._client_ip_looper = self._clock.looping_call(
|
|
self._update_client_ips_batch, 5 * 1000
|
|
)
|
|
self.hs.get_reactor().addSystemEventTrigger(
|
|
"before", "shutdown", self._update_client_ips_batch
|
|
)
|
|
|
|
@defer.inlineCallbacks
|
|
def _remove_user_ip_nonunique(self, progress, batch_size):
|
|
def f(conn):
|
|
txn = conn.cursor()
|
|
txn.execute(
|
|
"DROP INDEX IF EXISTS user_ips_user_ip"
|
|
)
|
|
txn.close()
|
|
|
|
yield self.runWithConnection(f)
|
|
yield self._end_background_update("user_ips_drop_nonunique_index")
|
|
defer.returnValue(1)
|
|
|
|
@defer.inlineCallbacks
|
|
def _remove_user_ip_dupes(self, progress, batch_size):
|
|
# This works function works by scanning the user_ips table in batches
|
|
# based on `last_seen`. For each row in a batch it searches the rest of
|
|
# the table to see if there are any duplicates, if there are then they
|
|
# are removed and replaced with a suitable row.
|
|
|
|
# Fetch the start of the batch
|
|
begin_last_seen = progress.get("last_seen", 0)
|
|
|
|
def get_last_seen(txn):
|
|
txn.execute(
|
|
"""
|
|
SELECT last_seen FROM user_ips
|
|
WHERE last_seen > ?
|
|
ORDER BY last_seen
|
|
LIMIT 1
|
|
OFFSET ?
|
|
""",
|
|
(begin_last_seen, batch_size)
|
|
)
|
|
row = txn.fetchone()
|
|
if row:
|
|
return row[0]
|
|
else:
|
|
return None
|
|
|
|
# Get a last seen that has roughly `batch_size` since `begin_last_seen`
|
|
end_last_seen = yield self.runInteraction(
|
|
"user_ips_dups_get_last_seen", get_last_seen
|
|
)
|
|
|
|
if end_last_seen is None:
|
|
# If we get a None then we're reaching the end and just need to
|
|
# delete the last batch.
|
|
last = True
|
|
|
|
# We fake not having an upper bound by using a future date, by
|
|
# just multiplying the current time by two....
|
|
last_seen = int(self.clock.time_msec()) * 2
|
|
else:
|
|
last = False
|
|
|
|
def remove(txn, begin_last_seen, end_last_seen):
|
|
# This works by looking at all entries in the given time span, and
|
|
# then for each (user_id, access_token, ip) tuple in that range
|
|
# checking for any duplicates in the rest of the table (via a join).
|
|
# It then only returns entries which have duplicates, and the max
|
|
# last_seen across all duplicates, which can the be used to delete
|
|
# all other duplicates.
|
|
# It is efficient due to the existence of (user_id, access_token,
|
|
# ip) and (last_seen) indices.
|
|
txn.execute(
|
|
"""
|
|
SELECT user_id, access_token, ip,
|
|
MAX(device_id), MAX(user_agent), MAX(last_seen)
|
|
FROM (
|
|
SELECT user_id, access_token, ip
|
|
FROM user_ips
|
|
WHERE ? <= last_seen AND last_seen < ?
|
|
ORDER BY last_seen
|
|
) c
|
|
INNER JOIN user_ips USING (user_id, access_token, ip)
|
|
GROUP BY user_id, access_token, ip
|
|
HAVING count(*) > 1""",
|
|
(begin_last_seen, end_last_seen)
|
|
)
|
|
res = txn.fetchall()
|
|
|
|
# We've got some duplicates
|
|
for i in res:
|
|
user_id, access_token, ip, device_id, user_agent, last_seen = i
|
|
|
|
# Drop all the duplicates
|
|
txn.execute(
|
|
"""
|
|
DELETE FROM user_ips
|
|
WHERE user_id = ? AND access_token = ? AND ip = ?
|
|
""",
|
|
(user_id, access_token, ip)
|
|
)
|
|
|
|
# Add in one to be the last_seen
|
|
txn.execute(
|
|
"""
|
|
INSERT INTO user_ips
|
|
(user_id, access_token, ip, device_id, user_agent, last_seen)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(user_id, access_token, ip, device_id, user_agent, last_seen)
|
|
)
|
|
|
|
self._background_update_progress_txn(
|
|
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
|
|
)
|
|
|
|
yield self.runInteraction(
|
|
"user_ips_dups_remove", remove, begin_last_seen, end_last_seen
|
|
)
|
|
if last:
|
|
yield self._end_background_update("user_ips_remove_dupes")
|
|
|
|
defer.returnValue(batch_size)
|
|
|
|
@defer.inlineCallbacks
|
|
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
|
|
now=None):
|
|
if not now:
|
|
now = int(self._clock.time_msec())
|
|
key = (user_id, access_token, ip)
|
|
|
|
try:
|
|
last_seen = self.client_ip_last_seen.get(key)
|
|
except KeyError:
|
|
last_seen = None
|
|
yield self.populate_monthly_active_users(user_id)
|
|
# Rate-limited inserts
|
|
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
|
return
|
|
|
|
self.client_ip_last_seen.prefill(key, now)
|
|
|
|
self._batch_row_update[key] = (user_agent, device_id, now)
|
|
|
|
def _update_client_ips_batch(self):
|
|
|
|
# If the DB pool has already terminated, don't try updating
|
|
if not self.hs.get_db_pool().running:
|
|
return
|
|
|
|
def update():
|
|
to_update = self._batch_row_update
|
|
self._batch_row_update = {}
|
|
return self.runInteraction(
|
|
"_update_client_ips_batch", self._update_client_ips_batch_txn,
|
|
to_update,
|
|
)
|
|
|
|
return run_as_background_process(
|
|
"update_client_ips", update,
|
|
)
|
|
|
|
def _update_client_ips_batch_txn(self, txn, to_update):
|
|
self.database_engine.lock_table(txn, "user_ips")
|
|
|
|
for entry in iteritems(to_update):
|
|
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
|
|
|
|
try:
|
|
self._simple_upsert_txn(
|
|
txn,
|
|
table="user_ips",
|
|
keyvalues={
|
|
"user_id": user_id,
|
|
"access_token": access_token,
|
|
"ip": ip,
|
|
},
|
|
values={
|
|
"user_agent": user_agent,
|
|
"device_id": device_id,
|
|
"last_seen": last_seen,
|
|
},
|
|
lock=False,
|
|
)
|
|
except Exception as e:
|
|
# Failed to upsert, log and continue
|
|
logger.error("Failed to insert client IP %r: %r", entry, e)
|
|
|
|
@defer.inlineCallbacks
|
|
def get_last_client_ip_by_device(self, user_id, device_id):
|
|
"""For each device_id listed, give the user_ip it was last seen on
|
|
|
|
Args:
|
|
user_id (str)
|
|
device_id (str): If None fetches all devices for the user
|
|
|
|
Returns:
|
|
defer.Deferred: resolves to a dict, where the keys
|
|
are (user_id, device_id) tuples. The values are also dicts, with
|
|
keys giving the column names
|
|
"""
|
|
|
|
res = yield self.runInteraction(
|
|
"get_last_client_ip_by_device",
|
|
self._get_last_client_ip_by_device_txn,
|
|
user_id, device_id,
|
|
retcols=(
|
|
"user_id",
|
|
"access_token",
|
|
"ip",
|
|
"user_agent",
|
|
"device_id",
|
|
"last_seen",
|
|
),
|
|
)
|
|
|
|
ret = {(d["user_id"], d["device_id"]): d for d in res}
|
|
for key in self._batch_row_update:
|
|
uid, access_token, ip = key
|
|
if uid == user_id:
|
|
user_agent, did, last_seen = self._batch_row_update[key]
|
|
if not device_id or did == device_id:
|
|
ret[(user_id, device_id)] = {
|
|
"user_id": user_id,
|
|
"access_token": access_token,
|
|
"ip": ip,
|
|
"user_agent": user_agent,
|
|
"device_id": did,
|
|
"last_seen": last_seen,
|
|
}
|
|
defer.returnValue(ret)
|
|
|
|
@classmethod
|
|
def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
|
|
where_clauses = []
|
|
bindings = []
|
|
if device_id is None:
|
|
where_clauses.append("user_id = ?")
|
|
bindings.extend((user_id, ))
|
|
else:
|
|
where_clauses.append("(user_id = ? AND device_id = ?)")
|
|
bindings.extend((user_id, device_id))
|
|
|
|
if not where_clauses:
|
|
return []
|
|
|
|
inner_select = (
|
|
"SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
|
|
"WHERE %(where)s "
|
|
"GROUP BY user_id, device_id"
|
|
) % {
|
|
"where": " OR ".join(where_clauses),
|
|
}
|
|
|
|
sql = (
|
|
"SELECT %(retcols)s FROM user_ips "
|
|
"JOIN (%(inner_select)s) ips ON"
|
|
" user_ips.last_seen = ips.mls AND"
|
|
" user_ips.user_id = ips.user_id AND"
|
|
" (user_ips.device_id = ips.device_id OR"
|
|
" (user_ips.device_id IS NULL AND ips.device_id IS NULL)"
|
|
" )"
|
|
) % {
|
|
"retcols": ",".join("user_ips." + c for c in retcols),
|
|
"inner_select": inner_select,
|
|
}
|
|
|
|
txn.execute(sql, bindings)
|
|
return cls.cursor_to_dict(txn)
|
|
|
|
@defer.inlineCallbacks
|
|
def get_user_ip_and_agents(self, user):
|
|
user_id = user.to_string()
|
|
results = {}
|
|
|
|
for key in self._batch_row_update:
|
|
uid, access_token, ip, = key
|
|
if uid == user_id:
|
|
user_agent, _, last_seen = self._batch_row_update[key]
|
|
results[(access_token, ip)] = (user_agent, last_seen)
|
|
|
|
rows = yield self._simple_select_list(
|
|
table="user_ips",
|
|
keyvalues={"user_id": user_id},
|
|
retcols=[
|
|
"access_token", "ip", "user_agent", "last_seen"
|
|
],
|
|
desc="get_user_ip_and_agents",
|
|
)
|
|
|
|
results.update(
|
|
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
|
|
for row in rows
|
|
)
|
|
defer.returnValue(list(
|
|
{
|
|
"access_token": access_token,
|
|
"ip": ip,
|
|
"user_agent": user_agent,
|
|
"last_seen": last_seen,
|
|
}
|
|
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
|
|
))
|