fix bug #2926 (loading all state for a given type from the DB if the state_key is None) (#2990)

Fixes a regression that had crept in where the caching layer upholds requests for loading state which is filtered by type (but not by state_key), but the DB layer itself would interpret a missing state_key as a request to filter by null state_key rather than returning all state_keys.
This commit is contained in:
Matthew Hodgson 2018-03-13 22:36:04 +00:00 committed by GitHub
parent 1a69c6d590
commit d144ed6ffb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -240,6 +240,9 @@ class StateGroupWorkerStore(SQLBaseStore):
( (
"AND type = ? AND state_key = ?", "AND type = ? AND state_key = ?",
(etype, state_key) (etype, state_key)
) if state_key is not None else (
"AND type = ?",
(etype,)
) )
for etype, state_key in types for etype, state_key in types
] ]
@ -259,10 +262,19 @@ class StateGroupWorkerStore(SQLBaseStore):
key = (typ, state_key) key = (typ, state_key)
results[group][key] = event_id results[group][key] = event_id
else: else:
where_args = []
where_clauses = []
wildcard_types = False
if types is not None: if types is not None:
where_clause = "AND (%s)" % ( for typ in types:
" OR ".join(["(type = ? AND state_key = ?)"] * len(types)), if typ[1] is None:
) where_clauses.append("(type = ?)")
where_args.extend(typ[0])
wildcard_types = True
else:
where_clauses.append("(type = ? AND state_key = ?)")
where_args.extend([typ[0], typ[1]])
where_clause = "AND (%s)" % (" OR ".join(where_clauses))
else: else:
where_clause = "" where_clause = ""
@ -279,7 +291,7 @@ class StateGroupWorkerStore(SQLBaseStore):
# after we finish deduping state, which requires this func) # after we finish deduping state, which requires this func)
args = [next_group] args = [next_group]
if types: if types:
args.extend(i for typ in types for i in typ) args.extend(where_args)
txn.execute( txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state" "SELECT type, state_key, event_id FROM state_groups_state"
@ -292,9 +304,17 @@ class StateGroupWorkerStore(SQLBaseStore):
if (typ, state_key) not in results[group] if (typ, state_key) not in results[group]
) )
# If the lengths match then we must have all the types, # If the number of entries in the (type,state_key)->event_id dict
# so no need to go walk further down the tree. # matches the number of (type,state_keys) types we were searching
if types is not None and len(results[group]) == len(types): # for, then we must have found them all, so no need to go walk
# further down the tree... UNLESS our types filter contained
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
types is not None and
not wildcard_types and
len(results[group]) == len(types)
):
break break
next_group = self._simple_select_one_onecol_txn( next_group = self._simple_select_one_onecol_txn(