Fix unit tests

This commit is contained in:
Mark Haines 2016-09-12 10:46:02 +01:00
parent 3ddec016ff
commit ec609f8094
6 changed files with 30 additions and 15 deletions

View File

@ -1180,7 +1180,7 @@ def get_access_token_from_request(request, token_not_found_http_status=401):
auth_headers = request.requestHeaders.getRawHeaders("Authorization") auth_headers = request.requestHeaders.getRawHeaders("Authorization")
query_params = request.args.get("access_token") query_params = request.args.get("access_token")
if auth_headers is not None: if auth_headers:
# Try the get the access_token from a "Authorization: Bearer" # Try the get the access_token from a "Authorization: Bearer"
# header # header
if query_params is not None: if query_params is not None:

View File

@ -20,7 +20,7 @@ from mock import Mock
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.types import UserID from synapse.types import UserID
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver, mock_getRawHeaders
import pymacaroons import pymacaroons
@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -74,7 +74,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=user_info) self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user) self.assertEquals(requester.user.to_string(), self.test_user)
@ -96,7 +96,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -106,7 +106,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), masquerading_user_id) self.assertEquals(requester.user.to_string(), masquerading_user_id)
@ -135,7 +135,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)

View File

@ -219,7 +219,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
"user_id": self.u_onion.to_string(), "user_id": self.u_onion.to_string(),
"typing": True, "typing": True,
} }
) ),
federation_auth=True,
) )
self.on_new_event.assert_has_calls([ self.on_new_event.assert_has_calls([

View File

@ -17,6 +17,7 @@ from synapse.rest.client.v1.register import CreateUserRestServlet
from twisted.internet import defer from twisted.internet import defer
from mock import Mock from mock import Mock
from tests import unittest from tests import unittest
from tests.utils import mock_getRawHeaders
import json import json
@ -30,6 +31,7 @@ class CreateUserServletTestCase(unittest.TestCase):
path='/_matrix/client/api/v1/createUser' path='/_matrix/client/api/v1/createUser'
) )
self.request.args = {} self.request.args = {}
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.appservice = None self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock( self.auth = Mock(get_appservice_by_req=Mock(

View File

@ -3,6 +3,7 @@ from synapse.api.errors import SynapseError
from twisted.internet import defer from twisted.internet import defer
from mock import Mock from mock import Mock
from tests import unittest from tests import unittest
from tests.utils import mock_getRawHeaders
import json import json
@ -16,6 +17,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
path='/_matrix/api/v2_alpha/register' path='/_matrix/api/v2_alpha/register'
) )
self.request.args = {} self.request.args = {}
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.appservice = None self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock( self.auth = Mock(get_appservice_by_req=Mock(

View File

@ -115,6 +115,15 @@ def get_mock_call_args(pattern_func, mock_func):
return getcallargs(pattern_func, *invoked_args, **invoked_kargs) return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {}
def getRawHeaders(name, default=None):
return headers.get(name, default)
return getRawHeaders
# This is a mock /resource/ not an entire server # This is a mock /resource/ not an entire server
class MockHttpResource(HttpServer): class MockHttpResource(HttpServer):
@ -127,7 +136,7 @@ class MockHttpResource(HttpServer):
@patch('twisted.web.http.Request') @patch('twisted.web.http.Request')
@defer.inlineCallbacks @defer.inlineCallbacks
def trigger(self, http_method, path, content, mock_request): def trigger(self, http_method, path, content, mock_request, federation_auth=False):
""" Fire an HTTP event. """ Fire an HTTP event.
Args: Args:
@ -155,9 +164,10 @@ class MockHttpResource(HttpServer):
mock_request.getClientIP.return_value = "-" mock_request.getClientIP.return_value = "-"
mock_request.requestHeaders.getRawHeaders.return_value = [ headers = {}
"X-Matrix origin=test,key=,sig=" if federation_auth:
] headers["Authorization"] = ["X-Matrix origin=test,key=,sig="]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it # return the right path if the event requires it
mock_request.path = path mock_request.path = path