Track IS used to bind 3PIDs

This will then be used to know which IS to default to when unbinding the
threepid.
This commit is contained in:
Erik Johnston 2019-04-01 10:16:13 +01:00
parent 215c15d049
commit 1666c0696a
3 changed files with 119 additions and 0 deletions

View File

@ -132,6 +132,14 @@ class IdentityHandler(BaseHandler):
} }
) )
logger.debug("bound threepid %r to %s", creds, mxid) logger.debug("bound threepid %r to %s", creds, mxid)
# Remember where we bound the threepid
yield self.store.add_user_bound_threepid(
user_id=mxid,
medium=data["medium"],
address=data["address"],
id_server=id_server,
)
except CodeMessageException as e: except CodeMessageException as e:
data = json.loads(e.msg) # XXX WAT? data = json.loads(e.msg) # XXX WAT?
defer.returnValue(data) defer.returnValue(data)
@ -188,6 +196,13 @@ class IdentityHandler(BaseHandler):
content, content,
headers, headers,
) )
yield self.store.remove_user_bound_threepid(
user_id=mxid,
medium=threepid["medium"],
address=threepid["address"],
id_server=id_server,
)
except HttpResponseException as e: except HttpResponseException as e:
if e.code in (400, 404, 501,): if e.code in (400, 404, 501,):
# The remote server probably doesn't support unbinding (yet) # The remote server probably doesn't support unbinding (yet)

View File

@ -328,6 +328,83 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="user_delete_threepids", desc="user_delete_threepids",
) )
def add_user_bound_threepid(self, user_id, medium, address, id_server):
"""The server proxied a bind request to the given identity server on
behalf of the given user. We need to remember this in case the user
asks us to unbind the threepid.
Args:
user_id (str)
medium (str)
address (str)
id_server (str)
Returns:
Deferred
"""
# We need to use an upsert, in case they user had already bound the
# threepid
return self._simple_upsert(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
"medium": medium,
"address": address,
"id_server": id_server,
},
values={},
insertion_values={},
desc="add_user_bound_threepid",
)
def remove_user_bound_threepid(self, user_id, medium, address, id_server):
"""The server proxied a bind request to the given identity server on
behalf of the given user. We need to remember this in case the user
asks us to unbind the threepid.
Args:
user_id (str)
medium (str)
address (str)
id_server (str)
Returns:
Deferred
"""
return self._simple_delete(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
"medium": medium,
"address": address,
"id_server": id_server,
},
desc="remove_user_bound_threepid",
)
def get_id_servers_user_bound(self, user_id, medium, address):
"""Get the list of identity servers that the server proxied bind
requests to for given user and threepid
Args:
user_id (str)
medium (str)
address (str)
Returns:
Deferred[list[str]]: Resolves to a list of identity servers
"""
return self._simple_select_onecol(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
"medium": medium,
"address": address,
},
retcol="id_server",
desc="get_id_servers_user_bound",
)
class RegistrationStore(RegistrationWorkerStore, class RegistrationStore(RegistrationWorkerStore,
background_updates.BackgroundUpdateStore): background_updates.BackgroundUpdateStore):

View File

@ -0,0 +1,27 @@
/* Copyright 2019 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Tracks when
CREATE TABlE user_threepid_id_server (
user_id TEXT NOT NULL,
medium TEXT NOT NULL,
address TEXT NOT NULL,
id_server TEXT NOT NULL
);
CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server(
user_id, medium, address, id_server
);