Use prefix matching

This commit is contained in:
Erik Johnston 2017-05-31 18:07:12 +01:00
parent f5cc22bdc6
commit a757dd4863

View File

@ -21,6 +21,8 @@ from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id
import re
class UserDirectoryStore(SQLBaseStore):
@ -272,17 +274,17 @@ class UserDirectoryStore(SQLBaseStore):
]
}
"""
search_query = _parse_query(self.database_engine, search_term)
if isinstance(self.database_engine, PostgresEngine):
sql = """
SELECT user_id, display_name, avatar_url
FROM user_directory_search
INNER JOIN user_directory USING (user_id)
WHERE vector @@ plainto_tsquery('english', ?)
ORDER BY ts_rank_cd(vector, plainto_tsquery('english', ?)) DESC
WHERE vector @@ to_tsquery('english', ?)
ORDER BY ts_rank_cd(vector, to_tsquery('english', ?)) DESC
LIMIT ?
"""
args = (search_term, search_term, limit + 1,)
args = (search_query, search_query, limit + 1,)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = """
SELECT user_id, display_name, avatar_url
@ -292,7 +294,7 @@ class UserDirectoryStore(SQLBaseStore):
ORDER BY rank(matchinfo(user_directory)) DESC
LIMIT ?
"""
args = (search_term, limit + 1)
args = (search_query, limit + 1)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@ -307,3 +309,25 @@ class UserDirectoryStore(SQLBaseStore):
"limited": limited,
"results": results,
})
def _parse_query(database_engine, search_term):
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
that is supported by default.
We specifically add both a prefix and non prefix matching term so that
exact matches get ranked higher.
"""
# Pull out the individual words, discarding any non-word characters.
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
if isinstance(database_engine, PostgresEngine):
return " & ".join("%s:* & %s" % (result, result,) for result in results)
elif isinstance(database_engine, Sqlite3Engine):
return " & ".join("%s* & %s" % (result, result,) for result in results)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")