Factor _AccountHandler proxy out to ModuleApi

We're going to need to use this from places that aren't password auth, so let's
move it to a proper class.
This commit is contained in:
Richard van der Hoff 2017-11-02 14:13:25 +00:00
parent b19d9e2174
commit 1189be43a2
3 changed files with 83 additions and 70 deletions

View File

@ -27,7 +27,7 @@ Password auth provider classes must provide the following methods:
*class* ``SomeProvider``\(*config*, *account_handler*) *class* ``SomeProvider``\(*config*, *account_handler*)
The constructor is passed the config object returned by ``parse_config``, The constructor is passed the config object returned by ``parse_config``,
and a ``synapse.handlers.auth._AccountHandler`` object which allows the and a ``synapse.module_api.ModuleApi`` object which allows the
password provider to check if accounts exist and/or create new ones. password provider to check if accounts exist and/or create new ones.
Optional methods Optional methods

View File

@ -13,13 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.module_api import ModuleApi
from synapse.types import UserID
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -63,10 +63,7 @@ class AuthHandler(BaseHandler):
reset_expiry_on_get=True, reset_expiry_on_get=True,
) )
account_handler = _AccountHandler( account_handler = ModuleApi(hs, self)
hs, check_user_exists=self.check_user_exists
)
self.password_providers = [ self.password_providers = [
module(config=config, account_handler=account_handler) module(config=config, account_handler=account_handler)
for module, config in hs.config.password_providers for module, config in hs.config.password_providers
@ -843,66 +840,3 @@ class MacaroonGeneartor(object):
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon return macaroon
class _AccountHandler(object):
"""A proxy object that gets passed to password auth providers so they
can register new users etc if necessary.
"""
def __init__(self, hs, check_user_exists):
self.hs = hs
self._check_user_exists = check_user_exists
self._store = hs.get_datastore()
def get_qualified_user_id(self, username):
"""Qualify a user id, if necessary
Takes a user id provided by the user and adds the @ and :domain to
qualify it, if necessary
Args:
username (str): provided user id
Returns:
str: qualified @user:id
"""
if username.startswith('@'):
return username
return UserID(username, self.hs.hostname).to_string()
def check_user_exists(self, user_id):
"""Check if user exists.
Args:
user_id (str): Complete @user:id
Returns:
Deferred[str|None]: Canonical (case-corrected) user_id, or None
if the user is not registered.
"""
return self._check_user_exists(user_id)
def register(self, localpart):
"""Registers a new user with given localpart
Returns:
Deferred: a 2-tuple of (user_id, access_token)
"""
reg = self.hs.get_handlers().registration_handler
return reg.register(localpart=localpart)
def run_db_interaction(self, desc, func, *args, **kwargs):
"""Run a function with a database connection
Args:
desc (str): description for the transaction, for metrics etc
func (func): function to be run. Passed a database cursor object
as well as *args and **kwargs
*args: positional args to be passed to func
**kwargs: named args to be passed to func
Returns:
Deferred[object]: result of func
"""
return self._store.runInteraction(desc, func, *args, **kwargs)

View File

@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.types import UserID
class ModuleApi(object):
"""A proxy object that gets passed to password auth providers so they
can register new users etc if necessary.
"""
def __init__(self, hs, auth_handler):
self.hs = hs
self._store = hs.get_datastore()
self._auth_handler = auth_handler
def get_qualified_user_id(self, username):
"""Qualify a user id, if necessary
Takes a user id provided by the user and adds the @ and :domain to
qualify it, if necessary
Args:
username (str): provided user id
Returns:
str: qualified @user:id
"""
if username.startswith('@'):
return username
return UserID(username, self.hs.hostname).to_string()
def check_user_exists(self, user_id):
"""Check if user exists.
Args:
user_id (str): Complete @user:id
Returns:
Deferred[str|None]: Canonical (case-corrected) user_id, or None
if the user is not registered.
"""
return self._auth_handler.check_user_exists(user_id)
def register(self, localpart):
"""Registers a new user with given localpart
Returns:
Deferred: a 2-tuple of (user_id, access_token)
"""
reg = self.hs.get_handlers().registration_handler
return reg.register(localpart=localpart)
def run_db_interaction(self, desc, func, *args, **kwargs):
"""Run a function with a database connection
Args:
desc (str): description for the transaction, for metrics etc
func (func): function to be run. Passed a database cursor object
as well as *args and **kwargs
*args: positional args to be passed to func
**kwargs: named args to be passed to func
Returns:
Deferred[object]: result of func
"""
return self._store.runInteraction(desc, func, *args, **kwargs)