Add get_raw method to SimpleHttpClient, use this in CAS auth rather than requests

This commit is contained in:
Steven Hammerton 2015-10-09 11:02:56 +01:00
parent 22112f8d14
commit 625e13bfde
2 changed files with 47 additions and 27 deletions

View File

@ -160,27 +160,8 @@ class SimpleHttpClient(object):
On a non-2xx HTTP response. The response body will be used as the On a non-2xx HTTP response. The response body will be used as the
error message. error message.
""" """
if len(args): body = yield self.get_raw(uri, args)
query_bytes = urllib.urlencode(args, True) defer.returnValue(json.loads(body))
uri = "%s?%s" % (uri, query_bytes)
response = yield self.request(
"GET",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
})
)
body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
else:
# NB: This is explicitly not json.loads(body)'d because the contract
# of CodeMessageException is a *string* message. Callers can always
# load it into JSON if they want.
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, uri, json_body, args={}): def put_json(self, uri, json_body, args={}):
@ -209,7 +190,7 @@ class SimpleHttpClient(object):
"PUT", "PUT",
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers({
b"User-Agent": [self.user_agent], b"User-Agent": [self.version_string],
"Content-Type": ["application/json"] "Content-Type": ["application/json"]
}), }),
bodyProducer=FileBodyProducer(StringIO(json_str)) bodyProducer=FileBodyProducer(StringIO(json_str))
@ -225,6 +206,42 @@ class SimpleHttpClient(object):
# load it into JSON if they want. # load it into JSON if they want.
raise CodeMessageException(response.code, body) raise CodeMessageException(response.code, body)
@defer.inlineCallbacks
def get_raw(self, uri, args={}):
""" Gets raw text from the given URI.
Args:
uri (str): The URI to request, not including query parameters
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text.
Raises:
On a non-2xx HTTP response. The response body will be used as the
error message.
"""
if len(args):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
response = yield self.request(
"GET",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.version_string],
})
)
body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300:
defer.returnValue(body)
else:
raise CodeMessageException(response.code, body)
class CaptchaServerHttpClient(SimpleHttpClient): class CaptchaServerHttpClient(SimpleHttpClient):
""" """

View File

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, LoginError, Codes from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.types import UserID from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_pattern
@ -28,7 +29,6 @@ from saml2 import config
from saml2.client import Saml2Client from saml2.client import Saml2Client
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
import requests
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -79,13 +79,16 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
elif self.cas_enabled and (login_submission["type"] == elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE): LoginRestServlet.CAS_TYPE):
url = "%s/proxyValidate" % (self.cas_server_url) # TODO: get this from the homeserver rather than creating a new one for
parameters = { # each request
http_client = SimpleHttpClient(self.hs)
uri = "%s/proxyValidate" % (self.cas_server_url,)
args = {
"ticket": login_submission["ticket"], "ticket": login_submission["ticket"],
"service": login_submission["service"] "service": login_submission["service"]
} }
response = requests.get(url, verify=False, params=parameters) body = yield http_client.get_raw(uri, args)
result = yield self.do_cas_login(response.text) result = yield self.do_cas_login(body)
defer.returnValue(result) defer.returnValue(result)
else: else:
raise SynapseError(400, "Bad login type.") raise SynapseError(400, "Bad login type.")