Clean up the database pagination code (#5007)

* rewrite & simplify

* changelog

* cleanup potential sql injection
This commit is contained in:
Amber Brown 2019-04-05 00:21:16 +11:00 committed by GitHub
parent 616e6a10bd
commit a33a5abc4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 68 deletions

1
changelog.d/5007.misc Normal file
View File

@ -0,0 +1 @@
Refactor synapse.storage._base._simple_select_list_paginate.

View File

@ -18,6 +18,8 @@ import calendar
import logging import logging
import time import time
from twisted.internet import defer
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.storage.devices import DeviceStore from synapse.storage.devices import DeviceStore
from synapse.storage.user_erasure_store import UserErasureStore from synapse.storage.user_erasure_store import UserErasureStore
@ -453,6 +455,7 @@ class DataStore(
desc="get_users", desc="get_users",
) )
@defer.inlineCallbacks
def get_users_paginate(self, order, start, limit): def get_users_paginate(self, order, start, limit):
"""Function to reterive a paginated list of users from """Function to reterive a paginated list of users from
users list. This will return a json object, which contains users list. This will return a json object, which contains
@ -465,16 +468,19 @@ class DataStore(
Returns: Returns:
defer.Deferred: resolves to json object {list[dict[str, Any]], count} defer.Deferred: resolves to json object {list[dict[str, Any]], count}
""" """
is_guest = 0 users = yield self.runInteraction(
i_start = (int)(start) "get_users_paginate",
i_limit = (int)(limit) self._simple_select_list_paginate_txn,
return self.get_user_list_paginate(
table="users", table="users",
keyvalues={"is_guest": is_guest}, keyvalues={"is_guest": False},
pagevalues=[order, i_limit, i_start], orderby=order,
start=start,
limit=limit,
retcols=["name", "password_hash", "is_guest", "admin"], retcols=["name", "password_hash", "is_guest", "admin"],
desc="get_users_paginate",
) )
count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
retval = {"users": users, "total": count}
defer.returnValue(retval)
def search_users(self, term): def search_users(self, term):
"""Function to search users list for one or more users with """Function to search users list for one or more users with

View File

@ -595,7 +595,7 @@ class SQLBaseStore(object):
Args: Args:
table (str): The table to upsert into table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values keyvalues (dict): The unique key columns and their new values
values (dict): The nonunique columns and their new values values (dict): The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when insertion_values (dict): additional key/values to use only when
inserting inserting
@ -627,7 +627,7 @@ class SQLBaseStore(object):
# presumably we raced with another transaction: let's retry. # presumably we raced with another transaction: let's retry.
logger.warn( logger.warn(
"%s when upserting into %s; retrying: %s", e.__name__, table, e "IntegrityError when upserting into %s; retrying: %s", table, e
) )
def _simple_upsert_txn( def _simple_upsert_txn(
@ -1398,21 +1398,31 @@ class SQLBaseStore(object):
return 0 return 0
def _simple_select_list_paginate( def _simple_select_list_paginate(
self, table, keyvalues, pagevalues, retcols, desc="_simple_select_list_paginate" self,
table,
keyvalues,
orderby,
start,
limit,
retcols,
order_direction="ASC",
desc="_simple_select_list_paginate",
): ):
"""Executes a SELECT query on the named table with start and limit, """
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit, of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts. returning the result as a list of dicts.
Args: Args:
table (str): the table name table (str): the table name
keyvalues (dict[str, Any] | None): keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not column names and values to select the rows with, or None to not
apply a WHERE clause. apply a WHERE clause.
orderby (str): Column to order the results by.
start (int): Index to begin the query at.
limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return retcols (iterable[str]): the names of the columns to return
order (str): order the select by this column order_direction (str): Whether the results should be ordered "ASC" or "DESC".
start (int): start number to begin the query from
limit (int): number of rows to reterive
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] defer.Deferred: resolves to list[dict[str, Any]]
""" """
@ -1421,15 +1431,27 @@ class SQLBaseStore(object):
self._simple_select_list_paginate_txn, self._simple_select_list_paginate_txn,
table, table,
keyvalues, keyvalues,
pagevalues, orderby,
start,
limit,
retcols, retcols,
order_direction=order_direction,
) )
@classmethod @classmethod
def _simple_select_list_paginate_txn( def _simple_select_list_paginate_txn(
cls, txn, table, keyvalues, pagevalues, retcols cls,
txn,
table,
keyvalues,
orderby,
start,
limit,
retcols,
order_direction="ASC",
): ):
"""Executes a SELECT query on the named table with start and limit, """
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit, of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts. returning the result as a list of dicts.
@ -1439,65 +1461,33 @@ class SQLBaseStore(object):
keyvalues (dict[str, T] | None): keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not column names and values to select the rows with, or None to not
apply a WHERE clause. apply a WHERE clause.
pagevalues ([]): orderby (str): Column to order the results by.
order (str): order the select by this column start (int): Index to begin the query at.
start (int): start number to begin the query from limit (int): Number of results to return.
limit (int): number of rows to reterive
retcols (iterable[str]): the names of the columns to return retcols (iterable[str]): the names of the columns to return
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] defer.Deferred: resolves to list[dict[str, Any]]
""" """
if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
if keyvalues: if keyvalues:
sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % ( where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues),
" ? ASC LIMIT ? OFFSET ?",
)
txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
else: else:
sql = "SELECT %s FROM %s ORDER BY %s" % ( where_clause = ""
sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" ? ASC LIMIT ? OFFSET ?", where_clause,
orderby,
order_direction,
) )
txn.execute(sql, pagevalues) txn.execute(sql, list(keyvalues.values()) + [limit, start])
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
def get_user_list_paginate(
self, table, keyvalues, pagevalues, retcols, desc="get_user_list_paginate"
):
"""Get a list of users from start row to a limit number of rows. This will
return a json object with users and total number of users in users list.
Args:
table (str): the table name
keyvalues (dict[str, Any] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
pagevalues ([]):
order (str): order the select by this column
start (int): start number to begin the query from
limit (int): number of rows to reterive
retcols (iterable[str]): the names of the columns to return
Returns:
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
"""
users = yield self.runInteraction(
desc,
self._simple_select_list_paginate_txn,
table,
keyvalues,
pagevalues,
retcols,
)
count = yield self.runInteraction(desc, self.get_user_count_txn)
retval = {"users": users, "total": count}
defer.returnValue(retval)
def get_user_count_txn(self, txn): def get_user_count_txn(self, txn):
"""Get a total number of registered users in the users list. """Get a total number of registered users in the users list.