mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-10-01 11:49:51 -04:00
Clean up the database pagination code (#5007)
* rewrite & simplify * changelog * cleanup potential sql injection
This commit is contained in:
parent
616e6a10bd
commit
a33a5abc4c
1
changelog.d/5007.misc
Normal file
1
changelog.d/5007.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Refactor synapse.storage._base._simple_select_list_paginate.
|
@ -18,6 +18,8 @@ import calendar
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import PresenceState
|
from synapse.api.constants import PresenceState
|
||||||
from synapse.storage.devices import DeviceStore
|
from synapse.storage.devices import DeviceStore
|
||||||
from synapse.storage.user_erasure_store import UserErasureStore
|
from synapse.storage.user_erasure_store import UserErasureStore
|
||||||
@ -453,6 +455,7 @@ class DataStore(
|
|||||||
desc="get_users",
|
desc="get_users",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def get_users_paginate(self, order, start, limit):
|
def get_users_paginate(self, order, start, limit):
|
||||||
"""Function to reterive a paginated list of users from
|
"""Function to reterive a paginated list of users from
|
||||||
users list. This will return a json object, which contains
|
users list. This will return a json object, which contains
|
||||||
@ -465,16 +468,19 @@ class DataStore(
|
|||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
|
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
|
||||||
"""
|
"""
|
||||||
is_guest = 0
|
users = yield self.runInteraction(
|
||||||
i_start = (int)(start)
|
"get_users_paginate",
|
||||||
i_limit = (int)(limit)
|
self._simple_select_list_paginate_txn,
|
||||||
return self.get_user_list_paginate(
|
|
||||||
table="users",
|
table="users",
|
||||||
keyvalues={"is_guest": is_guest},
|
keyvalues={"is_guest": False},
|
||||||
pagevalues=[order, i_limit, i_start],
|
orderby=order,
|
||||||
|
start=start,
|
||||||
|
limit=limit,
|
||||||
retcols=["name", "password_hash", "is_guest", "admin"],
|
retcols=["name", "password_hash", "is_guest", "admin"],
|
||||||
desc="get_users_paginate",
|
|
||||||
)
|
)
|
||||||
|
count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
|
||||||
|
retval = {"users": users, "total": count}
|
||||||
|
defer.returnValue(retval)
|
||||||
|
|
||||||
def search_users(self, term):
|
def search_users(self, term):
|
||||||
"""Function to search users list for one or more users with
|
"""Function to search users list for one or more users with
|
||||||
|
@ -595,7 +595,7 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
table (str): The table to upsert into
|
table (str): The table to upsert into
|
||||||
keyvalues (dict): The unique key tables and their new values
|
keyvalues (dict): The unique key columns and their new values
|
||||||
values (dict): The nonunique columns and their new values
|
values (dict): The nonunique columns and their new values
|
||||||
insertion_values (dict): additional key/values to use only when
|
insertion_values (dict): additional key/values to use only when
|
||||||
inserting
|
inserting
|
||||||
@ -627,7 +627,7 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
# presumably we raced with another transaction: let's retry.
|
# presumably we raced with another transaction: let's retry.
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"%s when upserting into %s; retrying: %s", e.__name__, table, e
|
"IntegrityError when upserting into %s; retrying: %s", table, e
|
||||||
)
|
)
|
||||||
|
|
||||||
def _simple_upsert_txn(
|
def _simple_upsert_txn(
|
||||||
@ -1398,21 +1398,31 @@ class SQLBaseStore(object):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
def _simple_select_list_paginate(
|
def _simple_select_list_paginate(
|
||||||
self, table, keyvalues, pagevalues, retcols, desc="_simple_select_list_paginate"
|
self,
|
||||||
|
table,
|
||||||
|
keyvalues,
|
||||||
|
orderby,
|
||||||
|
start,
|
||||||
|
limit,
|
||||||
|
retcols,
|
||||||
|
order_direction="ASC",
|
||||||
|
desc="_simple_select_list_paginate",
|
||||||
):
|
):
|
||||||
"""Executes a SELECT query on the named table with start and limit,
|
"""
|
||||||
|
Executes a SELECT query on the named table with start and limit,
|
||||||
of row numbers, which may return zero or number of rows from start to limit,
|
of row numbers, which may return zero or number of rows from start to limit,
|
||||||
returning the result as a list of dicts.
|
returning the result as a list of dicts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table (str): the table name
|
table (str): the table name
|
||||||
keyvalues (dict[str, Any] | None):
|
keyvalues (dict[str, T] | None):
|
||||||
column names and values to select the rows with, or None to not
|
column names and values to select the rows with, or None to not
|
||||||
apply a WHERE clause.
|
apply a WHERE clause.
|
||||||
|
orderby (str): Column to order the results by.
|
||||||
|
start (int): Index to begin the query at.
|
||||||
|
limit (int): Number of results to return.
|
||||||
retcols (iterable[str]): the names of the columns to return
|
retcols (iterable[str]): the names of the columns to return
|
||||||
order (str): order the select by this column
|
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
|
||||||
start (int): start number to begin the query from
|
|
||||||
limit (int): number of rows to reterive
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: resolves to list[dict[str, Any]]
|
defer.Deferred: resolves to list[dict[str, Any]]
|
||||||
"""
|
"""
|
||||||
@ -1421,15 +1431,27 @@ class SQLBaseStore(object):
|
|||||||
self._simple_select_list_paginate_txn,
|
self._simple_select_list_paginate_txn,
|
||||||
table,
|
table,
|
||||||
keyvalues,
|
keyvalues,
|
||||||
pagevalues,
|
orderby,
|
||||||
|
start,
|
||||||
|
limit,
|
||||||
retcols,
|
retcols,
|
||||||
|
order_direction=order_direction,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _simple_select_list_paginate_txn(
|
def _simple_select_list_paginate_txn(
|
||||||
cls, txn, table, keyvalues, pagevalues, retcols
|
cls,
|
||||||
|
txn,
|
||||||
|
table,
|
||||||
|
keyvalues,
|
||||||
|
orderby,
|
||||||
|
start,
|
||||||
|
limit,
|
||||||
|
retcols,
|
||||||
|
order_direction="ASC",
|
||||||
):
|
):
|
||||||
"""Executes a SELECT query on the named table with start and limit,
|
"""
|
||||||
|
Executes a SELECT query on the named table with start and limit,
|
||||||
of row numbers, which may return zero or number of rows from start to limit,
|
of row numbers, which may return zero or number of rows from start to limit,
|
||||||
returning the result as a list of dicts.
|
returning the result as a list of dicts.
|
||||||
|
|
||||||
@ -1439,65 +1461,33 @@ class SQLBaseStore(object):
|
|||||||
keyvalues (dict[str, T] | None):
|
keyvalues (dict[str, T] | None):
|
||||||
column names and values to select the rows with, or None to not
|
column names and values to select the rows with, or None to not
|
||||||
apply a WHERE clause.
|
apply a WHERE clause.
|
||||||
pagevalues ([]):
|
orderby (str): Column to order the results by.
|
||||||
order (str): order the select by this column
|
start (int): Index to begin the query at.
|
||||||
start (int): start number to begin the query from
|
limit (int): Number of results to return.
|
||||||
limit (int): number of rows to reterive
|
|
||||||
retcols (iterable[str]): the names of the columns to return
|
retcols (iterable[str]): the names of the columns to return
|
||||||
|
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: resolves to list[dict[str, Any]]
|
defer.Deferred: resolves to list[dict[str, Any]]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if order_direction not in ["ASC", "DESC"]:
|
||||||
|
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
|
||||||
|
|
||||||
if keyvalues:
|
if keyvalues:
|
||||||
sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % (
|
where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||||
", ".join(retcols),
|
|
||||||
table,
|
|
||||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
|
||||||
" ? ASC LIMIT ? OFFSET ?",
|
|
||||||
)
|
|
||||||
txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
|
|
||||||
else:
|
else:
|
||||||
sql = "SELECT %s FROM %s ORDER BY %s" % (
|
where_clause = ""
|
||||||
", ".join(retcols),
|
|
||||||
table,
|
sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
|
||||||
" ? ASC LIMIT ? OFFSET ?",
|
", ".join(retcols),
|
||||||
)
|
table,
|
||||||
txn.execute(sql, pagevalues)
|
where_clause,
|
||||||
|
orderby,
|
||||||
|
order_direction,
|
||||||
|
)
|
||||||
|
txn.execute(sql, list(keyvalues.values()) + [limit, start])
|
||||||
|
|
||||||
return cls.cursor_to_dict(txn)
|
return cls.cursor_to_dict(txn)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_user_list_paginate(
|
|
||||||
self, table, keyvalues, pagevalues, retcols, desc="get_user_list_paginate"
|
|
||||||
):
|
|
||||||
"""Get a list of users from start row to a limit number of rows. This will
|
|
||||||
return a json object with users and total number of users in users list.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
table (str): the table name
|
|
||||||
keyvalues (dict[str, Any] | None):
|
|
||||||
column names and values to select the rows with, or None to not
|
|
||||||
apply a WHERE clause.
|
|
||||||
pagevalues ([]):
|
|
||||||
order (str): order the select by this column
|
|
||||||
start (int): start number to begin the query from
|
|
||||||
limit (int): number of rows to reterive
|
|
||||||
retcols (iterable[str]): the names of the columns to return
|
|
||||||
Returns:
|
|
||||||
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
|
|
||||||
"""
|
|
||||||
users = yield self.runInteraction(
|
|
||||||
desc,
|
|
||||||
self._simple_select_list_paginate_txn,
|
|
||||||
table,
|
|
||||||
keyvalues,
|
|
||||||
pagevalues,
|
|
||||||
retcols,
|
|
||||||
)
|
|
||||||
count = yield self.runInteraction(desc, self.get_user_count_txn)
|
|
||||||
retval = {"users": users, "total": count}
|
|
||||||
defer.returnValue(retval)
|
|
||||||
|
|
||||||
def get_user_count_txn(self, txn):
|
def get_user_count_txn(self, txn):
|
||||||
"""Get a total number of registered users in the users list.
|
"""Get a total number of registered users in the users list.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user