mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-09-18 21:24:36 -04:00
Reference Matrix Home Server
This commit is contained in:
commit
4f475c7697
217 changed files with 48447 additions and 0 deletions
16
synapse/__init__.py
Normal file
16
synapse/__init__.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" This is a reference implementation of a synapse home server.
|
||||
"""
|
14
synapse/api/__init__.py
Normal file
14
synapse/api/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
164
synapse/api/auth.py
Normal file
164
synapse/api/auth.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This module contains classes for authenticating the user."""
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import AuthError, StoreError
|
||||
from synapse.api.events.room import (RoomTopicEvent, RoomMemberEvent,
|
||||
MessageEvent, FeedbackEvent)
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Auth(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check(self, event, raises=False):
|
||||
""" Checks if this event is correctly authed.
|
||||
|
||||
Returns:
|
||||
True if the auth checks pass.
|
||||
Raises:
|
||||
AuthError if there was a problem authorising this event. This will
|
||||
be raised only if raises=True.
|
||||
"""
|
||||
try:
|
||||
if event.type in [RoomTopicEvent.TYPE, MessageEvent.TYPE,
|
||||
FeedbackEvent.TYPE]:
|
||||
yield self.check_joined_room(event.room_id, event.user_id)
|
||||
defer.returnValue(True)
|
||||
elif event.type == RoomMemberEvent.TYPE:
|
||||
allowed = yield self.is_membership_change_allowed(event)
|
||||
defer.returnValue(allowed)
|
||||
else:
|
||||
raise AuthError(500, "Unknown event type %s" % event.type)
|
||||
except AuthError as e:
|
||||
logger.info("Event auth check failed on event %s with msg: %s",
|
||||
event, e.msg)
|
||||
if raises:
|
||||
raise e
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_joined_room(self, room_id, user_id):
|
||||
try:
|
||||
member = yield self.store.get_room_member(
|
||||
room_id=room_id,
|
||||
user_id=user_id
|
||||
)
|
||||
if not member or member.membership != Membership.JOIN:
|
||||
raise AuthError(403, "User %s not in room %s" %
|
||||
(user_id, room_id))
|
||||
defer.returnValue(member)
|
||||
except AttributeError:
|
||||
pass
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_membership_change_allowed(self, event):
|
||||
# does this room even exist
|
||||
room = yield self.store.get_room(event.room_id)
|
||||
if not room:
|
||||
raise AuthError(403, "Room does not exist")
|
||||
|
||||
# get info about the caller
|
||||
try:
|
||||
caller = yield self.store.get_room_member(
|
||||
user_id=event.user_id,
|
||||
room_id=event.room_id)
|
||||
except:
|
||||
caller = None
|
||||
caller_in_room = caller and caller.membership == "join"
|
||||
|
||||
# get info about the target
|
||||
try:
|
||||
target = yield self.store.get_room_member(
|
||||
user_id=event.target_user_id,
|
||||
room_id=event.room_id)
|
||||
except:
|
||||
target = None
|
||||
target_in_room = target and target.membership == "join"
|
||||
|
||||
membership = event.content["membership"]
|
||||
|
||||
if Membership.INVITE == membership:
|
||||
# Invites are valid iff caller is in the room and target isn't.
|
||||
if not caller_in_room: # caller isn't joined
|
||||
raise AuthError(403, "You are not in room %s." % event.room_id)
|
||||
elif target_in_room: # the target is already in the room.
|
||||
raise AuthError(403, "%s is already in the room." %
|
||||
event.target_user_id)
|
||||
elif Membership.JOIN == membership:
|
||||
# Joins are valid iff caller == target and they were:
|
||||
# invited: They are accepting the invitation
|
||||
# joined: It's a NOOP
|
||||
if event.user_id != event.target_user_id:
|
||||
raise AuthError(403, "Cannot force another user to join.")
|
||||
elif room.is_public:
|
||||
pass # anyone can join public rooms.
|
||||
elif (not caller or caller.membership not in
|
||||
[Membership.INVITE, Membership.JOIN]):
|
||||
raise AuthError(403, "You are not invited to this room.")
|
||||
elif Membership.LEAVE == membership:
|
||||
if not caller_in_room: # trying to leave a room you aren't joined
|
||||
raise AuthError(403, "You are not in room %s." % event.room_id)
|
||||
elif event.target_user_id != event.user_id:
|
||||
# trying to force another user to leave
|
||||
raise AuthError(403, "Cannot force %s to leave." %
|
||||
event.target_user_id)
|
||||
else:
|
||||
raise AuthError(500, "Unknown membership %s" % membership)
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
def get_user_by_req(self, request):
|
||||
""" Get a registered user's ID.
|
||||
|
||||
Args:
|
||||
request - An HTTP request with an access_token query parameter.
|
||||
Returns:
|
||||
UserID : User ID object of the user making the request
|
||||
Raises:
|
||||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
# Can optionally look elsewhere in the request (e.g. headers)
|
||||
try:
|
||||
return self.get_user_by_token(request.args["access_token"][0])
|
||||
except KeyError:
|
||||
raise AuthError(403, "Missing access token.")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_by_token(self, token):
|
||||
""" Get a registered user's ID.
|
||||
|
||||
Args:
|
||||
token (str)- The access token to get the user by.
|
||||
Returns:
|
||||
UserID : User ID object of the user who has that access token.
|
||||
Raises:
|
||||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
try:
|
||||
user_id = yield self.store.get_user_by_token(token=token)
|
||||
defer.returnValue(self.hs.parse_userid(user_id))
|
||||
except StoreError:
|
||||
raise AuthError(403, "Unrecognised access token.")
|
42
synapse/api/constants.py
Normal file
42
synapse/api/constants.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains constants from the specification."""
|
||||
|
||||
|
||||
class Membership(object):
|
||||
|
||||
"""Represents the membership states of a user in a room."""
|
||||
INVITE = u"invite"
|
||||
JOIN = u"join"
|
||||
KNOCK = u"knock"
|
||||
LEAVE = u"leave"
|
||||
|
||||
|
||||
class Feedback(object):
|
||||
|
||||
"""Represents the types of feedback a user can send in response to a
|
||||
message."""
|
||||
|
||||
DELIVERED = u"d"
|
||||
READ = u"r"
|
||||
LIST = (DELIVERED, READ)
|
||||
|
||||
|
||||
class PresenceState(object):
|
||||
"""Represents the presence state of a user."""
|
||||
OFFLINE = 0
|
||||
BUSY = 1
|
||||
ONLINE = 2
|
||||
FREE_FOR_CHAT = 3
|
114
synapse/api/errors.py
Normal file
114
synapse/api/errors.py
Normal file
|
@ -0,0 +1,114 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains exceptions and error codes."""
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class Codes(object):
|
||||
FORBIDDEN = "M_FORBIDDEN"
|
||||
BAD_JSON = "M_BAD_JSON"
|
||||
NOT_JSON = "M_NOT_JSON"
|
||||
USER_IN_USE = "M_USER_IN_USE"
|
||||
ROOM_IN_USE = "M_ROOM_IN_USE"
|
||||
BAD_PAGINATION = "M_BAD_PAGINATION"
|
||||
UNKNOWN = "M_UNKNOWN"
|
||||
NOT_FOUND = "M_NOT_FOUND"
|
||||
|
||||
|
||||
class CodeMessageException(Exception):
|
||||
"""An exception with integer code and message string attributes."""
|
||||
|
||||
def __init__(self, code, msg):
|
||||
logging.error("%s: %s, %s", type(self).__name__, code, msg)
|
||||
super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
|
||||
self.code = code
|
||||
self.msg = msg
|
||||
|
||||
|
||||
class SynapseError(CodeMessageException):
|
||||
"""A base error which can be caught for all synapse events."""
|
||||
def __init__(self, code, msg, errcode=""):
|
||||
"""Constructs a synapse error.
|
||||
|
||||
Args:
|
||||
code (int): The integer error code (typically an HTTP response code)
|
||||
msg (str): The human-readable error message.
|
||||
err (str): The error code e.g 'M_FORBIDDEN'
|
||||
"""
|
||||
super(SynapseError, self).__init__(code, msg)
|
||||
self.errcode = errcode
|
||||
|
||||
|
||||
class RoomError(SynapseError):
|
||||
"""An error raised when a room event fails."""
|
||||
pass
|
||||
|
||||
|
||||
class RegistrationError(SynapseError):
|
||||
"""An error raised when a registration event fails."""
|
||||
pass
|
||||
|
||||
|
||||
class AuthError(SynapseError):
|
||||
"""An error raised when there was a problem authorising an event."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if "errcode" not in kwargs:
|
||||
kwargs["errcode"] = Codes.FORBIDDEN
|
||||
super(AuthError, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class EventStreamError(SynapseError):
|
||||
"""An error raised when there a problem with the event stream."""
|
||||
pass
|
||||
|
||||
|
||||
class LoginError(SynapseError):
|
||||
"""An error raised when there was a problem logging in."""
|
||||
pass
|
||||
|
||||
|
||||
class StoreError(SynapseError):
|
||||
"""An error raised when there was a problem storing some data."""
|
||||
pass
|
||||
|
||||
|
||||
def cs_exception(exception):
|
||||
if isinstance(exception, SynapseError):
|
||||
return cs_error(
|
||||
exception.msg,
|
||||
Codes.UNKNOWN if not exception.errcode else exception.errcode)
|
||||
elif isinstance(exception, CodeMessageException):
|
||||
return cs_error(exception.msg)
|
||||
else:
|
||||
logging.error("Unknown exception type: %s", type(exception))
|
||||
|
||||
|
||||
def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
|
||||
""" Utility method for constructing an error response for client-server
|
||||
interactions.
|
||||
|
||||
Args:
|
||||
msg (str): The error message.
|
||||
code (int): The error code.
|
||||
kwargs : Additional keys to add to the response.
|
||||
Returns:
|
||||
A dict representing the error response JSON.
|
||||
"""
|
||||
err = {"error": msg, "errcode": code}
|
||||
for key, value in kwargs.iteritems():
|
||||
err[key] = value
|
||||
return err
|
152
synapse/api/events/__init__.py
Normal file
152
synapse/api/events/__init__.py
Normal file
|
@ -0,0 +1,152 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.util.jsonobject import JsonEncodedObject
|
||||
|
||||
|
||||
class SynapseEvent(JsonEncodedObject):
|
||||
|
||||
"""Base class for Synapse events. These are JSON objects which must abide
|
||||
by a certain well-defined structure.
|
||||
"""
|
||||
|
||||
# Attributes that are currently assumed by the federation side:
|
||||
# Mandatory:
|
||||
# - event_id
|
||||
# - room_id
|
||||
# - type
|
||||
# - is_state
|
||||
#
|
||||
# Optional:
|
||||
# - state_key (mandatory when is_state is True)
|
||||
# - prev_events (these can be filled out by the federation layer itself.)
|
||||
# - prev_state
|
||||
|
||||
valid_keys = [
|
||||
"event_id",
|
||||
"type",
|
||||
"room_id",
|
||||
"user_id", # sender/initiator
|
||||
"content", # HTTP body, JSON
|
||||
]
|
||||
|
||||
internal_keys = [
|
||||
"is_state",
|
||||
"state_key",
|
||||
"prev_events",
|
||||
"prev_state",
|
||||
"depth",
|
||||
"destinations",
|
||||
"origin",
|
||||
]
|
||||
|
||||
required_keys = [
|
||||
"event_id",
|
||||
"room_id",
|
||||
"content",
|
||||
]
|
||||
|
||||
def __init__(self, raises=True, **kwargs):
|
||||
super(SynapseEvent, self).__init__(**kwargs)
|
||||
if "content" in kwargs:
|
||||
self.check_json(self.content, raises=raises)
|
||||
|
||||
def get_content_template(self):
|
||||
""" Retrieve the JSON template for this event as a dict.
|
||||
|
||||
The template must be a dict representing the JSON to match. Only
|
||||
required keys should be present. The values of the keys in the template
|
||||
are checked via type() to the values of the same keys in the actual
|
||||
event JSON.
|
||||
|
||||
NB: If loading content via json.loads, you MUST define strings as
|
||||
unicode.
|
||||
|
||||
For example:
|
||||
Content:
|
||||
{
|
||||
"name": u"bob",
|
||||
"age": 18,
|
||||
"friends": [u"mike", u"jill"]
|
||||
}
|
||||
Template:
|
||||
{
|
||||
"name": u"string",
|
||||
"age": 0,
|
||||
"friends": [u"string"]
|
||||
}
|
||||
The values "string" and 0 could be anything, so long as the types
|
||||
are the same as the content.
|
||||
"""
|
||||
raise NotImplementedError("get_content_template not implemented.")
|
||||
|
||||
def check_json(self, content, raises=True):
|
||||
"""Checks the given JSON content abides by the rules of the template.
|
||||
|
||||
Args:
|
||||
content : A JSON object to check.
|
||||
raises: True to raise a SynapseError if the check fails.
|
||||
Returns:
|
||||
True if the content passes the template. Returns False if the check
|
||||
fails and raises=False.
|
||||
Raises:
|
||||
SynapseError if the check fails and raises=True.
|
||||
"""
|
||||
# recursively call to inspect each layer
|
||||
err_msg = self._check_json(content, self.get_content_template())
|
||||
if err_msg:
|
||||
if raises:
|
||||
raise SynapseError(400, err_msg, Codes.BAD_JSON)
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _check_json(self, content, template):
|
||||
"""Check content and template matches.
|
||||
|
||||
If the template is a dict, each key in the dict will be validated with
|
||||
the content, else it will just compare the types of content and
|
||||
template. This basic type check is required because this function will
|
||||
be recursively called and could be called with just strs or ints.
|
||||
|
||||
Args:
|
||||
content: The content to validate.
|
||||
template: The validation template.
|
||||
Returns:
|
||||
str: An error message if the validation fails, else None.
|
||||
"""
|
||||
if type(content) != type(template):
|
||||
return "Mismatched types: %s" % template
|
||||
|
||||
if type(template) == dict:
|
||||
for key in template:
|
||||
if key not in content:
|
||||
return "Missing %s key" % key
|
||||
|
||||
if type(content[key]) != type(template[key]):
|
||||
return "Key %s is of the wrong type." % key
|
||||
|
||||
if type(content[key]) == dict:
|
||||
# we must go deeper
|
||||
msg = self._check_json(content[key], template[key])
|
||||
if msg:
|
||||
return msg
|
||||
elif type(content[key]) == list:
|
||||
# make sure each item type in content matches the template
|
||||
for entry in content[key]:
|
||||
msg = self._check_json(entry, template[key][0])
|
||||
if msg:
|
||||
return msg
|
50
synapse/api/events/factory.py
Normal file
50
synapse/api/events/factory.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.api.events.room import (
|
||||
RoomTopicEvent, MessageEvent, RoomMemberEvent, FeedbackEvent,
|
||||
InviteJoinEvent, RoomConfigEvent
|
||||
)
|
||||
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
|
||||
class EventFactory(object):
|
||||
|
||||
_event_classes = [
|
||||
RoomTopicEvent,
|
||||
MessageEvent,
|
||||
RoomMemberEvent,
|
||||
FeedbackEvent,
|
||||
InviteJoinEvent,
|
||||
RoomConfigEvent
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
self._event_list = {} # dict of TYPE to event class
|
||||
for event_class in EventFactory._event_classes:
|
||||
self._event_list[event_class.TYPE] = event_class
|
||||
|
||||
def create_event(self, etype=None, **kwargs):
|
||||
kwargs["type"] = etype
|
||||
if "event_id" not in kwargs:
|
||||
kwargs["event_id"] = random_string(10)
|
||||
|
||||
try:
|
||||
handler = self._event_list[etype]
|
||||
except KeyError: # unknown event type
|
||||
# TODO allow custom event types.
|
||||
raise NotImplementedError("Unknown etype=%s" % etype)
|
||||
|
||||
return handler(**kwargs)
|
99
synapse/api/events/room.py
Normal file
99
synapse/api/events/room.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from . import SynapseEvent
|
||||
|
||||
|
||||
class RoomTopicEvent(SynapseEvent):
|
||||
TYPE = "m.room.topic"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["state_key"] = ""
|
||||
super(RoomTopicEvent, self).__init__(**kwargs)
|
||||
|
||||
def get_content_template(self):
|
||||
return {"topic": u"string"}
|
||||
|
||||
|
||||
class RoomMemberEvent(SynapseEvent):
|
||||
TYPE = "m.room.member"
|
||||
|
||||
valid_keys = SynapseEvent.valid_keys + [
|
||||
"target_user_id", # target
|
||||
"membership", # action
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if "target_user_id" in kwargs:
|
||||
kwargs["state_key"] = kwargs["target_user_id"]
|
||||
super(RoomMemberEvent, self).__init__(**kwargs)
|
||||
|
||||
def get_content_template(self):
|
||||
return {"membership": u"string"}
|
||||
|
||||
|
||||
class MessageEvent(SynapseEvent):
|
||||
TYPE = "m.room.message"
|
||||
|
||||
valid_keys = SynapseEvent.valid_keys + [
|
||||
"msg_id", # unique per room + user combo
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(MessageEvent, self).__init__(**kwargs)
|
||||
|
||||
def get_content_template(self):
|
||||
return {"msgtype": u"string"}
|
||||
|
||||
|
||||
class FeedbackEvent(SynapseEvent):
|
||||
TYPE = "m.room.message.feedback"
|
||||
|
||||
valid_keys = SynapseEvent.valid_keys + [
|
||||
"msg_id", # the message ID being acknowledged
|
||||
"msg_sender_id", # person who is sending the feedback is 'user_id'
|
||||
"feedback_type", # the type of feedback (delivery, read, etc)
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(FeedbackEvent, self).__init__(**kwargs)
|
||||
|
||||
def get_content_template(self):
|
||||
return {}
|
||||
|
||||
|
||||
class InviteJoinEvent(SynapseEvent):
|
||||
TYPE = "m.room.invite_join"
|
||||
|
||||
valid_keys = SynapseEvent.valid_keys + [
|
||||
"target_user_id",
|
||||
"target_host",
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(InviteJoinEvent, self).__init__(**kwargs)
|
||||
|
||||
def get_content_template(self):
|
||||
return {}
|
||||
|
||||
|
||||
class RoomConfigEvent(SynapseEvent):
|
||||
TYPE = "m.room.config"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["state_key"] = ""
|
||||
super(RoomConfigEvent, self).__init__(**kwargs)
|
||||
|
||||
def get_content_template(self):
|
||||
return {}
|
186
synapse/api/notifier.py
Normal file
186
synapse/api/notifier.py
Normal file
|
@ -0,0 +1,186 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.events.room import RoomMemberEvent
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet import reactor
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Notifier(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.hs = hs
|
||||
self.stored_event_listeners = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_new_room_event(self, event, store_id):
|
||||
"""Called when there is a new room event which may potentially be sent
|
||||
down listening users' event streams.
|
||||
|
||||
This function looks for interested *users* who may want to be notified
|
||||
for this event. This is different to users requesting from the event
|
||||
stream which looks for interested *events* for this user.
|
||||
|
||||
Args:
|
||||
event (SynapseEvent): The new event, which must have a room_id
|
||||
store_id (int): The ID of this event after it was stored with the
|
||||
data store.
|
||||
'"""
|
||||
member_list = yield self.store.get_room_members(room_id=event.room_id,
|
||||
membership="join")
|
||||
if not member_list:
|
||||
member_list = []
|
||||
|
||||
member_list = [u.user_id for u in member_list]
|
||||
|
||||
# invites MUST prod the person being invited, who won't be in the room.
|
||||
if (event.type == RoomMemberEvent.TYPE and
|
||||
event.content["membership"] == Membership.INVITE):
|
||||
member_list.append(event.target_user_id)
|
||||
|
||||
for user_id in member_list:
|
||||
if user_id in self.stored_event_listeners:
|
||||
self._notify_and_callback(
|
||||
user_id=user_id,
|
||||
event_data=event.get_dict(),
|
||||
stream_type=event.type,
|
||||
store_id=store_id)
|
||||
|
||||
def on_new_user_event(self, user_id, event_data, stream_type, store_id):
|
||||
if user_id in self.stored_event_listeners:
|
||||
self._notify_and_callback(
|
||||
user_id=user_id,
|
||||
event_data=event_data,
|
||||
stream_type=stream_type,
|
||||
store_id=store_id
|
||||
)
|
||||
|
||||
def _notify_and_callback(self, user_id, event_data, stream_type, store_id):
|
||||
logger.debug(
|
||||
"Notifying %s of a new event.",
|
||||
user_id
|
||||
)
|
||||
|
||||
stream_ids = list(self.stored_event_listeners[user_id])
|
||||
for stream_id in stream_ids:
|
||||
self._notify_and_callback_stream(user_id, stream_id, event_data,
|
||||
stream_type, store_id)
|
||||
|
||||
if not self.stored_event_listeners[user_id]:
|
||||
del self.stored_event_listeners[user_id]
|
||||
|
||||
def _notify_and_callback_stream(self, user_id, stream_id, event_data,
|
||||
stream_type, store_id):
|
||||
|
||||
event_listener = self.stored_event_listeners[user_id].pop(stream_id)
|
||||
return_event_object = {
|
||||
k: event_listener[k] for k in ["start", "chunk", "end"]
|
||||
}
|
||||
|
||||
# work out the new end token
|
||||
token = event_listener["start"]
|
||||
end = self._next_token(stream_type, store_id, token)
|
||||
return_event_object["end"] = end
|
||||
|
||||
# add the event to the chunk
|
||||
chunk = event_listener["chunk"]
|
||||
chunk.append(event_data)
|
||||
|
||||
# callback the defer. We know this can't have been resolved before as
|
||||
# we always remove the event_listener from the map before resolving.
|
||||
event_listener["defer"].callback(return_event_object)
|
||||
|
||||
def _next_token(self, stream_type, store_id, current_token):
|
||||
stream_handler = self.hs.get_handlers().event_stream_handler
|
||||
return stream_handler.get_event_stream_token(
|
||||
stream_type,
|
||||
store_id,
|
||||
current_token
|
||||
)
|
||||
|
||||
def store_events_for(self, user_id=None, stream_id=None, from_tok=None):
|
||||
"""Store all incoming events for this user. This should be paired with
|
||||
get_events_for to return chunked data.
|
||||
|
||||
Args:
|
||||
user_id (str): The user to monitor incoming events for.
|
||||
stream (object): The stream that is receiving events
|
||||
from_tok (str): The token to monitor incoming events from.
|
||||
"""
|
||||
event_listener = {
|
||||
"start": from_tok,
|
||||
"chunk": [],
|
||||
"end": from_tok,
|
||||
"defer": defer.Deferred(),
|
||||
}
|
||||
|
||||
if user_id not in self.stored_event_listeners:
|
||||
self.stored_event_listeners[user_id] = {stream_id: event_listener}
|
||||
else:
|
||||
self.stored_event_listeners[user_id][stream_id] = event_listener
|
||||
|
||||
def purge_events_for(self, user_id=None, stream_id=None):
|
||||
"""Purges any stored events for this user.
|
||||
|
||||
Args:
|
||||
user_id (str): The user to purge stored events for.
|
||||
"""
|
||||
try:
|
||||
del self.stored_event_listeners[user_id][stream_id]
|
||||
if not self.stored_event_listeners[user_id]:
|
||||
del self.stored_event_listeners[user_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def get_events_for(self, user_id=None, stream_id=None, timeout=0):
|
||||
"""Retrieve stored events for this user, waiting if necessary.
|
||||
|
||||
It is advisable to wrap this call in a maybeDeferred.
|
||||
|
||||
Args:
|
||||
user_id (str): The user to get events for.
|
||||
timeout (int): The time in seconds to wait before giving up.
|
||||
Returns:
|
||||
A Deferred or a dict containing the chunk data, depending on if
|
||||
there was data to return yet. The Deferred callback may be None if
|
||||
there were no events before the timeout expired.
|
||||
"""
|
||||
logger.debug("%s is listening for events.", user_id)
|
||||
|
||||
if len(self.stored_event_listeners[user_id][stream_id]["chunk"]) > 0:
|
||||
logger.debug("%s returning existing chunk.", user_id)
|
||||
return self.stored_event_listeners[user_id][stream_id]
|
||||
|
||||
reactor.callLater(
|
||||
(timeout / 1000.0), self._timeout, user_id, stream_id
|
||||
)
|
||||
return self.stored_event_listeners[user_id][stream_id]["defer"]
|
||||
|
||||
def _timeout(self, user_id, stream_id):
|
||||
try:
|
||||
# We remove the event_listener from the map so that we can't
|
||||
# resolve the deferred twice.
|
||||
event_listeners = self.stored_event_listeners[user_id]
|
||||
event_listener = event_listeners.pop(stream_id)
|
||||
event_listener["defer"].callback(None)
|
||||
logger.debug("%s event listening timed out.", user_id)
|
||||
except KeyError:
|
||||
pass
|
96
synapse/api/streams/__init__.py
Normal file
96
synapse/api/streams/__init__.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
|
||||
class PaginationConfig(object):
|
||||
|
||||
"""A configuration object which stores pagination parameters."""
|
||||
|
||||
def __init__(self, from_tok=None, to_tok=None, limit=0):
|
||||
self.from_tok = from_tok
|
||||
self.to_tok = to_tok
|
||||
self.limit = limit
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, request, raise_invalid_params=True):
|
||||
params = {
|
||||
"from_tok": PaginationStream.TOK_START,
|
||||
"to_tok": PaginationStream.TOK_END,
|
||||
"limit": 0
|
||||
}
|
||||
|
||||
query_param_mappings = [ # 3-tuple of qp_key, attribute, rules
|
||||
("from", "from_tok", lambda x: type(x) == str),
|
||||
("to", "to_tok", lambda x: type(x) == str),
|
||||
("limit", "limit", lambda x: x.isdigit())
|
||||
]
|
||||
|
||||
for qp, attr, is_valid in query_param_mappings:
|
||||
if qp in request.args:
|
||||
if is_valid(request.args[qp][0]):
|
||||
params[attr] = request.args[qp][0]
|
||||
elif raise_invalid_params:
|
||||
raise SynapseError(400, "%s parameter is invalid." % qp)
|
||||
|
||||
return PaginationConfig(**params)
|
||||
|
||||
|
||||
class PaginationStream(object):
|
||||
|
||||
""" An interface for streaming data as chunks. """
|
||||
|
||||
TOK_START = "START"
|
||||
TOK_END = "END"
|
||||
|
||||
def get_chunk(self, config=None):
|
||||
""" Return the next chunk in the stream.
|
||||
|
||||
Args:
|
||||
config (PaginationConfig): The config to aid which chunk to get.
|
||||
Returns:
|
||||
A dict containing the new start token "start", the new end token
|
||||
"end" and the data "chunk" as a list.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class StreamData(object):
|
||||
|
||||
""" An interface for obtaining streaming data from a table. """
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
def get_rows(self, user_id, from_pkey, to_pkey, limit):
|
||||
""" Get event stream data between the specified pkeys.
|
||||
|
||||
Args:
|
||||
user_id : The user's ID
|
||||
from_pkey : The starting pkey.
|
||||
to_pkey : The end pkey. May be -1 to mean "latest".
|
||||
limit: The max number of results to return.
|
||||
Returns:
|
||||
A tuple containing the list of event stream data and the last pkey.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def max_token(self):
|
||||
""" Get the latest currently-valid token.
|
||||
|
||||
Returns:
|
||||
The latest token."""
|
||||
raise NotImplementedError()
|
247
synapse/api/streams/event.py
Normal file
247
synapse/api/streams/event.py
Normal file
|
@ -0,0 +1,247 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This module contains classes for streaming from the event stream: /events.
|
||||
"""
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import EventStreamError
|
||||
from synapse.api.events.room import (
|
||||
RoomMemberEvent, MessageEvent, FeedbackEvent, RoomTopicEvent
|
||||
)
|
||||
from synapse.api.streams import PaginationStream, StreamData
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessagesStreamData(StreamData):
|
||||
EVENT_TYPE = MessageEvent.TYPE
|
||||
|
||||
def __init__(self, hs, room_id=None, feedback=False):
|
||||
super(MessagesStreamData, self).__init__(hs)
|
||||
self.room_id = room_id
|
||||
self.with_feedback = feedback
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rows(self, user_id, from_key, to_key, limit):
|
||||
(data, latest_ver) = yield self.store.get_message_stream(
|
||||
user_id=user_id,
|
||||
from_key=from_key,
|
||||
to_key=to_key,
|
||||
limit=limit,
|
||||
room_id=self.room_id,
|
||||
with_feedback=self.with_feedback
|
||||
)
|
||||
defer.returnValue((data, latest_ver))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def max_token(self):
|
||||
val = yield self.store.get_max_message_id()
|
||||
defer.returnValue(val)
|
||||
|
||||
|
||||
class RoomMemberStreamData(StreamData):
|
||||
EVENT_TYPE = RoomMemberEvent.TYPE
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rows(self, user_id, from_key, to_key, limit):
|
||||
(data, latest_ver) = yield self.store.get_room_member_stream(
|
||||
user_id=user_id,
|
||||
from_key=from_key,
|
||||
to_key=to_key
|
||||
)
|
||||
|
||||
defer.returnValue((data, latest_ver))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def max_token(self):
|
||||
val = yield self.store.get_max_room_member_id()
|
||||
defer.returnValue(val)
|
||||
|
||||
|
||||
class FeedbackStreamData(StreamData):
|
||||
EVENT_TYPE = FeedbackEvent.TYPE
|
||||
|
||||
def __init__(self, hs, room_id=None):
|
||||
super(FeedbackStreamData, self).__init__(hs)
|
||||
self.room_id = room_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rows(self, user_id, from_key, to_key, limit):
|
||||
(data, latest_ver) = yield self.store.get_feedback_stream(
|
||||
user_id=user_id,
|
||||
from_key=from_key,
|
||||
to_key=to_key,
|
||||
limit=limit,
|
||||
room_id=self.room_id
|
||||
)
|
||||
defer.returnValue((data, latest_ver))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def max_token(self):
|
||||
val = yield self.store.get_max_feedback_id()
|
||||
defer.returnValue(val)
|
||||
|
||||
|
||||
class RoomDataStreamData(StreamData):
|
||||
EVENT_TYPE = RoomTopicEvent.TYPE # TODO need multiple event types
|
||||
|
||||
def __init__(self, hs, room_id=None):
|
||||
super(RoomDataStreamData, self).__init__(hs)
|
||||
self.room_id = room_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rows(self, user_id, from_key, to_key, limit):
|
||||
(data, latest_ver) = yield self.store.get_room_data_stream(
|
||||
user_id=user_id,
|
||||
from_key=from_key,
|
||||
to_key=to_key,
|
||||
limit=limit,
|
||||
room_id=self.room_id
|
||||
)
|
||||
defer.returnValue((data, latest_ver))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def max_token(self):
|
||||
val = yield self.store.get_max_room_data_id()
|
||||
defer.returnValue(val)
|
||||
|
||||
|
||||
class EventStream(PaginationStream):
|
||||
|
||||
SEPARATOR = '_'
|
||||
|
||||
def __init__(self, user_id, stream_data_list):
|
||||
super(EventStream, self).__init__()
|
||||
self.user_id = user_id
|
||||
self.stream_data = stream_data_list
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fix_tokens(self, pagination_config):
|
||||
pagination_config.from_tok = yield self.fix_token(
|
||||
pagination_config.from_tok)
|
||||
pagination_config.to_tok = yield self.fix_token(
|
||||
pagination_config.to_tok)
|
||||
defer.returnValue(pagination_config)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fix_token(self, token):
|
||||
"""Fixes unknown values in a token to known values.
|
||||
|
||||
Args:
|
||||
token (str): The token to fix up.
|
||||
Returns:
|
||||
The fixed-up token, which may == token.
|
||||
"""
|
||||
# replace TOK_START and TOK_END with 0_0_0 or -1_-1_-1 depending.
|
||||
replacements = [
|
||||
(PaginationStream.TOK_START, "0"),
|
||||
(PaginationStream.TOK_END, "-1")
|
||||
]
|
||||
for magic_token, key in replacements:
|
||||
if magic_token == token:
|
||||
token = EventStream.SEPARATOR.join(
|
||||
[key] * len(self.stream_data)
|
||||
)
|
||||
|
||||
# replace -1 values with an actual pkey
|
||||
token_segments = self._split_token(token)
|
||||
for i, tok in enumerate(token_segments):
|
||||
if tok == -1:
|
||||
# add 1 to the max token because results are EXCLUSIVE from the
|
||||
# latest version.
|
||||
token_segments[i] = 1 + (yield self.stream_data[i].max_token())
|
||||
defer.returnValue(EventStream.SEPARATOR.join(
|
||||
str(x) for x in token_segments
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_chunk(self, config=None):
|
||||
# no support for limit on >1 streams, makes no sense.
|
||||
if config.limit and len(self.stream_data) > 1:
|
||||
raise EventStreamError(
|
||||
400, "Limit not supported on multiplexed streams."
|
||||
)
|
||||
|
||||
(chunk_data, next_tok) = yield self._get_chunk_data(config.from_tok,
|
||||
config.to_tok,
|
||||
config.limit)
|
||||
|
||||
defer.returnValue({
|
||||
"chunk": chunk_data,
|
||||
"start": config.from_tok,
|
||||
"end": next_tok
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_chunk_data(self, from_tok, to_tok, limit):
|
||||
""" Get event data between the two tokens.
|
||||
|
||||
Tokens are SEPARATOR separated values representing pkey values of
|
||||
certain tables, and the position determines the StreamData invoked
|
||||
according to the STREAM_DATA list.
|
||||
|
||||
The magic value '-1' can be used to get the latest value.
|
||||
|
||||
Args:
|
||||
from_tok - The token to start from.
|
||||
to_tok - The token to end at. Must have values > from_tok or be -1.
|
||||
Returns:
|
||||
A list of event data.
|
||||
Raises:
|
||||
EventStreamError if something went wrong.
|
||||
"""
|
||||
# sanity check
|
||||
if (from_tok.count(EventStream.SEPARATOR) !=
|
||||
to_tok.count(EventStream.SEPARATOR) or
|
||||
(from_tok.count(EventStream.SEPARATOR) + 1) !=
|
||||
len(self.stream_data)):
|
||||
raise EventStreamError(400, "Token lengths don't match.")
|
||||
|
||||
chunk = []
|
||||
next_ver = []
|
||||
for i, (from_pkey, to_pkey) in enumerate(zip(
|
||||
self._split_token(from_tok),
|
||||
self._split_token(to_tok)
|
||||
)):
|
||||
if from_pkey == to_pkey:
|
||||
# tokens are the same, we have nothing to do.
|
||||
next_ver.append(str(to_pkey))
|
||||
continue
|
||||
|
||||
(event_chunk, max_pkey) = yield self.stream_data[i].get_rows(
|
||||
self.user_id, from_pkey, to_pkey, limit
|
||||
)
|
||||
|
||||
chunk += event_chunk
|
||||
next_ver.append(str(max_pkey))
|
||||
|
||||
defer.returnValue((chunk, EventStream.SEPARATOR.join(next_ver)))
|
||||
|
||||
def _split_token(self, token):
|
||||
"""Splits the given token into a list of pkeys.
|
||||
|
||||
Args:
|
||||
token (str): The token with SEPARATOR values.
|
||||
Returns:
|
||||
A list of ints.
|
||||
"""
|
||||
segments = token.split(EventStream.SEPARATOR)
|
||||
try:
|
||||
int_segments = [int(x) for x in segments]
|
||||
except ValueError:
|
||||
raise EventStreamError(400, "Bad token: %s" % token)
|
||||
return int_segments
|
14
synapse/app/__init__.py
Normal file
14
synapse/app/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
172
synapse/app/homeserver.py
Normal file
172
synapse/app/homeserver.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#!/usr/bin/env python
|
||||
|
||||
from synapse.storage import read_schema
|
||||
|
||||
from synapse.server import HomeServer
|
||||
|
||||
from twisted.internet import reactor
|
||||
from twisted.enterprise import adbapi
|
||||
from twisted.python.log import PythonLoggingObserver
|
||||
from synapse.http.server import TwistedHttpServer
|
||||
from synapse.http.client import TwistedHttpClient
|
||||
|
||||
from daemonize import Daemonize
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import logging.config
|
||||
import sqlite3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SynapseHomeServer(HomeServer):
|
||||
def build_http_server(self):
|
||||
return TwistedHttpServer()
|
||||
|
||||
def build_http_client(self):
|
||||
return TwistedHttpClient()
|
||||
|
||||
def build_db_pool(self):
|
||||
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
|
||||
don't have to worry about overwriting existing content.
|
||||
"""
|
||||
logging.info("Preparing database: %s...", self.db_name)
|
||||
pool = adbapi.ConnectionPool(
|
||||
'sqlite3', self.db_name, check_same_thread=False,
|
||||
cp_min=1, cp_max=1)
|
||||
|
||||
schemas = [
|
||||
"transactions",
|
||||
"pdu",
|
||||
"users",
|
||||
"profiles",
|
||||
"presence",
|
||||
"im",
|
||||
"room_aliases",
|
||||
]
|
||||
|
||||
for sql_loc in schemas:
|
||||
sql_script = read_schema(sql_loc)
|
||||
|
||||
with sqlite3.connect(self.db_name) as db_conn:
|
||||
c = db_conn.cursor()
|
||||
c.executescript(sql_script)
|
||||
c.close()
|
||||
db_conn.commit()
|
||||
|
||||
logging.info("Database prepared in %s.", self.db_name)
|
||||
|
||||
return pool
|
||||
|
||||
|
||||
def setup_logging(verbosity=0, filename=None, config_path=None):
|
||||
""" Sets up logging with verbosity levels.
|
||||
|
||||
Args:
|
||||
verbosity: The verbosity level.
|
||||
filename: Log to the given file rather than to the console.
|
||||
config_path: Path to a python logging config file.
|
||||
"""
|
||||
|
||||
if config_path is None:
|
||||
log_format = (
|
||||
'%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
if not verbosity or verbosity == 0:
|
||||
level = logging.WARNING
|
||||
elif verbosity == 1:
|
||||
level = logging.INFO
|
||||
else:
|
||||
level = logging.DEBUG
|
||||
|
||||
logging.basicConfig(level=level, filename=filename, format=log_format)
|
||||
else:
|
||||
logging.config.fileConfig(config_path)
|
||||
|
||||
observer = PythonLoggingObserver()
|
||||
observer.start()
|
||||
|
||||
|
||||
def run():
|
||||
reactor.run()
|
||||
|
||||
|
||||
def setup():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-p", "--port", dest="port", type=int, default=8080,
|
||||
help="The port to listen on.")
|
||||
parser.add_argument("-d", "--database", dest="db", default="homeserver.db",
|
||||
help="The database name.")
|
||||
parser.add_argument("-H", "--host", dest="host", default="localhost",
|
||||
help="The hostname of the server.")
|
||||
parser.add_argument('-v', '--verbose', dest="verbose", action='count',
|
||||
help="The verbosity level.")
|
||||
parser.add_argument('-f', '--log-file', dest="log_file", default=None,
|
||||
help="File to log to.")
|
||||
parser.add_argument('--log-config', dest="log_config", default=None,
|
||||
help="Python logging config")
|
||||
parser.add_argument('-D', '--daemonize', action='store_true',
|
||||
default=False, help="Daemonize the home server")
|
||||
parser.add_argument('--pid-file', dest="pid", help="When running as a "
|
||||
"daemon, the file to store the pid in",
|
||||
default="hs.pid")
|
||||
args = parser.parse_args()
|
||||
|
||||
verbosity = int(args.verbose) if args.verbose else None
|
||||
|
||||
setup_logging(
|
||||
verbosity=verbosity,
|
||||
filename=args.log_file,
|
||||
config_path=args.log_config,
|
||||
)
|
||||
|
||||
logger.info("Server hostname: %s", args.host)
|
||||
|
||||
hs = SynapseHomeServer(
|
||||
args.host,
|
||||
db_name=args.db
|
||||
)
|
||||
|
||||
# This object doesn't need to be saved because it's set as the handler for
|
||||
# the replication layer
|
||||
hs.get_federation()
|
||||
|
||||
hs.register_servlets()
|
||||
|
||||
hs.get_http_server().start_listening(args.port)
|
||||
|
||||
hs.build_db_pool()
|
||||
|
||||
if args.daemonize:
|
||||
daemon = Daemonize(
|
||||
app="synapse-homeserver",
|
||||
pid=args.pid,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
daemon.start()
|
||||
else:
|
||||
run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
setup()
|
14
synapse/crypto/__init__.py
Normal file
14
synapse/crypto/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
159
synapse/crypto/config.py
Normal file
159
synapse/crypto/config.py
Normal file
|
@ -0,0 +1,159 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ConfigParser as configparser
|
||||
import argparse
|
||||
import socket
|
||||
import sys
|
||||
import os
|
||||
from OpenSSL import crypto
|
||||
import nacl.signing
|
||||
from syutil.base64util import encode_base64
|
||||
import subprocess
|
||||
|
||||
|
||||
def load_config(description, argv):
|
||||
config_parser = argparse.ArgumentParser(add_help=False)
|
||||
config_parser.add_argument("-c", "--config-path", metavar="CONFIG_FILE",
|
||||
help="Specify config file")
|
||||
config_args, remaining_args = config_parser.parse_known_args(argv)
|
||||
if config_args.config_path:
|
||||
config = configparser.SafeConfigParser()
|
||||
config.read([config_args.config_path])
|
||||
defaults = dict(config.items("KeyServer"))
|
||||
else:
|
||||
defaults = {}
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[config_parser],
|
||||
description=description,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.set_defaults(**defaults)
|
||||
parser.add_argument("--server-name", default=socket.getfqdn(),
|
||||
help="The name of the server")
|
||||
parser.add_argument("--signing-key-path",
|
||||
help="The signing key to sign responses with")
|
||||
parser.add_argument("--tls-certificate-path",
|
||||
help="PEM encoded X509 certificate for TLS")
|
||||
parser.add_argument("--tls-private-key-path",
|
||||
help="PEM encoded private key for TLS")
|
||||
parser.add_argument("--tls-dh-params-path",
|
||||
help="PEM encoded dh parameters for ephemeral keys")
|
||||
parser.add_argument("--bind-port", type=int,
|
||||
help="TCP port to listen on")
|
||||
parser.add_argument("--bind-host", default="",
|
||||
help="Local interface to listen on")
|
||||
|
||||
args = parser.parse_args(remaining_args)
|
||||
|
||||
server_config = vars(args)
|
||||
del server_config["config_path"]
|
||||
return server_config
|
||||
|
||||
|
||||
def generate_config(argv):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--config-path", help="Specify config file",
|
||||
metavar="CONFIG_FILE", required=True)
|
||||
parser.add_argument("--server-name", default=socket.getfqdn(),
|
||||
help="The name of the server")
|
||||
parser.add_argument("--signing-key-path",
|
||||
help="The signing key to sign responses with")
|
||||
parser.add_argument("--tls-certificate-path",
|
||||
help="PEM encoded X509 certificate for TLS")
|
||||
parser.add_argument("--tls-private-key-path",
|
||||
help="PEM encoded private key for TLS")
|
||||
parser.add_argument("--tls-dh-params-path",
|
||||
help="PEM encoded dh parameters for ephemeral keys")
|
||||
parser.add_argument("--bind-port", type=int, required=True,
|
||||
help="TCP port to listen on")
|
||||
parser.add_argument("--bind-host", default="",
|
||||
help="Local interface to listen on")
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
dir_name = os.path.dirname(args.config_path)
|
||||
base_key_name = os.path.join(dir_name, args.server_name)
|
||||
|
||||
if args.signing_key_path is None:
|
||||
args.signing_key_path = base_key_name + ".signing.key"
|
||||
|
||||
if args.tls_certificate_path is None:
|
||||
args.tls_certificate_path = base_key_name + ".tls.crt"
|
||||
|
||||
if args.tls_private_key_path is None:
|
||||
args.tls_private_key_path = base_key_name + ".tls.key"
|
||||
|
||||
if args.tls_dh_params_path is None:
|
||||
args.tls_dh_params_path = base_key_name + ".tls.dh"
|
||||
|
||||
if not os.path.exists(args.signing_key_path):
|
||||
with open(args.signing_key_path, "w") as signing_key_file:
|
||||
key = nacl.signing.SigningKey.generate()
|
||||
signing_key_file.write(encode_base64(key.encode()))
|
||||
|
||||
if not os.path.exists(args.tls_private_key_path):
|
||||
with open(args.tls_private_key_path, "w") as private_key_file:
|
||||
tls_private_key = crypto.PKey()
|
||||
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
|
||||
private_key_pem = crypto.dump_privatekey(
|
||||
crypto.FILETYPE_PEM, tls_private_key
|
||||
)
|
||||
private_key_file.write(private_key_pem)
|
||||
else:
|
||||
with open(args.tls_private_key_path) as private_key_file:
|
||||
private_key_pem = private_key_file.read()
|
||||
tls_private_key = crypto.load_privatekey(
|
||||
crypto.FILETYPE_PEM, private_key_pem
|
||||
)
|
||||
|
||||
if not os.path.exists(args.tls_certificate_path):
|
||||
with open(args.tls_certificate_path, "w") as certifcate_file:
|
||||
cert = crypto.X509()
|
||||
subject = cert.get_subject()
|
||||
subject.CN = args.server_name
|
||||
|
||||
cert.set_serial_number(1000)
|
||||
cert.gmtime_adj_notBefore(0)
|
||||
cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
|
||||
cert.set_issuer(cert.get_subject())
|
||||
cert.set_pubkey(tls_private_key)
|
||||
|
||||
cert.sign(tls_private_key, 'sha256')
|
||||
|
||||
cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
|
||||
|
||||
certifcate_file.write(cert_pem)
|
||||
|
||||
if not os.path.exists(args.tls_dh_params_path):
|
||||
subprocess.check_call([
|
||||
"openssl", "dhparam",
|
||||
"-outform", "PEM",
|
||||
"-out", args.tls_dh_params_path,
|
||||
"2048"
|
||||
])
|
||||
|
||||
config = configparser.SafeConfigParser()
|
||||
config.add_section("KeyServer")
|
||||
for key, value in vars(args).items():
|
||||
if key != "config_path":
|
||||
config.set("KeyServer", key, str(value))
|
||||
|
||||
with open(args.config_path, "w") as config_file:
|
||||
config.write(config_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_config(sys.argv[1:])
|
118
synapse/crypto/keyclient.py
Normal file
118
synapse/crypto/keyclient.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.web.http import HTTPClient
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.protocol import ClientFactory
|
||||
from twisted.names.srvconnect import SRVConnector
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_server_key(server_name, ssl_context_factory):
|
||||
"""Fetch the keys for a remote server."""
|
||||
|
||||
factory = SynapseKeyClientFactory()
|
||||
|
||||
SRVConnector(
|
||||
reactor, "matrix", server_name, factory,
|
||||
protocol="tcp", connectFuncName="connectSSL", defaultPort=443,
|
||||
connectFuncKwArgs=dict(contextFactory=ssl_context_factory)).connect()
|
||||
|
||||
server_key, server_certificate = yield factory.remote_key
|
||||
|
||||
defer.returnValue((server_key, server_certificate))
|
||||
|
||||
|
||||
class SynapseKeyClientError(Exception):
|
||||
"""The key wasn't retireved from the remote server."""
|
||||
pass
|
||||
|
||||
|
||||
class SynapseKeyClientProtocol(HTTPClient):
|
||||
"""Low level HTTPS client which retrieves an application/json response from
|
||||
the server and extracts the X.509 certificate for the remote peer from the
|
||||
SSL connection."""
|
||||
|
||||
def connectionMade(self):
|
||||
logger.debug("Connected to %s", self.transport.getHost())
|
||||
self.sendCommand(b"GET", b"/key")
|
||||
self.endHeaders()
|
||||
self.timer = reactor.callLater(
|
||||
self.factory.timeout_seconds,
|
||||
self.on_timeout
|
||||
)
|
||||
|
||||
def handleStatus(self, version, status, message):
|
||||
if status != b"200":
|
||||
logger.info("Non-200 response from %s: %s %s",
|
||||
self.transport.getHost(), status, message)
|
||||
self.transport.abortConnection()
|
||||
|
||||
def handleResponse(self, response_body_bytes):
|
||||
try:
|
||||
json_response = json.loads(response_body_bytes)
|
||||
except ValueError:
|
||||
logger.info("Invalid JSON response from %s",
|
||||
self.transport.getHost())
|
||||
self.transport.abortConnection()
|
||||
return
|
||||
|
||||
certificate = self.transport.getPeerCertificate()
|
||||
self.factory.on_remote_key((json_response, certificate))
|
||||
self.transport.abortConnection()
|
||||
self.timer.cancel()
|
||||
|
||||
def on_timeout(self):
|
||||
logger.debug("Timeout waiting for response from %s",
|
||||
self.transport.getHost())
|
||||
self.transport.abortConnection()
|
||||
|
||||
|
||||
class SynapseKeyClientFactory(ClientFactory):
|
||||
protocol = SynapseKeyClientProtocol
|
||||
max_retries = 5
|
||||
timeout_seconds = 30
|
||||
|
||||
def __init__(self):
|
||||
self.succeeded = False
|
||||
self.retries = 0
|
||||
self.remote_key = defer.Deferred()
|
||||
|
||||
def on_remote_key(self, key):
|
||||
self.succeeded = True
|
||||
self.remote_key.callback(key)
|
||||
|
||||
def retry_connection(self, connector):
|
||||
self.retries += 1
|
||||
if self.retries < self.max_retries:
|
||||
connector.connector = None
|
||||
connector.connect()
|
||||
else:
|
||||
self.remote_key.errback(
|
||||
SynapseKeyClientError("Max retries exceeded"))
|
||||
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
logger.info("Connection failed %s", reason)
|
||||
self.retry_connection(connector)
|
||||
|
||||
def clientConnectionLost(self, connector, reason):
|
||||
logger.info("Connection lost %s", reason)
|
||||
if not self.succeeded:
|
||||
self.retry_connection(connector)
|
110
synapse/crypto/keyserver.py
Normal file
110
synapse/crypto/keyserver.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import reactor, ssl
|
||||
from twisted.web import server
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.python.log import PythonLoggingObserver
|
||||
|
||||
from synapse.crypto.resource.key import LocalKey
|
||||
from synapse.crypto.config import load_config
|
||||
|
||||
from syutil.base64util import decode_base64
|
||||
|
||||
from OpenSSL import crypto, SSL
|
||||
|
||||
import logging
|
||||
import nacl.signing
|
||||
import sys
|
||||
|
||||
|
||||
class KeyServerSSLContextFactory(ssl.ContextFactory):
|
||||
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming
|
||||
connections and to make connections to remote servers."""
|
||||
|
||||
def __init__(self, key_server):
|
||||
self._context = SSL.Context(SSL.SSLv23_METHOD)
|
||||
self.configure_context(self._context, key_server)
|
||||
|
||||
@staticmethod
|
||||
def configure_context(context, key_server):
|
||||
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
|
||||
context.use_certificate(key_server.tls_certificate)
|
||||
context.use_privatekey(key_server.tls_private_key)
|
||||
context.load_tmp_dh(key_server.tls_dh_params_path)
|
||||
context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH")
|
||||
|
||||
def getContext(self):
|
||||
return self._context
|
||||
|
||||
|
||||
class KeyServer(object):
|
||||
"""An HTTPS server serving LocalKey and RemoteKey resources."""
|
||||
|
||||
def __init__(self, server_name, tls_certificate_path, tls_private_key_path,
|
||||
tls_dh_params_path, signing_key_path, bind_host, bind_port):
|
||||
self.server_name = server_name
|
||||
self.tls_certificate = self.read_tls_certificate(tls_certificate_path)
|
||||
self.tls_private_key = self.read_tls_private_key(tls_private_key_path)
|
||||
self.tls_dh_params_path = tls_dh_params_path
|
||||
self.signing_key = self.read_signing_key(signing_key_path)
|
||||
self.bind_host = bind_host
|
||||
self.bind_port = int(bind_port)
|
||||
self.ssl_context_factory = KeyServerSSLContextFactory(self)
|
||||
|
||||
@staticmethod
|
||||
def read_tls_certificate(cert_path):
|
||||
with open(cert_path) as cert_file:
|
||||
cert_pem = cert_file.read()
|
||||
return crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
|
||||
|
||||
@staticmethod
|
||||
def read_tls_private_key(private_key_path):
|
||||
with open(private_key_path) as private_key_file:
|
||||
private_key_pem = private_key_file.read()
|
||||
return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
|
||||
|
||||
@staticmethod
|
||||
def read_signing_key(signing_key_path):
|
||||
with open(signing_key_path) as signing_key_file:
|
||||
signing_key_b64 = signing_key_file.read()
|
||||
signing_key_bytes = decode_base64(signing_key_b64)
|
||||
return nacl.signing.SigningKey(signing_key_bytes)
|
||||
|
||||
def run(self):
|
||||
root = Resource()
|
||||
root.putChild("key", LocalKey(self))
|
||||
site = server.Site(root)
|
||||
reactor.listenSSL(
|
||||
self.bind_port,
|
||||
site,
|
||||
self.ssl_context_factory,
|
||||
interface=self.bind_host
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
observer = PythonLoggingObserver()
|
||||
observer.start()
|
||||
|
||||
reactor.run()
|
||||
|
||||
|
||||
def main():
|
||||
key_server = KeyServer(**load_config(__doc__, sys.argv[1:]))
|
||||
key_server.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
14
synapse/crypto/resource/__init__.py
Normal file
14
synapse/crypto/resource/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
160
synapse/crypto/resource/key.py
Normal file
160
synapse/crypto/resource/key.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
from synapse.http.server import respond_with_json_bytes
|
||||
from synapse.crypto.keyclient import fetch_server_key
|
||||
from syutil.crypto.jsonsign import sign_json, verify_signed_json
|
||||
from syutil.base64util import encode_base64, decode_base64
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
from OpenSSL import crypto
|
||||
from nacl.signing import VerifyKey
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalKey(Resource):
|
||||
"""HTTP resource containing encoding the TLS X.509 certificate and NACL
|
||||
signature verification keys for this server::
|
||||
|
||||
GET /key HTTP/1.1
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
Content-Type: application/json
|
||||
{
|
||||
"server_name": "this.server.example.com"
|
||||
"signature_verify_key": # base64 encoded NACL verification key.
|
||||
"tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert.
|
||||
"signatures": {
|
||||
"this.server.example.com": # NACL signature for this server.
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, key_server):
|
||||
self.key_server = key_server
|
||||
self.response_body = encode_canonical_json(
|
||||
self.response_json_object(key_server)
|
||||
)
|
||||
Resource.__init__(self)
|
||||
|
||||
@staticmethod
|
||||
def response_json_object(key_server):
|
||||
verify_key_bytes = key_server.signing_key.verify_key.encode()
|
||||
x509_certificate_bytes = crypto.dump_certificate(
|
||||
crypto.FILETYPE_ASN1,
|
||||
key_server.tls_certificate
|
||||
)
|
||||
json_object = {
|
||||
u"server_name": key_server.server_name,
|
||||
u"signature_verify_key": encode_base64(verify_key_bytes),
|
||||
u"tls_certificate": encode_base64(x509_certificate_bytes)
|
||||
}
|
||||
signed_json = sign_json(
|
||||
json_object,
|
||||
key_server.server_name,
|
||||
key_server.signing_key
|
||||
)
|
||||
return signed_json
|
||||
|
||||
def getChild(self, name, request):
|
||||
logger.info("getChild %s %s", name, request)
|
||||
if name == '':
|
||||
return self
|
||||
else:
|
||||
return RemoteKey(name, self.key_server)
|
||||
|
||||
def render_GET(self, request):
|
||||
return respond_with_json_bytes(request, 200, self.response_body)
|
||||
|
||||
|
||||
class RemoteKey(Resource):
|
||||
"""HTTP resource for retreiving the TLS certificate and NACL signature
|
||||
verification keys for a another server. Checks that the reported X.509 TLS
|
||||
certificate matches the one used in the HTTPS connection. Checks that the
|
||||
NACL signature for the remote server is valid. Returns JSON signed by both
|
||||
the remote server and by this server.
|
||||
|
||||
GET /key/remote.server.example.com HTTP/1.1
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
Content-Type: application/json
|
||||
{
|
||||
"server_name": "remote.server.example.com"
|
||||
"signature_verify_key": # base64 encoded NACL verification key.
|
||||
"tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert.
|
||||
"signatures": {
|
||||
"remote.server.example.com": # NACL signature for remote server.
|
||||
"this.server.example.com": # NACL signature for this server.
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, server_name, key_server):
|
||||
self.server_name = server_name
|
||||
self.key_server = key_server
|
||||
Resource.__init__(self)
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
try:
|
||||
server_keys, certificate = yield fetch_server_key(
|
||||
self.server_name,
|
||||
self.key_server.ssl_context_factory
|
||||
)
|
||||
|
||||
resp_server_name = server_keys[u"server_name"]
|
||||
verify_key_b64 = server_keys[u"signature_verify_key"]
|
||||
tls_certificate_b64 = server_keys[u"tls_certificate"]
|
||||
verify_key = VerifyKey(decode_base64(verify_key_b64))
|
||||
|
||||
if resp_server_name != self.server_name:
|
||||
raise ValueError("Wrong server name '%s' != '%s'" %
|
||||
(resp_server_name, self.server_name))
|
||||
|
||||
x509_certificate_bytes = crypto.dump_certificate(
|
||||
crypto.FILETYPE_ASN1,
|
||||
certificate
|
||||
)
|
||||
|
||||
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
|
||||
raise ValueError("TLS certificate doesn't match")
|
||||
|
||||
verify_signed_json(server_keys, self.server_name, verify_key)
|
||||
|
||||
signed_json = sign_json(
|
||||
server_keys,
|
||||
self.key_server.server_name,
|
||||
self.key_server.signing_key
|
||||
)
|
||||
|
||||
json_bytes = encode_canonical_json(signed_json)
|
||||
respond_with_json_bytes(request, 200, json_bytes)
|
||||
|
||||
except Exception as e:
|
||||
json_bytes = encode_canonical_json({
|
||||
u"error": {u"code": 502, u"message": e.message}
|
||||
})
|
||||
respond_with_json_bytes(request, 502, json_bytes)
|
29
synapse/federation/__init__.py
Normal file
29
synapse/federation/__init__.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" This package includes all the federation specific logic.
|
||||
"""
|
||||
|
||||
from .replication import ReplicationLayer
|
||||
from .transport import TransportLayer
|
||||
|
||||
|
||||
def initialize_http_replication(homeserver):
|
||||
transport = TransportLayer(
|
||||
homeserver.hostname,
|
||||
server=homeserver.get_http_server(),
|
||||
client=homeserver.get_http_client()
|
||||
)
|
||||
|
||||
return ReplicationLayer(homeserver, transport)
|
148
synapse/federation/handler.py
Normal file
148
synapse/federation/handler.py
Normal file
|
@ -0,0 +1,148 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from .pdu_codec import PduCodec
|
||||
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FederationEventHandler(object):
|
||||
""" Responsible for:
|
||||
a) handling received Pdus before handing them on as Events to the rest
|
||||
of the home server (including auth and state conflict resoultion)
|
||||
b) converting events that were produced by local clients that may need
|
||||
to be sent to remote home servers.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.replication_layer = hs.get_replication_layer()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
# self.auth_handler = gs.get_auth_handler()
|
||||
self.event_handler = hs.get_handlers().federation_handler
|
||||
self.server_name = hs.hostname
|
||||
|
||||
self.lock_manager = hs.get_room_lock_manager()
|
||||
|
||||
self.replication_layer.set_handler(self)
|
||||
|
||||
self.pdu_codec = PduCodec(hs)
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def handle_new_event(self, event):
|
||||
""" Takes in an event from the client to server side, that has already
|
||||
been authed and handled by the state module, and sends it to any
|
||||
remote home servers that may be interested.
|
||||
|
||||
Args:
|
||||
event
|
||||
|
||||
Returns:
|
||||
Deferred: Resolved when it has successfully been queued for
|
||||
processing.
|
||||
"""
|
||||
yield self._fill_out_prev_events(event)
|
||||
|
||||
pdu = self.pdu_codec.pdu_from_event(event)
|
||||
|
||||
if not hasattr(pdu, "destinations") or not pdu.destinations:
|
||||
pdu.destinations = []
|
||||
|
||||
yield self.replication_layer.send_pdu(pdu)
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def backfill(self, room_id, limit):
|
||||
# TODO: Work out which destinations to ask for pagination
|
||||
# self.replication_layer.paginate(dest, room_id, limit)
|
||||
pass
|
||||
|
||||
@log_function
|
||||
def get_state_for_room(self, destination, room_id):
|
||||
return self.replication_layer.get_state_for_context(
|
||||
destination, room_id
|
||||
)
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def on_receive_pdu(self, pdu):
|
||||
""" Called by the ReplicationLayer when we have a new pdu. We need to
|
||||
do auth checks and put it throught the StateHandler.
|
||||
"""
|
||||
event = self.pdu_codec.event_from_pdu(pdu)
|
||||
|
||||
try:
|
||||
with (yield self.lock_manager.lock(pdu.context)):
|
||||
if event.is_state:
|
||||
is_new_state = yield self.state_handler.handle_new_state(
|
||||
pdu
|
||||
)
|
||||
if not is_new_state:
|
||||
return
|
||||
else:
|
||||
is_new_state = False
|
||||
|
||||
yield self.event_handler.on_receive(event, is_new_state)
|
||||
|
||||
except AuthError:
|
||||
# TODO: Implement something in federation that allows us to
|
||||
# respond to PDU.
|
||||
raise
|
||||
|
||||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _on_new_state(self, pdu, new_state_event):
|
||||
# TODO: Do any store stuff here. Notifiy C2S about this new
|
||||
# state.
|
||||
|
||||
yield self.store.update_current_state(
|
||||
pdu_id=pdu.pdu_id,
|
||||
origin=pdu.origin,
|
||||
context=pdu.context,
|
||||
pdu_type=pdu.pdu_type,
|
||||
state_key=pdu.state_key
|
||||
)
|
||||
|
||||
yield self.event_handler.on_receive(new_state_event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _fill_out_prev_events(self, event):
|
||||
if hasattr(event, "prev_events"):
|
||||
return
|
||||
|
||||
results = yield self.store.get_latest_pdus_in_context(
|
||||
event.room_id
|
||||
)
|
||||
|
||||
es = [
|
||||
"%s@%s" % (p_id, origin) for p_id, origin, _ in results
|
||||
]
|
||||
|
||||
event.prev_events = [e for e in es if e != event.event_id]
|
||||
|
||||
if results:
|
||||
event.depth = max([int(v) for _, _, v in results]) + 1
|
||||
else:
|
||||
event.depth = 0
|
101
synapse/federation/pdu_codec.py
Normal file
101
synapse/federation/pdu_codec.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .units import Pdu
|
||||
|
||||
import copy
|
||||
|
||||
|
||||
def decode_event_id(event_id, server_name):
|
||||
parts = event_id.split("@")
|
||||
if len(parts) < 2:
|
||||
return (event_id, server_name)
|
||||
else:
|
||||
return (parts[0], "".join(parts[1:]))
|
||||
|
||||
|
||||
def encode_event_id(pdu_id, origin):
|
||||
return "%s@%s" % (pdu_id, origin)
|
||||
|
||||
|
||||
class PduCodec(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.server_name = hs.hostname
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def event_from_pdu(self, pdu):
|
||||
kwargs = {}
|
||||
|
||||
kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
|
||||
kwargs["room_id"] = pdu.context
|
||||
kwargs["etype"] = pdu.pdu_type
|
||||
kwargs["prev_events"] = [
|
||||
encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
|
||||
]
|
||||
|
||||
if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
|
||||
kwargs["prev_state"] = encode_event_id(
|
||||
pdu.prev_state_id, pdu.prev_state_origin
|
||||
)
|
||||
|
||||
kwargs.update({
|
||||
k: v
|
||||
for k, v in pdu.get_full_dict().items()
|
||||
if k not in [
|
||||
"pdu_id",
|
||||
"context",
|
||||
"pdu_type",
|
||||
"prev_pdus",
|
||||
"prev_state_id",
|
||||
"prev_state_origin",
|
||||
]
|
||||
})
|
||||
|
||||
return self.event_factory.create_event(**kwargs)
|
||||
|
||||
def pdu_from_event(self, event):
|
||||
d = event.get_full_dict()
|
||||
|
||||
d["pdu_id"], d["origin"] = decode_event_id(
|
||||
event.event_id, self.server_name
|
||||
)
|
||||
d["context"] = event.room_id
|
||||
d["pdu_type"] = event.type
|
||||
|
||||
if hasattr(event, "prev_events"):
|
||||
d["prev_pdus"] = [
|
||||
decode_event_id(e, self.server_name)
|
||||
for e in event.prev_events
|
||||
]
|
||||
|
||||
if hasattr(event, "prev_state"):
|
||||
d["prev_state_id"], d["prev_state_origin"] = (
|
||||
decode_event_id(event.prev_state, self.server_name)
|
||||
)
|
||||
|
||||
if hasattr(event, "state_key"):
|
||||
d["is_state"] = True
|
||||
|
||||
kwargs = copy.deepcopy(event.unrecognized_keys)
|
||||
kwargs.update({
|
||||
k: v for k, v in d.items()
|
||||
if k not in ["event_id", "room_id", "type", "prev_events"]
|
||||
})
|
||||
|
||||
if "ts" not in kwargs:
|
||||
kwargs["ts"] = int(self.clock.time_msec())
|
||||
|
||||
return Pdu(**kwargs)
|
240
synapse/federation/persistence.py
Normal file
240
synapse/federation/persistence.py
Normal file
|
@ -0,0 +1,240 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" This module contains all the persistence actions done by the federation
|
||||
package.
|
||||
|
||||
These actions are mostly only used by the :py:mod:`.replication` module.
|
||||
"""
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from .units import Pdu
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PduActions(object):
|
||||
""" Defines persistence actions that relate to handling PDUs.
|
||||
"""
|
||||
|
||||
def __init__(self, datastore):
|
||||
self.store = datastore
|
||||
|
||||
@log_function
|
||||
def persist_received(self, pdu):
|
||||
""" Persists the given `Pdu` that was received from a remote home
|
||||
server.
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
return self._persist(pdu)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def persist_outgoing(self, pdu):
|
||||
""" Persists the given `Pdu` that this home server created.
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
ret = yield self._persist(pdu)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
@log_function
|
||||
def mark_as_processed(self, pdu):
|
||||
""" Persist the fact that we have fully processed the given `Pdu`
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def populate_previous_pdus(self, pdu):
|
||||
""" Given an outgoing `Pdu` fill out its `prev_ids` key with the `Pdu`s
|
||||
that we have received.
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
results = yield self.store.get_latest_pdus_in_context(pdu.context)
|
||||
|
||||
pdu.prev_pdus = [(p_id, origin) for p_id, origin, _ in results]
|
||||
|
||||
vs = [int(v) for _, _, v in results]
|
||||
if vs:
|
||||
pdu.depth = max(vs) + 1
|
||||
else:
|
||||
pdu.depth = 0
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def after_transaction(self, transaction_id, destination, origin):
|
||||
""" Returns all `Pdu`s that we sent to the given remote home server
|
||||
after a given transaction id.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a list of `Pdu`s
|
||||
"""
|
||||
results = yield self.store.get_pdus_after_transaction(
|
||||
transaction_id,
|
||||
destination
|
||||
)
|
||||
|
||||
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_all_pdus_from_context(self, context):
|
||||
results = yield self.store.get_all_pdus_from_context(context)
|
||||
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def paginate(self, context, pdu_list, limit):
|
||||
""" For a given list of PDU id and origins return the proceeding
|
||||
`limit` `Pdu`s in the given `context`.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a list of `Pdu`s.
|
||||
"""
|
||||
results = yield self.store.get_pagination(
|
||||
context, pdu_list, limit
|
||||
)
|
||||
|
||||
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
|
||||
|
||||
@log_function
|
||||
def is_new(self, pdu):
|
||||
""" When we receive a `Pdu` from a remote home server, we want to
|
||||
figure out whether it is `new`, i.e. it is not some historic PDU that
|
||||
we haven't seen simply because we haven't paginated back that far.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a `bool`
|
||||
"""
|
||||
return self.store.is_pdu_new(
|
||||
pdu_id=pdu.pdu_id,
|
||||
origin=pdu.origin,
|
||||
context=pdu.context,
|
||||
depth=pdu.depth
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _persist(self, pdu):
|
||||
kwargs = copy.copy(pdu.__dict__)
|
||||
unrec_keys = copy.copy(pdu.unrecognized_keys)
|
||||
del kwargs["content"]
|
||||
kwargs["content_json"] = json.dumps(pdu.content)
|
||||
kwargs["unrecognized_keys"] = json.dumps(unrec_keys)
|
||||
|
||||
logger.debug("Persisting: %s", repr(kwargs))
|
||||
|
||||
if pdu.is_state:
|
||||
ret = yield self.store.persist_state(**kwargs)
|
||||
else:
|
||||
ret = yield self.store.persist_pdu(**kwargs)
|
||||
|
||||
yield self.store.update_min_depth_for_context(
|
||||
pdu.context, pdu.depth
|
||||
)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
|
||||
class TransactionActions(object):
|
||||
""" Defines persistence actions that relate to handling Transactions.
|
||||
"""
|
||||
|
||||
def __init__(self, datastore):
|
||||
self.store = datastore
|
||||
|
||||
@log_function
|
||||
def have_responded(self, transaction):
|
||||
""" Have we already responded to a transaction with the same id and
|
||||
origin?
|
||||
|
||||
Returns:
|
||||
Deferred: Results in `None` if we have not previously responded to
|
||||
this transaction or a 2-tuple of `(int, dict)` representing the
|
||||
response code and response body.
|
||||
"""
|
||||
if not transaction.transaction_id:
|
||||
raise RuntimeError("Cannot persist a transaction with no "
|
||||
"transaction_id")
|
||||
|
||||
return self.store.get_received_txn_response(
|
||||
transaction.transaction_id, transaction.origin
|
||||
)
|
||||
|
||||
@log_function
|
||||
def set_response(self, transaction, code, response):
|
||||
""" Persist how we responded to a transaction.
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
if not transaction.transaction_id:
|
||||
raise RuntimeError("Cannot persist a transaction with no "
|
||||
"transaction_id")
|
||||
|
||||
return self.store.set_received_txn_response(
|
||||
transaction.transaction_id,
|
||||
transaction.origin,
|
||||
code,
|
||||
json.dumps(response)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def prepare_to_send(self, transaction):
|
||||
""" Persists the `Transaction` we are about to send and works out the
|
||||
correct value for the `prev_ids` key.
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
transaction.prev_ids = yield self.store.prep_send_transaction(
|
||||
transaction.transaction_id,
|
||||
transaction.destination,
|
||||
transaction.ts,
|
||||
[(p["pdu_id"], p["origin"]) for p in transaction.pdus]
|
||||
)
|
||||
|
||||
@log_function
|
||||
def delivered(self, transaction, response_code, response_dict):
|
||||
""" Marks the given `Transaction` as having been successfully
|
||||
delivered to the remote homeserver, and what the response was.
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
return self.store.delivered_txn(
|
||||
transaction.transaction_id,
|
||||
transaction.destination,
|
||||
response_code,
|
||||
json.dumps(response_dict)
|
||||
)
|
582
synapse/federation/replication.py
Normal file
582
synapse/federation/replication.py
Normal file
|
@ -0,0 +1,582 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This layer is responsible for replicating with remote home servers using
|
||||
a given transport.
|
||||
"""
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from .units import Transaction, Pdu, Edu
|
||||
|
||||
from .persistence import PduActions, TransactionActions
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplicationLayer(object):
|
||||
"""This layer is responsible for replicating with remote home servers over
|
||||
the given transport. I.e., does the sending and receiving of PDUs to
|
||||
remote home servers.
|
||||
|
||||
The layer communicates with the rest of the server via a registered
|
||||
ReplicationHandler.
|
||||
|
||||
In more detail, the layer:
|
||||
* Receives incoming data and processes it into transactions and pdus.
|
||||
* Fetches any PDUs it thinks it might have missed.
|
||||
* Keeps the current state for contexts up to date by applying the
|
||||
suitable conflict resolution.
|
||||
* Sends outgoing pdus wrapped in transactions.
|
||||
* Fills out the references to previous pdus/transactions appropriately
|
||||
for outgoing data.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, transport_layer):
|
||||
self.server_name = hs.hostname
|
||||
|
||||
self.transport_layer = transport_layer
|
||||
self.transport_layer.register_received_handler(self)
|
||||
self.transport_layer.register_request_handler(self)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.pdu_actions = PduActions(self.store)
|
||||
self.transaction_actions = TransactionActions(self.store)
|
||||
|
||||
self._transaction_queue = _TransactionQueue(
|
||||
hs, self.transaction_actions, transport_layer
|
||||
)
|
||||
|
||||
self.handler = None
|
||||
self.edu_handlers = {}
|
||||
|
||||
self._order = 0
|
||||
|
||||
self._clock = hs.get_clock()
|
||||
|
||||
def set_handler(self, handler):
|
||||
"""Sets the handler that the replication layer will use to communicate
|
||||
receipt of new PDUs from other home servers. The required methods are
|
||||
documented on :py:class:`.ReplicationHandler`.
|
||||
"""
|
||||
self.handler = handler
|
||||
|
||||
def register_edu_handler(self, edu_type, handler):
|
||||
if edu_type in self.edu_handlers:
|
||||
raise KeyError("Already have an EDU handler for %s" % (edu_type))
|
||||
|
||||
self.edu_handlers[edu_type] = handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_pdu(self, pdu):
|
||||
"""Informs the replication layer about a new PDU generated within the
|
||||
home server that should be transmitted to others.
|
||||
|
||||
This will fill out various attributes on the PDU object, e.g. the
|
||||
`prev_pdus` key.
|
||||
|
||||
*Note:* The home server should always call `send_pdu` even if it knows
|
||||
that it does not need to be replicated to other home servers. This is
|
||||
in case e.g. someone else joins via a remote home server and then
|
||||
paginates.
|
||||
|
||||
TODO: Figure out when we should actually resolve the deferred.
|
||||
|
||||
Args:
|
||||
pdu (Pdu): The new Pdu.
|
||||
|
||||
Returns:
|
||||
Deferred: Completes when we have successfully processed the PDU
|
||||
and replicated it to any interested remote home servers.
|
||||
"""
|
||||
order = self._order
|
||||
self._order += 1
|
||||
|
||||
logger.debug("[%s] Persisting PDU", pdu.pdu_id)
|
||||
|
||||
#yield self.pdu_actions.populate_previous_pdus(pdu)
|
||||
|
||||
# Save *before* trying to send
|
||||
yield self.pdu_actions.persist_outgoing(pdu)
|
||||
|
||||
logger.debug("[%s] Persisted PDU", pdu.pdu_id)
|
||||
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
|
||||
|
||||
# TODO, add errback, etc.
|
||||
self._transaction_queue.enqueue_pdu(pdu, order)
|
||||
|
||||
logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
|
||||
|
||||
@log_function
|
||||
def send_edu(self, destination, edu_type, content):
|
||||
edu = Edu(
|
||||
origin=self.server_name,
|
||||
destination=destination,
|
||||
edu_type=edu_type,
|
||||
content=content,
|
||||
)
|
||||
|
||||
# TODO, add errback, etc.
|
||||
self._transaction_queue.enqueue_edu(edu)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def paginate(self, dest, context, limit):
|
||||
"""Requests some more historic PDUs for the given context from the
|
||||
given destination server.
|
||||
|
||||
Args:
|
||||
dest (str): The remote home server to ask.
|
||||
context (str): The context to paginate back on.
|
||||
limit (int): The maximum number of PDUs to return.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in the received PDUs.
|
||||
"""
|
||||
extremities = yield self.store.get_oldest_pdus_in_context(context)
|
||||
|
||||
logger.debug("paginate extrem=%s", extremities)
|
||||
|
||||
# If there are no extremeties then we've (probably) reached the start.
|
||||
if not extremities:
|
||||
return
|
||||
|
||||
transaction_data = yield self.transport_layer.paginate(
|
||||
dest, context, extremities, limit)
|
||||
|
||||
logger.debug("paginate transaction_data=%s", repr(transaction_data))
|
||||
|
||||
transaction = Transaction(**transaction_data)
|
||||
|
||||
pdus = [Pdu(outlier=False, **p) for p in transaction.pdus]
|
||||
for pdu in pdus:
|
||||
yield self._handle_new_pdu(pdu)
|
||||
|
||||
defer.returnValue(pdus)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
|
||||
"""Requests the PDU with given origin and ID from the remote home
|
||||
server.
|
||||
|
||||
This will persist the PDU locally upon receipt.
|
||||
|
||||
Args:
|
||||
destination (str): Which home server to query
|
||||
pdu_origin (str): The home server that originally sent the pdu.
|
||||
pdu_id (str)
|
||||
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
|
||||
it's from an arbitary point in the context as opposed to part
|
||||
of the current block of PDUs. Defaults to `False`
|
||||
|
||||
Returns:
|
||||
Deferred: Results in the requested PDU.
|
||||
"""
|
||||
|
||||
transaction_data = yield self.transport_layer.get_pdu(
|
||||
destination, pdu_origin, pdu_id)
|
||||
|
||||
transaction = Transaction(**transaction_data)
|
||||
|
||||
pdu_list = [Pdu(outlier=outlier, **p) for p in transaction.pdus]
|
||||
|
||||
pdu = None
|
||||
if pdu_list:
|
||||
pdu = pdu_list[0]
|
||||
yield self._handle_new_pdu(pdu)
|
||||
|
||||
defer.returnValue(pdu)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_state_for_context(self, destination, context):
|
||||
"""Requests all of the `current` state PDUs for a given context from
|
||||
a remote home server.
|
||||
|
||||
Args:
|
||||
destination (str): The remote homeserver to query for the state.
|
||||
context (str): The context we're interested in.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a list of PDUs.
|
||||
"""
|
||||
|
||||
transaction_data = yield self.transport_layer.get_context_state(
|
||||
destination, context)
|
||||
|
||||
transaction = Transaction(**transaction_data)
|
||||
|
||||
pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
|
||||
for pdu in pdus:
|
||||
yield self._handle_new_pdu(pdu)
|
||||
|
||||
defer.returnValue(pdus)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_context_pdus_request(self, context):
|
||||
pdus = yield self.pdu_actions.get_all_pdus_from_context(
|
||||
context
|
||||
)
|
||||
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_paginate_request(self, context, versions, limit):
|
||||
|
||||
pdus = yield self.pdu_actions.paginate(context, versions, limit)
|
||||
|
||||
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_incoming_transaction(self, transaction_data):
|
||||
transaction = Transaction(**transaction_data)
|
||||
|
||||
logger.debug("[%s] Got transaction", transaction.transaction_id)
|
||||
|
||||
response = yield self.transaction_actions.have_responded(transaction)
|
||||
|
||||
if response:
|
||||
logger.debug("[%s] We've already responed to this request",
|
||||
transaction.transaction_id)
|
||||
defer.returnValue(response)
|
||||
return
|
||||
|
||||
logger.debug("[%s] Transacition is new", transaction.transaction_id)
|
||||
|
||||
pdu_list = [Pdu(**p) for p in transaction.pdus]
|
||||
|
||||
dl = []
|
||||
for pdu in pdu_list:
|
||||
dl.append(self._handle_new_pdu(pdu))
|
||||
|
||||
if hasattr(transaction, "edus"):
|
||||
for edu in [Edu(**x) for x in transaction.edus]:
|
||||
self.received_edu(edu.origin, edu.edu_type, edu.content)
|
||||
|
||||
results = yield defer.DeferredList(dl)
|
||||
|
||||
ret = []
|
||||
for r in results:
|
||||
if r[0]:
|
||||
ret.append({})
|
||||
else:
|
||||
logger.exception(r[1])
|
||||
ret.append({"error": str(r[1])})
|
||||
|
||||
logger.debug("Returning: %s", str(ret))
|
||||
|
||||
yield self.transaction_actions.set_response(
|
||||
transaction,
|
||||
200, response
|
||||
)
|
||||
defer.returnValue((200, response))
|
||||
|
||||
def received_edu(self, origin, edu_type, content):
|
||||
if edu_type in self.edu_handlers:
|
||||
self.edu_handlers[edu_type](origin, content)
|
||||
else:
|
||||
logger.warn("Received EDU of type %s with no handler", edu_type)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_context_state_request(self, context):
|
||||
results = yield self.store.get_current_state_for_context(
|
||||
context
|
||||
)
|
||||
|
||||
logger.debug("Context returning %d results", len(results))
|
||||
|
||||
pdus = [Pdu.from_pdu_tuple(p) for p in results]
|
||||
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_pdu_request(self, pdu_origin, pdu_id):
|
||||
pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
|
||||
|
||||
if pdu:
|
||||
defer.returnValue(
|
||||
(200, self._transaction_from_pdus([pdu]).get_dict())
|
||||
)
|
||||
else:
|
||||
defer.returnValue((404, ""))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_pull_request(self, origin, versions):
|
||||
transaction_id = max([int(v) for v in versions])
|
||||
|
||||
response = yield self.pdu_actions.after_transaction(
|
||||
transaction_id,
|
||||
origin,
|
||||
self.server_name
|
||||
)
|
||||
|
||||
if not response:
|
||||
response = []
|
||||
|
||||
defer.returnValue(
|
||||
(200, self._transaction_from_pdus(response).get_dict())
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _get_persisted_pdu(self, pdu_id, pdu_origin):
|
||||
""" Get a PDU from the database with given origin and id.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a `Pdu`.
|
||||
"""
|
||||
pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin)
|
||||
|
||||
defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple))
|
||||
|
||||
def _transaction_from_pdus(self, pdu_list):
|
||||
"""Returns a new Transaction containing the given PDUs suitable for
|
||||
transmission.
|
||||
"""
|
||||
return Transaction(
|
||||
pdus=[p.get_dict() for p in pdu_list],
|
||||
origin=self.server_name,
|
||||
ts=int(self._clock.time_msec()),
|
||||
destination=None,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _handle_new_pdu(self, pdu):
|
||||
# We reprocess pdus when we have seen them only as outliers
|
||||
existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
|
||||
|
||||
if existing and (not existing.outlier or pdu.outlier):
|
||||
logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
|
||||
defer.returnValue({})
|
||||
return
|
||||
|
||||
# Get missing pdus if necessary.
|
||||
is_new = yield self.pdu_actions.is_new(pdu)
|
||||
if is_new and not pdu.outlier:
|
||||
# We only paginate backwards to the min depth.
|
||||
min_depth = yield self.store.get_min_depth_for_context(pdu.context)
|
||||
|
||||
if min_depth and pdu.depth > min_depth:
|
||||
for pdu_id, origin in pdu.prev_pdus:
|
||||
exists = yield self._get_persisted_pdu(pdu_id, origin)
|
||||
|
||||
if not exists:
|
||||
logger.debug("Requesting pdu %s %s", pdu_id, origin)
|
||||
|
||||
try:
|
||||
yield self.get_pdu(
|
||||
pdu.origin,
|
||||
pdu_id=pdu_id,
|
||||
pdu_origin=origin
|
||||
)
|
||||
logger.debug("Processed pdu %s %s", pdu_id, origin)
|
||||
except:
|
||||
# TODO(erikj): Do some more intelligent retries.
|
||||
logger.exception("Failed to get PDU")
|
||||
|
||||
# Persist the Pdu, but don't mark it as processed yet.
|
||||
yield self.pdu_actions.persist_received(pdu)
|
||||
|
||||
ret = yield self.handler.on_receive_pdu(pdu)
|
||||
|
||||
yield self.pdu_actions.mark_as_processed(pdu)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
def __str__(self):
|
||||
return "<ReplicationLayer(%s)>" % self.server_name
|
||||
|
||||
|
||||
class ReplicationHandler(object):
|
||||
"""This defines the methods that the :py:class:`.ReplicationLayer` will
|
||||
use to communicate with the rest of the home server.
|
||||
"""
|
||||
def on_receive_pdu(self, pdu):
|
||||
raise NotImplementedError("on_receive_pdu")
|
||||
|
||||
|
||||
class _TransactionQueue(object):
|
||||
"""This class makes sure we only have one transaction in flight at
|
||||
a time for a given destination.
|
||||
|
||||
It batches pending PDUs into single transactions.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, transaction_actions, transport_layer):
|
||||
|
||||
self.server_name = hs.hostname
|
||||
self.transaction_actions = transaction_actions
|
||||
self.transport_layer = transport_layer
|
||||
|
||||
self._clock = hs.get_clock()
|
||||
|
||||
# Is a mapping from destinations -> deferreds. Used to keep track
|
||||
# of which destinations have transactions in flight and when they are
|
||||
# done
|
||||
self.pending_transactions = {}
|
||||
|
||||
# Is a mapping from destination -> list of
|
||||
# tuple(pending pdus, deferred, order)
|
||||
self.pending_pdus_by_dest = {}
|
||||
# destination -> list of tuple(edu, deferred)
|
||||
self.pending_edus_by_dest = {}
|
||||
|
||||
# HACK to get unique tx id
|
||||
self._next_txn_id = int(self._clock.time_msec())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def enqueue_pdu(self, pdu, order):
|
||||
# We loop through all destinations to see whether we already have
|
||||
# a transaction in progress. If we do, stick it in the pending_pdus
|
||||
# table and we'll get back to it later.
|
||||
|
||||
destinations = [
|
||||
d for d in pdu.destinations
|
||||
if d != self.server_name
|
||||
]
|
||||
|
||||
logger.debug("Sending to: %s", str(destinations))
|
||||
|
||||
if not destinations:
|
||||
return
|
||||
|
||||
deferreds = []
|
||||
|
||||
for destination in destinations:
|
||||
deferred = defer.Deferred()
|
||||
self.pending_pdus_by_dest.setdefault(destination, []).append(
|
||||
(pdu, deferred, order)
|
||||
)
|
||||
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
deferreds.append(deferred)
|
||||
|
||||
yield defer.DeferredList(deferreds)
|
||||
|
||||
# NO inlineCallbacks
|
||||
def enqueue_edu(self, edu):
|
||||
destination = edu.destination
|
||||
|
||||
deferred = defer.Deferred()
|
||||
self.pending_edus_by_dest.setdefault(destination, []).append(
|
||||
(edu, deferred)
|
||||
)
|
||||
|
||||
def eb(failure):
|
||||
deferred.errback(failure)
|
||||
self._attempt_new_transaction(destination).addErrback(eb)
|
||||
|
||||
return deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _attempt_new_transaction(self, destination):
|
||||
if destination in self.pending_transactions:
|
||||
return
|
||||
|
||||
# list of (pending_pdu, deferred, order)
|
||||
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
||||
pending_edus = self.pending_edus_by_dest.pop(destination, [])
|
||||
|
||||
if not pending_pdus and not pending_edus:
|
||||
return
|
||||
|
||||
logger.debug("TX [%s] Attempting new transaction", destination)
|
||||
|
||||
# Sort based on the order field
|
||||
pending_pdus.sort(key=lambda t: t[2])
|
||||
|
||||
pdus = [x[0] for x in pending_pdus]
|
||||
edus = [x[0] for x in pending_edus]
|
||||
deferreds = [x[1] for x in pending_pdus + pending_edus]
|
||||
|
||||
try:
|
||||
self.pending_transactions[destination] = 1
|
||||
|
||||
logger.debug("TX [%s] Persisting transaction...", destination)
|
||||
|
||||
transaction = Transaction.create_new(
|
||||
ts=self._clock.time_msec(),
|
||||
transaction_id=self._next_txn_id,
|
||||
origin=self.server_name,
|
||||
destination=destination,
|
||||
pdus=pdus,
|
||||
edus=edus,
|
||||
)
|
||||
|
||||
self._next_txn_id += 1
|
||||
|
||||
yield self.transaction_actions.prepare_to_send(transaction)
|
||||
|
||||
logger.debug("TX [%s] Persisted transaction", destination)
|
||||
logger.debug("TX [%s] Sending transaction...", destination)
|
||||
|
||||
# Actually send the transaction
|
||||
code, response = yield self.transport_layer.send_transaction(
|
||||
transaction
|
||||
)
|
||||
|
||||
logger.debug("TX [%s] Sent transaction", destination)
|
||||
logger.debug("TX [%s] Marking as delivered...", destination)
|
||||
|
||||
yield self.transaction_actions.delivered(
|
||||
transaction, code, response
|
||||
)
|
||||
|
||||
logger.debug("TX [%s] Marked as delivered", destination)
|
||||
logger.debug("TX [%s] Yielding to callbacks...", destination)
|
||||
|
||||
for deferred in deferreds:
|
||||
if code == 200:
|
||||
deferred.callback(None)
|
||||
else:
|
||||
deferred.errback(RuntimeError("Got status %d" % code))
|
||||
|
||||
# Ensures we don't continue until all callbacks on that
|
||||
# deferred have fired
|
||||
yield deferred
|
||||
|
||||
logger.debug("TX [%s] Yielded to callbacks", destination)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("TX Problem in _attempt_transaction")
|
||||
|
||||
# We capture this here as there as nothing actually listens
|
||||
# for this finishing functions deferred.
|
||||
logger.exception(e)
|
||||
|
||||
for deferred in deferreds:
|
||||
deferred.errback(e)
|
||||
yield deferred
|
||||
|
||||
finally:
|
||||
# We want to be *very* sure we delete this after we stop processing
|
||||
self.pending_transactions.pop(destination, None)
|
||||
|
||||
# Check to see if there is anything else to send.
|
||||
self._attempt_new_transaction(destination)
|
454
synapse/federation/transport.py
Normal file
454
synapse/federation/transport.py
Normal file
|
@ -0,0 +1,454 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The transport layer is responsible for both sending transactions to remote
|
||||
home servers and receiving a variety of requests from other home servers.
|
||||
|
||||
Typically, this is done over HTTP (and all home servers are required to
|
||||
support HTTP), however individual pairings of servers may decide to communicate
|
||||
over a different (albeit still reliable) protocol.
|
||||
"""
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
import logging
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransportLayer(object):
|
||||
"""This is a basic implementation of the transport layer that translates
|
||||
transactions and other requests to/from HTTP.
|
||||
|
||||
Attributes:
|
||||
server_name (str): Local home server host
|
||||
|
||||
server (synapse.http.server.HttpServer): the http server to
|
||||
register listeners on
|
||||
|
||||
client (synapse.http.client.HttpClient): the http client used to
|
||||
send requests
|
||||
|
||||
request_handler (TransportRequestHandler): The handler to fire when we
|
||||
receive requests for data.
|
||||
|
||||
received_handler (TransportReceivedHandler): The handler to fire when
|
||||
we receive data.
|
||||
"""
|
||||
|
||||
def __init__(self, server_name, server, client):
|
||||
"""
|
||||
Args:
|
||||
server_name (str): Local home server host
|
||||
server (synapse.protocol.http.HttpServer): the http server to
|
||||
register listeners on
|
||||
client (synapse.protocol.http.HttpClient): the http client used to
|
||||
send requests
|
||||
"""
|
||||
self.server_name = server_name
|
||||
self.server = server
|
||||
self.client = client
|
||||
self.request_handler = None
|
||||
self.received_handler = None
|
||||
|
||||
@log_function
|
||||
def get_context_state(self, destination, context):
|
||||
""" Requests all state for a given context (i.e. room) from the
|
||||
given server.
|
||||
|
||||
Args:
|
||||
destination (str): The host name of the remote home server we want
|
||||
to get the state from.
|
||||
context (str): The name of the context we want the state of
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
logger.debug("get_context_state dest=%s, context=%s",
|
||||
destination, context)
|
||||
|
||||
path = "/state/%s/" % context
|
||||
|
||||
return self._do_request_for_transaction(destination, path)
|
||||
|
||||
@log_function
|
||||
def get_pdu(self, destination, pdu_origin, pdu_id):
|
||||
""" Requests the pdu with give id and origin from the given server.
|
||||
|
||||
Args:
|
||||
destination (str): The host name of the remote home server we want
|
||||
to get the state from.
|
||||
pdu_origin (str): The home server which created the PDU.
|
||||
pdu_id (str): The id of the PDU being requested.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s",
|
||||
destination, pdu_origin, pdu_id)
|
||||
|
||||
path = "/pdu/%s/%s/" % (pdu_origin, pdu_id)
|
||||
|
||||
return self._do_request_for_transaction(destination, path)
|
||||
|
||||
@log_function
|
||||
def paginate(self, dest, context, pdu_tuples, limit):
|
||||
""" Requests `limit` previous PDUs in a given context before list of
|
||||
PDUs.
|
||||
|
||||
Args:
|
||||
dest (str)
|
||||
context (str)
|
||||
pdu_tuples (list)
|
||||
limt (int)
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
logger.debug(
|
||||
"paginate dest=%s, context=%s, pdu_tuples=%s, limit=%s",
|
||||
dest, context, repr(pdu_tuples), str(limit)
|
||||
)
|
||||
|
||||
if not pdu_tuples:
|
||||
return
|
||||
|
||||
path = "/paginate/%s/" % context
|
||||
|
||||
args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
|
||||
args["limit"] = limit
|
||||
|
||||
return self._do_request_for_transaction(
|
||||
dest,
|
||||
path,
|
||||
args=args,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_transaction(self, transaction):
|
||||
""" Sends the given Transaction to it's destination
|
||||
|
||||
Args:
|
||||
transaction (Transaction)
|
||||
|
||||
Returns:
|
||||
Deferred: Results of the deferred is a tuple in the form of
|
||||
(response_code, response_body) where the response_body is a
|
||||
python dict decoded from json
|
||||
"""
|
||||
logger.debug(
|
||||
"send_data dest=%s, txid=%s",
|
||||
transaction.destination, transaction.transaction_id
|
||||
)
|
||||
|
||||
if transaction.destination == self.server_name:
|
||||
raise RuntimeError("Transport layer cannot send to itself!")
|
||||
|
||||
data = transaction.get_dict()
|
||||
|
||||
code, response = yield self.client.put_json(
|
||||
transaction.destination,
|
||||
path="/send/%s/" % transaction.transaction_id,
|
||||
data=data
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"send_data dest=%s, txid=%s, got response: %d",
|
||||
transaction.destination, transaction.transaction_id, code
|
||||
)
|
||||
|
||||
defer.returnValue((code, response))
|
||||
|
||||
@log_function
|
||||
def register_received_handler(self, handler):
|
||||
""" Register a handler that will be fired when we receive data.
|
||||
|
||||
Args:
|
||||
handler (TransportReceivedHandler)
|
||||
"""
|
||||
self.received_handler = handler
|
||||
|
||||
# This is when someone is trying to send us a bunch of data.
|
||||
self.server.register_path(
|
||||
"PUT",
|
||||
re.compile("^/send/([^/]*)/$"),
|
||||
self._on_send_request
|
||||
)
|
||||
|
||||
@log_function
|
||||
def register_request_handler(self, handler):
|
||||
""" Register a handler that will be fired when we get asked for data.
|
||||
|
||||
Args:
|
||||
handler (TransportRequestHandler)
|
||||
"""
|
||||
self.request_handler = handler
|
||||
|
||||
# TODO(markjh): Namespace the federation URI paths
|
||||
|
||||
# This is for when someone asks us for everything since version X
|
||||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^/pull/$"),
|
||||
lambda request: handler.on_pull_request(
|
||||
request.args["origin"][0],
|
||||
request.args["v"]
|
||||
)
|
||||
)
|
||||
|
||||
# This is when someone asks for a data item for a given server
|
||||
# data_id pair.
|
||||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^/pdu/([^/]*)/([^/]*)/$"),
|
||||
lambda request, pdu_origin, pdu_id: handler.on_pdu_request(
|
||||
pdu_origin, pdu_id
|
||||
)
|
||||
)
|
||||
|
||||
# This is when someone asks for all data for a given context.
|
||||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^/state/([^/]*)/$"),
|
||||
lambda request, context: handler.on_context_state_request(
|
||||
context
|
||||
)
|
||||
)
|
||||
|
||||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^/paginate/([^/]*)/$"),
|
||||
lambda request, context: self._on_paginate_request(
|
||||
context, request.args["v"],
|
||||
request.args["limit"]
|
||||
)
|
||||
)
|
||||
|
||||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^/context/([^/]*)/$"),
|
||||
lambda request, context: handler.on_context_pdus_request(context)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _on_send_request(self, request, transaction_id):
|
||||
""" Called on PUT /send/<transaction_id>/
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request): The HTTP request.
|
||||
transaction_id (str): The transaction_id associated with this
|
||||
request. This is *not* None.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a tuple of `(code, response)`, where
|
||||
`response` is a python dict to be converted into JSON that is
|
||||
used as the response body.
|
||||
"""
|
||||
# Parse the request
|
||||
try:
|
||||
data = request.content.read()
|
||||
|
||||
l = data[:20].encode("string_escape")
|
||||
logger.debug("Got data: \"%s\"", l)
|
||||
|
||||
transaction_data = json.loads(data)
|
||||
|
||||
logger.debug(
|
||||
"Decoded %s: %s",
|
||||
transaction_id, str(transaction_data)
|
||||
)
|
||||
|
||||
# We should ideally be getting this from the security layer.
|
||||
# origin = body["origin"]
|
||||
|
||||
# Add some extra data to the transaction dict that isn't included
|
||||
# in the request body.
|
||||
transaction_data.update(
|
||||
transaction_id=transaction_id,
|
||||
destination=self.server_name
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
defer.returnValue((400, {"error": "Invalid transaction"}))
|
||||
return
|
||||
|
||||
code, response = yield self.received_handler.on_incoming_transaction(
|
||||
transaction_data
|
||||
)
|
||||
|
||||
defer.returnValue((code, response))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _do_request_for_transaction(self, destination, path, args={}):
|
||||
"""
|
||||
Args:
|
||||
destination (str)
|
||||
path (str)
|
||||
args (dict): This is parsed directly to the HttpClient.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a dict.
|
||||
"""
|
||||
|
||||
data = yield self.client.get_json(
|
||||
destination,
|
||||
path=path,
|
||||
args=args,
|
||||
)
|
||||
|
||||
# Add certain keys to the JSON, ready for decoding as a Transaction
|
||||
data.update(
|
||||
origin=destination,
|
||||
destination=self.server_name,
|
||||
transaction_id=None
|
||||
)
|
||||
|
||||
defer.returnValue(data)
|
||||
|
||||
@log_function
|
||||
def _on_paginate_request(self, context, v_list, limits):
|
||||
if not limits:
|
||||
return defer.succeed(
|
||||
(400, {"error": "Did not include limit param"})
|
||||
)
|
||||
|
||||
limit = int(limits[-1])
|
||||
|
||||
versions = [v.split(",", 1) for v in v_list]
|
||||
|
||||
return self.request_handler.on_paginate_request(
|
||||
context, versions, limit)
|
||||
|
||||
|
||||
class TransportReceivedHandler(object):
|
||||
""" Callbacks used when we receive a transaction
|
||||
"""
|
||||
def on_incoming_transaction(self, transaction):
|
||||
""" Called on PUT /send/<transaction_id>, or on response to a request
|
||||
that we sent (e.g. a pagination request)
|
||||
|
||||
Args:
|
||||
transaction (synapse.transaction.Transaction): The transaction that
|
||||
was sent to us.
|
||||
|
||||
Returns:
|
||||
twisted.internet.defer.Deferred: A deferred that get's fired when
|
||||
the transaction has finished being processed.
|
||||
|
||||
The result should be a tuple in the form of
|
||||
`(response_code, respond_body)`, where `response_body` is a python
|
||||
dict that will get serialized to JSON.
|
||||
|
||||
On errors, the dict should have an `error` key with a brief message
|
||||
of what went wrong.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TransportRequestHandler(object):
|
||||
""" Handlers used when someone want's data from us
|
||||
"""
|
||||
def on_pull_request(self, versions):
|
||||
""" Called on GET /pull/?v=...
|
||||
|
||||
This is hit when a remote home server wants to get all data
|
||||
after a given transaction. Mainly used when a home server comes back
|
||||
online and wants to get everything it has missed.
|
||||
|
||||
Args:
|
||||
versions (list): A list of transaction_ids that should be used to
|
||||
determine what PDUs the remote side have not yet seen.
|
||||
|
||||
Returns:
|
||||
Deferred: Resultsin a tuple in the form of
|
||||
`(response_code, respond_body)`, where `response_body` is a python
|
||||
dict that will get serialized to JSON.
|
||||
|
||||
On errors, the dict should have an `error` key with a brief message
|
||||
of what went wrong.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_pdu_request(self, pdu_origin, pdu_id):
|
||||
""" Called on GET /pdu/<pdu_origin>/<pdu_id>/
|
||||
|
||||
Someone wants a particular PDU. This PDU may or may not have originated
|
||||
from us.
|
||||
|
||||
Args:
|
||||
pdu_origin (str)
|
||||
pdu_id (str)
|
||||
|
||||
Returns:
|
||||
Deferred: Resultsin a tuple in the form of
|
||||
`(response_code, respond_body)`, where `response_body` is a python
|
||||
dict that will get serialized to JSON.
|
||||
|
||||
On errors, the dict should have an `error` key with a brief message
|
||||
of what went wrong.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_context_state_request(self, context):
|
||||
""" Called on GET /state/<context>/
|
||||
|
||||
Get's hit when someone wants all the *current* state for a given
|
||||
contexts.
|
||||
|
||||
Args:
|
||||
context (str): The name of the context that we're interested in.
|
||||
|
||||
Returns:
|
||||
twisted.internet.defer.Deferred: A deferred that get's fired when
|
||||
the transaction has finished being processed.
|
||||
|
||||
The result should be a tuple in the form of
|
||||
`(response_code, respond_body)`, where `response_body` is a python
|
||||
dict that will get serialized to JSON.
|
||||
|
||||
On errors, the dict should have an `error` key with a brief message
|
||||
of what went wrong.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_paginate_request(self, context, versions, limit):
|
||||
""" Called on GET /paginate/<context>/?v=...&limit=...
|
||||
|
||||
Get's hit when we want to paginate backwards on a given context from
|
||||
the given point.
|
||||
|
||||
Args:
|
||||
context (str): The context to paginate on
|
||||
versions (list): A list of 2-tuple's representing where to paginate
|
||||
from, in the form `(pdu_id, origin)`
|
||||
limit (int): How many pdus to return.
|
||||
|
||||
Returns:
|
||||
Deferred: Resultsin a tuple in the form of
|
||||
`(response_code, respond_body)`, where `response_body` is a python
|
||||
dict that will get serialized to JSON.
|
||||
|
||||
On errors, the dict should have an `error` key with a brief message
|
||||
of what went wrong.
|
||||
"""
|
||||
pass
|
236
synapse/federation/units.py
Normal file
236
synapse/federation/units.py
Normal file
|
@ -0,0 +1,236 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Defines the JSON structure of the protocol units used by the server to
|
||||
server protocol.
|
||||
"""
|
||||
|
||||
from synapse.util.jsonobject import JsonEncodedObject
|
||||
|
||||
import logging
|
||||
import json
|
||||
import copy
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Pdu(JsonEncodedObject):
|
||||
""" A Pdu represents a piece of data sent from a server and is associated
|
||||
with a context.
|
||||
|
||||
A Pdu can be classified as "state". For a given context, we can efficiently
|
||||
retrieve all state pdu's that haven't been clobbered. Clobbering is done
|
||||
via a unique constraint on the tuple (context, pdu_type, state_key). A pdu
|
||||
is a state pdu if `is_state` is True.
|
||||
|
||||
Example pdu::
|
||||
|
||||
{
|
||||
"pdu_id": "78c",
|
||||
"ts": 1404835423000,
|
||||
"origin": "bar",
|
||||
"prev_ids": [
|
||||
["23b", "foo"],
|
||||
["56a", "bar"],
|
||||
],
|
||||
"content": { ... },
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
valid_keys = [
|
||||
"pdu_id",
|
||||
"context",
|
||||
"origin",
|
||||
"ts",
|
||||
"pdu_type",
|
||||
"destinations",
|
||||
"transaction_id",
|
||||
"prev_pdus",
|
||||
"depth",
|
||||
"content",
|
||||
"outlier",
|
||||
"is_state", # Below this are keys valid only for State Pdus.
|
||||
"state_key",
|
||||
"power_level",
|
||||
"prev_state_id",
|
||||
"prev_state_origin",
|
||||
]
|
||||
|
||||
internal_keys = [
|
||||
"destinations",
|
||||
"transaction_id",
|
||||
"outlier",
|
||||
]
|
||||
|
||||
required_keys = [
|
||||
"pdu_id",
|
||||
"context",
|
||||
"origin",
|
||||
"ts",
|
||||
"pdu_type",
|
||||
"content",
|
||||
]
|
||||
|
||||
# TODO: We need to make this properly load content rather than
|
||||
# just leaving it as a dict. (OR DO WE?!)
|
||||
|
||||
def __init__(self, destinations=[], is_state=False, prev_pdus=[],
|
||||
outlier=False, **kwargs):
|
||||
if is_state:
|
||||
for required_key in ["state_key"]:
|
||||
if required_key not in kwargs:
|
||||
raise RuntimeError("Key %s is required" % required_key)
|
||||
|
||||
super(Pdu, self).__init__(
|
||||
destinations=destinations,
|
||||
is_state=is_state,
|
||||
prev_pdus=prev_pdus,
|
||||
outlier=outlier,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pdu_tuple(cls, pdu_tuple):
|
||||
""" Converts a PduTuple to a Pdu
|
||||
|
||||
Args:
|
||||
pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to
|
||||
convert
|
||||
|
||||
Returns:
|
||||
Pdu
|
||||
"""
|
||||
if pdu_tuple:
|
||||
d = copy.copy(pdu_tuple.pdu_entry._asdict())
|
||||
|
||||
d["content"] = json.loads(d["content_json"])
|
||||
del d["content_json"]
|
||||
|
||||
args = {f: d[f] for f in cls.valid_keys if f in d}
|
||||
if "unrecognized_keys" in d and d["unrecognized_keys"]:
|
||||
args.update(json.loads(d["unrecognized_keys"]))
|
||||
|
||||
return Pdu(
|
||||
prev_pdus=pdu_tuple.prev_pdu_list,
|
||||
**args
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def __str__(self):
|
||||
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s, %s>" % (self.__class__.__name__, repr(self.__dict__))
|
||||
|
||||
|
||||
class Edu(JsonEncodedObject):
|
||||
""" An Edu represents a piece of data sent from one homeserver to another.
|
||||
|
||||
In comparison to Pdus, Edus are not persisted for a long time on disk, are
|
||||
not meaningful beyond a given pair of homeservers, and don't have an
|
||||
internal ID or previous references graph.
|
||||
"""
|
||||
|
||||
valid_keys = [
|
||||
"origin",
|
||||
"destination",
|
||||
"edu_type",
|
||||
"content",
|
||||
]
|
||||
|
||||
required_keys = [
|
||||
"origin",
|
||||
"destination",
|
||||
"edu_type",
|
||||
]
|
||||
|
||||
|
||||
class Transaction(JsonEncodedObject):
|
||||
""" A transaction is a list of Pdus and Edus to be sent to a remote home
|
||||
server with some extra metadata.
|
||||
|
||||
Example transaction::
|
||||
|
||||
{
|
||||
"origin": "foo",
|
||||
"prev_ids": ["abc", "def"],
|
||||
"pdus": [
|
||||
...
|
||||
],
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
valid_keys = [
|
||||
"transaction_id",
|
||||
"origin",
|
||||
"destination",
|
||||
"ts",
|
||||
"previous_ids",
|
||||
"pdus",
|
||||
"edus",
|
||||
]
|
||||
|
||||
internal_keys = [
|
||||
"transaction_id",
|
||||
"destination",
|
||||
]
|
||||
|
||||
required_keys = [
|
||||
"transaction_id",
|
||||
"origin",
|
||||
"destination",
|
||||
"ts",
|
||||
"pdus",
|
||||
]
|
||||
|
||||
def __init__(self, transaction_id=None, pdus=[], **kwargs):
|
||||
""" If we include a list of pdus then we decode then as PDU's
|
||||
automatically.
|
||||
"""
|
||||
|
||||
# If there's no EDUs then remove the arg
|
||||
if "edus" in kwargs and not kwargs["edus"]:
|
||||
del kwargs["edus"]
|
||||
|
||||
super(Transaction, self).__init__(
|
||||
transaction_id=transaction_id,
|
||||
pdus=pdus,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_new(pdus, **kwargs):
|
||||
""" Used to create a new transaction. Will auto fill out
|
||||
transaction_id and ts keys.
|
||||
"""
|
||||
if "ts" not in kwargs:
|
||||
raise KeyError("Require 'ts' to construct a Transaction")
|
||||
if "transaction_id" not in kwargs:
|
||||
raise KeyError(
|
||||
"Require 'transaction_id' to construct a Transaction"
|
||||
)
|
||||
|
||||
for p in pdus:
|
||||
p.transaction_id = kwargs["transaction_id"]
|
||||
|
||||
kwargs["pdus"] = [p.get_dict() for p in pdus]
|
||||
|
||||
return Transaction(**kwargs)
|
||||
|
||||
|
||||
|
46
synapse/handlers/__init__.py
Normal file
46
synapse/handlers/__init__.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .register import RegistrationHandler
|
||||
from .room import (
|
||||
MessageHandler, RoomCreationHandler, RoomMemberHandler, RoomListHandler
|
||||
)
|
||||
from .events import EventStreamHandler
|
||||
from .federation import FederationHandler
|
||||
from .login import LoginHandler
|
||||
from .profile import ProfileHandler
|
||||
from .presence import PresenceHandler
|
||||
from .directory import DirectoryHandler
|
||||
|
||||
|
||||
class Handlers(object):
|
||||
|
||||
""" A collection of all the event handlers.
|
||||
|
||||
There's no need to lazily create these; we'll just make them all eagerly
|
||||
at construction time.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.registration_handler = RegistrationHandler(hs)
|
||||
self.message_handler = MessageHandler(hs)
|
||||
self.room_creation_handler = RoomCreationHandler(hs)
|
||||
self.room_member_handler = RoomMemberHandler(hs)
|
||||
self.event_stream_handler = EventStreamHandler(hs)
|
||||
self.federation_handler = FederationHandler(hs)
|
||||
self.profile_handler = ProfileHandler(hs)
|
||||
self.presence_handler = PresenceHandler(hs)
|
||||
self.room_list_handler = RoomListHandler(hs)
|
||||
self.login_handler = LoginHandler(hs)
|
||||
self.directory_handler = DirectoryHandler(hs)
|
26
synapse/handlers/_base.py
Normal file
26
synapse/handlers/_base.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class BaseHandler(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.auth = hs.get_auth()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.room_lock = hs.get_room_lock_manager()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.hs = hs
|
100
synapse/handlers/directory.py
Normal file
100
synapse/handlers/directory.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from ._base import BaseHandler
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
import logging
|
||||
import json
|
||||
import urllib
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO(erikj): This needs to be factored out somewere
|
||||
PREFIX = "/matrix/client/api/v1"
|
||||
|
||||
|
||||
class DirectoryHandler(BaseHandler):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(DirectoryHandler, self).__init__(hs)
|
||||
self.hs = hs
|
||||
self.http_client = hs.get_http_client()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_association(self, room_alias, room_id, servers):
|
||||
# TODO(erikj): Do auth.
|
||||
|
||||
if not room_alias.is_mine:
|
||||
raise SynapseError(400, "Room alias must be local")
|
||||
# TODO(erikj): Change this.
|
||||
|
||||
# TODO(erikj): Add transactions.
|
||||
|
||||
# TODO(erikj): Check if there is a current association.
|
||||
|
||||
yield self.store.create_room_alias_association(
|
||||
room_alias,
|
||||
room_id,
|
||||
servers
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_association(self, room_alias, local_only=False):
|
||||
# TODO(erikj): Do auth
|
||||
|
||||
room_id = None
|
||||
if room_alias.is_mine:
|
||||
result = yield self.store.get_association_from_room_alias(
|
||||
room_alias
|
||||
)
|
||||
|
||||
if result:
|
||||
room_id = result.room_id
|
||||
servers = result.servers
|
||||
elif not local_only:
|
||||
path = "%s/ds/room/%s?local_only=1" % (
|
||||
PREFIX,
|
||||
urllib.quote(room_alias.to_string())
|
||||
)
|
||||
|
||||
result = None
|
||||
try:
|
||||
result = yield self.http_client.get_json(
|
||||
destination=room_alias.domain,
|
||||
path=path,
|
||||
)
|
||||
except:
|
||||
# TODO(erikj): Handle this better?
|
||||
logger.exception("Failed to get remote room alias")
|
||||
|
||||
if result and "room_id" in result and "servers" in result:
|
||||
room_id = result["room_id"]
|
||||
servers = result["servers"]
|
||||
|
||||
if not room_id:
|
||||
defer.returnValue({})
|
||||
return
|
||||
|
||||
defer.returnValue({
|
||||
"room_id": room_id,
|
||||
"servers": servers,
|
||||
})
|
||||
return
|
149
synapse/handlers/events.py
Normal file
149
synapse/handlers/events.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import BaseHandler
|
||||
from synapse.api.streams.event import (
|
||||
EventStream, MessagesStreamData, RoomMemberStreamData, FeedbackStreamData,
|
||||
RoomDataStreamData
|
||||
)
|
||||
from synapse.handlers.presence import PresenceStreamData
|
||||
|
||||
|
||||
class EventStreamHandler(BaseHandler):
|
||||
|
||||
stream_data_classes = [
|
||||
MessagesStreamData,
|
||||
RoomMemberStreamData,
|
||||
FeedbackStreamData,
|
||||
RoomDataStreamData,
|
||||
PresenceStreamData,
|
||||
]
|
||||
|
||||
def __init__(self, hs):
|
||||
super(EventStreamHandler, self).__init__(hs)
|
||||
|
||||
# Count of active streams per user
|
||||
self._streams_per_user = {}
|
||||
# Grace timers per user to delay the "stopped" signal
|
||||
self._stop_timer_per_user = {}
|
||||
|
||||
self.distributor = hs.get_distributor()
|
||||
self.distributor.declare("started_user_eventstream")
|
||||
self.distributor.declare("stopped_user_eventstream")
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def get_event_stream_token(self, stream_type, store_id, start_token):
|
||||
"""Return the next token after this event.
|
||||
|
||||
Args:
|
||||
stream_type (str): The StreamData.EVENT_TYPE
|
||||
store_id (int): The new storage ID assigned from the data store.
|
||||
start_token (str): The token the user started with.
|
||||
Returns:
|
||||
str: The end token.
|
||||
"""
|
||||
for i, stream_cls in enumerate(EventStreamHandler.stream_data_classes):
|
||||
if stream_cls.EVENT_TYPE == stream_type:
|
||||
# this is the stream for this event, so replace this part of
|
||||
# the token
|
||||
store_ids = start_token.split(EventStream.SEPARATOR)
|
||||
store_ids[i] = str(store_id)
|
||||
return EventStream.SEPARATOR.join(store_ids)
|
||||
raise RuntimeError("Didn't find a stream type %s" % stream_type)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_stream(self, auth_user_id, pagin_config, timeout=0):
|
||||
"""Gets events as an event stream for this user.
|
||||
|
||||
This function looks for interesting *events* for this user. This is
|
||||
different from the notifier, which looks for interested *users* who may
|
||||
want to know about a single event.
|
||||
|
||||
Args:
|
||||
auth_user_id (str): The user requesting their event stream.
|
||||
pagin_config (synapse.api.streams.PaginationConfig): The config to
|
||||
use when obtaining the stream.
|
||||
timeout (int): The max time to wait for an incoming event in ms.
|
||||
Returns:
|
||||
A pagination stream API dict
|
||||
"""
|
||||
auth_user = self.hs.parse_userid(auth_user_id)
|
||||
|
||||
stream_id = object()
|
||||
|
||||
try:
|
||||
if auth_user not in self._streams_per_user:
|
||||
self._streams_per_user[auth_user] = 0
|
||||
if auth_user in self._stop_timer_per_user:
|
||||
self.clock.cancel_call_later(
|
||||
self._stop_timer_per_user.pop(auth_user))
|
||||
else:
|
||||
self.distributor.fire(
|
||||
"started_user_eventstream", auth_user
|
||||
)
|
||||
self._streams_per_user[auth_user] += 1
|
||||
|
||||
# construct an event stream with the correct data ordering
|
||||
stream_data_list = []
|
||||
for stream_class in EventStreamHandler.stream_data_classes:
|
||||
stream_data_list.append(stream_class(self.hs))
|
||||
event_stream = EventStream(auth_user_id, stream_data_list)
|
||||
|
||||
# fix unknown tokens to known tokens
|
||||
pagin_config = yield event_stream.fix_tokens(pagin_config)
|
||||
|
||||
# register interest in receiving new events
|
||||
self.notifier.store_events_for(user_id=auth_user_id,
|
||||
stream_id=stream_id,
|
||||
from_tok=pagin_config.from_tok)
|
||||
|
||||
# see if we can grab a chunk now
|
||||
data_chunk = yield event_stream.get_chunk(config=pagin_config)
|
||||
|
||||
# if there are previous events, return those. If not, wait on the
|
||||
# new events for 'timeout' seconds.
|
||||
if len(data_chunk["chunk"]) == 0 and timeout != 0:
|
||||
results = yield defer.maybeDeferred(
|
||||
self.notifier.get_events_for,
|
||||
user_id=auth_user_id,
|
||||
stream_id=stream_id,
|
||||
timeout=timeout
|
||||
)
|
||||
if results:
|
||||
defer.returnValue(results)
|
||||
|
||||
defer.returnValue(data_chunk)
|
||||
finally:
|
||||
# cleanup
|
||||
self.notifier.purge_events_for(user_id=auth_user_id,
|
||||
stream_id=stream_id)
|
||||
|
||||
self._streams_per_user[auth_user] -= 1
|
||||
if not self._streams_per_user[auth_user]:
|
||||
del self._streams_per_user[auth_user]
|
||||
|
||||
# 10 seconds of grace to allow the client to reconnect again
|
||||
# before we think they're gone
|
||||
def _later():
|
||||
self.distributor.fire(
|
||||
"stopped_user_eventstream", auth_user
|
||||
)
|
||||
del self._stop_timer_per_user[auth_user]
|
||||
|
||||
self._stop_timer_per_user[auth_user] = (
|
||||
self.clock.call_later(5, _later)
|
||||
)
|
74
synapse/handlers/federation.py
Normal file
74
synapse/handlers/federation.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains handlers for federation events."""
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
from synapse.api.events.room import InviteJoinEvent, RoomMemberEvent
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FederationHandler(BaseHandler):
|
||||
|
||||
"""Handles events that originated from federation."""
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def on_receive(self, event, is_new_state):
|
||||
if hasattr(event, "state_key") and not is_new_state:
|
||||
logger.debug("Ignoring old state.")
|
||||
return
|
||||
|
||||
target_is_mine = False
|
||||
if hasattr(event, "target_host"):
|
||||
target_is_mine = event.target_host == self.hs.hostname
|
||||
|
||||
if event.type == InviteJoinEvent.TYPE:
|
||||
if not target_is_mine:
|
||||
logger.debug("Ignoring invite/join event %s", event)
|
||||
return
|
||||
|
||||
# If we receive an invite/join event then we need to join the
|
||||
# sender to the given room.
|
||||
# TODO: We should probably auth this or some such
|
||||
content = event.content
|
||||
content.update({"membership": Membership.JOIN})
|
||||
new_event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
target_user_id=event.user_id,
|
||||
room_id=event.room_id,
|
||||
user_id=event.user_id,
|
||||
membership=Membership.JOIN,
|
||||
content=content
|
||||
)
|
||||
|
||||
yield self.hs.get_handlers().room_member_handler.change_membership(
|
||||
new_event,
|
||||
True
|
||||
)
|
||||
|
||||
else:
|
||||
with (yield self.room_lock.lock(event.room_id)):
|
||||
store_id = yield self.store.persist_event(event)
|
||||
|
||||
yield self.notifier.on_new_room_event(event, store_id)
|
64
synapse/handlers/login.py
Normal file
64
synapse/handlers/login.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import BaseHandler
|
||||
from synapse.api.errors import LoginError
|
||||
|
||||
import bcrypt
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginHandler(BaseHandler):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(LoginHandler, self).__init__(hs)
|
||||
self.hs = hs
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def login(self, user, password):
|
||||
"""Login as the specified user with the specified password.
|
||||
|
||||
Args:
|
||||
user (str): The user ID.
|
||||
password (str): The password.
|
||||
Returns:
|
||||
The newly allocated access token.
|
||||
Raises:
|
||||
StoreError if there was a problem storing the token.
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
# TODO do this better, it can't go in __init__ else it cyclic loops
|
||||
if not hasattr(self, "reg_handler"):
|
||||
self.reg_handler = self.hs.get_handlers().registration_handler
|
||||
|
||||
# pull out the hash for this user if they exist
|
||||
user_info = yield self.store.get_user_by_id(user_id=user)
|
||||
if not user_info:
|
||||
logger.warn("Attempted to login as %s but they do not exist.", user)
|
||||
raise LoginError(403, "")
|
||||
|
||||
stored_hash = user_info[0]["password_hash"]
|
||||
if bcrypt.checkpw(password, stored_hash):
|
||||
# generate an access token and store it.
|
||||
token = self.reg_handler._generate_token(user)
|
||||
logger.info("Adding token %s for user %s", token, user)
|
||||
yield self.store.add_access_token_to_user(user, token)
|
||||
defer.returnValue(token)
|
||||
else:
|
||||
logger.warn("Failed password login for user %s", user)
|
||||
raise LoginError(403, "")
|
697
synapse/handlers/presence.py
Normal file
697
synapse/handlers/presence.py
Normal file
|
@ -0,0 +1,697 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError, AuthError
|
||||
from synapse.api.constants import PresenceState
|
||||
from synapse.api.streams import StreamData
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO(paul): Maybe there's one of these I can steal from somewhere
|
||||
def partition(l, func):
|
||||
"""Partition the list by the result of func applied to each element."""
|
||||
ret = {}
|
||||
|
||||
for x in l:
|
||||
key = func(x)
|
||||
if key not in ret:
|
||||
ret[key] = []
|
||||
ret[key].append(x)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def partitionbool(l, func):
|
||||
def boolfunc(x):
|
||||
return bool(func(x))
|
||||
|
||||
ret = partition(l, boolfunc)
|
||||
return ret.get(True, []), ret.get(False, [])
|
||||
|
||||
|
||||
class PresenceHandler(BaseHandler):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PresenceHandler, self).__init__(hs)
|
||||
|
||||
self.homeserver = hs
|
||||
|
||||
distributor = hs.get_distributor()
|
||||
distributor.observe("registered_user", self.registered_user)
|
||||
|
||||
distributor.observe(
|
||||
"started_user_eventstream", self.started_user_eventstream
|
||||
)
|
||||
distributor.observe(
|
||||
"stopped_user_eventstream", self.stopped_user_eventstream
|
||||
)
|
||||
|
||||
distributor.observe("user_joined_room",
|
||||
self.user_joined_room
|
||||
)
|
||||
|
||||
distributor.declare("collect_presencelike_data")
|
||||
|
||||
distributor.declare("changed_presencelike_data")
|
||||
distributor.observe(
|
||||
"changed_presencelike_data", self.changed_presencelike_data
|
||||
)
|
||||
|
||||
self.distributor = distributor
|
||||
|
||||
self.federation = hs.get_replication_layer()
|
||||
|
||||
self.federation.register_edu_handler(
|
||||
"m.presence", self.incoming_presence
|
||||
)
|
||||
self.federation.register_edu_handler(
|
||||
"m.presence_invite",
|
||||
lambda origin, content: self.invite_presence(
|
||||
observed_user=hs.parse_userid(content["observed_user"]),
|
||||
observer_user=hs.parse_userid(content["observer_user"]),
|
||||
)
|
||||
)
|
||||
self.federation.register_edu_handler(
|
||||
"m.presence_accept",
|
||||
lambda origin, content: self.accept_presence(
|
||||
observed_user=hs.parse_userid(content["observed_user"]),
|
||||
observer_user=hs.parse_userid(content["observer_user"]),
|
||||
)
|
||||
)
|
||||
self.federation.register_edu_handler(
|
||||
"m.presence_deny",
|
||||
lambda origin, content: self.deny_presence(
|
||||
observed_user=hs.parse_userid(content["observed_user"]),
|
||||
observer_user=hs.parse_userid(content["observer_user"]),
|
||||
)
|
||||
)
|
||||
|
||||
# IN-MEMORY store, mapping local userparts to sets of local users to
|
||||
# be informed of state changes.
|
||||
self._local_pushmap = {}
|
||||
# map local users to sets of remote /domain names/ who are interested
|
||||
# in them
|
||||
self._remote_sendmap = {}
|
||||
# map remote users to sets of local users who're interested in them
|
||||
self._remote_recvmap = {}
|
||||
|
||||
# map any user to a UserPresenceCache
|
||||
self._user_cachemap = {}
|
||||
self._user_cachemap_latest_serial = 0
|
||||
|
||||
def _get_or_make_usercache(self, user):
|
||||
"""If the cache entry doesn't exist, initialise a new one."""
|
||||
if user not in self._user_cachemap:
|
||||
self._user_cachemap[user] = UserPresenceCache()
|
||||
return self._user_cachemap[user]
|
||||
|
||||
def _get_or_offline_usercache(self, user):
|
||||
"""If the cache entry doesn't exist, return an OFFLINE one but do not
|
||||
store it into the cache."""
|
||||
if user in self._user_cachemap:
|
||||
return self._user_cachemap[user]
|
||||
else:
|
||||
statuscache = UserPresenceCache()
|
||||
statuscache.update({"state": PresenceState.OFFLINE}, user)
|
||||
return statuscache
|
||||
|
||||
def registered_user(self, user):
|
||||
self.store.create_presence(user.localpart)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_presence_visible(self, observer_user, observed_user):
|
||||
assert(observed_user.is_mine)
|
||||
|
||||
if observer_user == observed_user:
|
||||
defer.returnValue(True)
|
||||
|
||||
allowed_by_subscription = yield self.store.is_presence_visible(
|
||||
observed_localpart=observed_user.localpart,
|
||||
observer_userid=observer_user.to_string(),
|
||||
)
|
||||
|
||||
if allowed_by_subscription:
|
||||
defer.returnValue(True)
|
||||
|
||||
# TODO(paul): Check same channel
|
||||
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state(self, target_user, auth_user):
|
||||
if target_user.is_mine:
|
||||
visible = yield self.is_presence_visible(observer_user=auth_user,
|
||||
observed_user=target_user
|
||||
)
|
||||
|
||||
if visible:
|
||||
state = yield self.store.get_presence_state(
|
||||
target_user.localpart
|
||||
)
|
||||
defer.returnValue(state)
|
||||
else:
|
||||
raise SynapseError(404, "Presence information not visible")
|
||||
else:
|
||||
# TODO(paul): Have remote server send us permissions set
|
||||
defer.returnValue(
|
||||
self._get_or_offline_usercache(target_user).get_state()
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_state(self, target_user, auth_user, state):
|
||||
if not target_user.is_mine:
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if target_user != auth_user:
|
||||
raise AuthError(400, "Cannot set another user's displayname")
|
||||
|
||||
# TODO(paul): Sanity-check 'state'
|
||||
if "status_msg" not in state:
|
||||
state["status_msg"] = None
|
||||
|
||||
for k in state.keys():
|
||||
if k not in ("state", "status_msg"):
|
||||
raise SynapseError(
|
||||
400, "Unexpected presence state key '%s'" % (k,)
|
||||
)
|
||||
|
||||
logger.debug("Updating presence state of %s to %s",
|
||||
target_user.localpart, state["state"])
|
||||
|
||||
state_to_store = dict(state)
|
||||
|
||||
yield defer.DeferredList([
|
||||
self.store.set_presence_state(
|
||||
target_user.localpart, state_to_store
|
||||
),
|
||||
self.distributor.fire(
|
||||
"collect_presencelike_data", target_user, state
|
||||
),
|
||||
])
|
||||
|
||||
now_online = state["state"] != PresenceState.OFFLINE
|
||||
was_polling = target_user in self._user_cachemap
|
||||
|
||||
if now_online and not was_polling:
|
||||
self.start_polling_presence(target_user, state=state)
|
||||
elif not now_online and was_polling:
|
||||
self.stop_polling_presence(target_user)
|
||||
|
||||
# TODO(paul): perform a presence push as part of start/stop poll so
|
||||
# we don't have to do this all the time
|
||||
self.changed_presencelike_data(target_user, state)
|
||||
|
||||
if not now_online:
|
||||
del self._user_cachemap[target_user]
|
||||
|
||||
def changed_presencelike_data(self, user, state):
|
||||
statuscache = self._get_or_make_usercache(user)
|
||||
|
||||
self._user_cachemap_latest_serial += 1
|
||||
statuscache.update(state, serial=self._user_cachemap_latest_serial)
|
||||
|
||||
self.push_presence(user, statuscache=statuscache)
|
||||
|
||||
def started_user_eventstream(self, user):
|
||||
# TODO(paul): Use "last online" state
|
||||
self.set_state(user, user, {"state": PresenceState.ONLINE})
|
||||
|
||||
def stopped_user_eventstream(self, user):
|
||||
# TODO(paul): Save current state as "last online" state
|
||||
self.set_state(user, user, {"state": PresenceState.OFFLINE})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_joined_room(self, user, room_id):
|
||||
localusers = set()
|
||||
remotedomains = set()
|
||||
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
yield rm_handler.fetch_room_distributions_into(room_id,
|
||||
localusers=localusers, remotedomains=remotedomains,
|
||||
ignore_user=user)
|
||||
|
||||
if user.is_mine:
|
||||
yield self._send_presence_to_distribution(srcuser=user,
|
||||
localusers=localusers, remotedomains=remotedomains,
|
||||
statuscache=self._get_or_offline_usercache(user),
|
||||
)
|
||||
|
||||
for srcuser in localusers:
|
||||
yield self._send_presence(srcuser=srcuser, destuser=user,
|
||||
statuscache=self._get_or_offline_usercache(srcuser),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_invite(self, observer_user, observed_user):
|
||||
if not observer_user.is_mine:
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
yield self.store.add_presence_list_pending(
|
||||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
|
||||
if observed_user.is_mine:
|
||||
yield self.invite_presence(observed_user, observer_user)
|
||||
else:
|
||||
yield self.federation.send_edu(
|
||||
destination=observed_user.domain,
|
||||
edu_type="m.presence_invite",
|
||||
content={
|
||||
"observed_user": observed_user.to_string(),
|
||||
"observer_user": observer_user.to_string(),
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _should_accept_invite(self, observed_user, observer_user):
|
||||
if not observed_user.is_mine:
|
||||
defer.returnValue(False)
|
||||
|
||||
row = yield self.store.has_presence_state(observed_user.localpart)
|
||||
if not row:
|
||||
defer.returnValue(False)
|
||||
|
||||
# TODO(paul): Eventually we'll ask the user's permission for this
|
||||
# before accepting. For now just accept any invite request
|
||||
defer.returnValue(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def invite_presence(self, observed_user, observer_user):
|
||||
accept = yield self._should_accept_invite(observed_user, observer_user)
|
||||
|
||||
if accept:
|
||||
yield self.store.allow_presence_visible(
|
||||
observed_user.localpart, observer_user.to_string()
|
||||
)
|
||||
|
||||
if observer_user.is_mine:
|
||||
if accept:
|
||||
yield self.accept_presence(observed_user, observer_user)
|
||||
else:
|
||||
yield self.deny_presence(observed_user, observer_user)
|
||||
else:
|
||||
edu_type = "m.presence_accept" if accept else "m.presence_deny"
|
||||
|
||||
yield self.federation.send_edu(
|
||||
destination=observer_user.domain,
|
||||
edu_type=edu_type,
|
||||
content={
|
||||
"observed_user": observed_user.to_string(),
|
||||
"observer_user": observer_user.to_string(),
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def accept_presence(self, observed_user, observer_user):
|
||||
yield self.store.set_presence_list_accepted(
|
||||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
|
||||
self.start_polling_presence(observer_user, target_user=observed_user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deny_presence(self, observed_user, observer_user):
|
||||
yield self.store.del_presence_list(
|
||||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
|
||||
# TODO(paul): Inform the user somehow?
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def drop(self, observed_user, observer_user):
|
||||
if not observer_user.is_mine:
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
yield self.store.del_presence_list(
|
||||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
|
||||
self.stop_polling_presence(observer_user, target_user=observed_user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_presence_list(self, observer_user, accepted=None):
|
||||
if not observer_user.is_mine:
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
presence = yield self.store.get_presence_list(
|
||||
observer_user.localpart, accepted=accepted
|
||||
)
|
||||
|
||||
for p in presence:
|
||||
observed_user = self.hs.parse_userid(p.pop("observed_user_id"))
|
||||
p["observed_user"] = observed_user
|
||||
p.update(self._get_or_offline_usercache(observed_user).get_state())
|
||||
|
||||
defer.returnValue(presence)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start_polling_presence(self, user, target_user=None, state=None):
|
||||
logger.debug("Start polling for presence from %s", user)
|
||||
|
||||
if target_user:
|
||||
target_users = [target_user]
|
||||
else:
|
||||
presence = yield self.store.get_presence_list(
|
||||
user.localpart, accepted=True
|
||||
)
|
||||
target_users = [
|
||||
self.hs.parse_userid(x["observed_user_id"]) for x in presence
|
||||
]
|
||||
|
||||
if state is None:
|
||||
state = yield self.store.get_presence_state(user.localpart)
|
||||
|
||||
localusers, remoteusers = partitionbool(
|
||||
target_users,
|
||||
lambda u: u.is_mine
|
||||
)
|
||||
|
||||
for target_user in localusers:
|
||||
self._start_polling_local(user, target_user)
|
||||
|
||||
deferreds = []
|
||||
remoteusers_by_domain = partition(remoteusers, lambda u: u.domain)
|
||||
for domain in remoteusers_by_domain:
|
||||
remoteusers = remoteusers_by_domain[domain]
|
||||
|
||||
deferreds.append(self._start_polling_remote(
|
||||
user, domain, remoteusers
|
||||
))
|
||||
|
||||
yield defer.DeferredList(deferreds)
|
||||
|
||||
def _start_polling_local(self, user, target_user):
|
||||
target_localpart = target_user.localpart
|
||||
|
||||
if not self.is_presence_visible(observer_user=user,
|
||||
observed_user=target_user):
|
||||
return
|
||||
|
||||
if target_localpart not in self._local_pushmap:
|
||||
self._local_pushmap[target_localpart] = set()
|
||||
|
||||
self._local_pushmap[target_localpart].add(user)
|
||||
|
||||
self.push_update_to_clients(
|
||||
observer_user=user,
|
||||
observed_user=target_user,
|
||||
statuscache=self._get_or_offline_usercache(target_user),
|
||||
)
|
||||
|
||||
def _start_polling_remote(self, user, domain, remoteusers):
|
||||
for u in remoteusers:
|
||||
if u not in self._remote_recvmap:
|
||||
self._remote_recvmap[u] = set()
|
||||
|
||||
self._remote_recvmap[u].add(user)
|
||||
|
||||
return self.federation.send_edu(
|
||||
destination=domain,
|
||||
edu_type="m.presence",
|
||||
content={"poll": [u.to_string() for u in remoteusers]}
|
||||
)
|
||||
|
||||
def stop_polling_presence(self, user, target_user=None):
|
||||
logger.debug("Stop polling for presence from %s", user)
|
||||
|
||||
if not target_user or target_user.is_mine:
|
||||
self._stop_polling_local(user, target_user=target_user)
|
||||
|
||||
deferreds = []
|
||||
|
||||
if target_user:
|
||||
raise NotImplementedError("TODO: remove one user")
|
||||
|
||||
remoteusers = [u for u in self._remote_recvmap
|
||||
if user in self._remote_recvmap[u]]
|
||||
remoteusers_by_domain = partition(remoteusers, lambda u: u.domain)
|
||||
|
||||
for domain in remoteusers_by_domain:
|
||||
remoteusers = remoteusers_by_domain[domain]
|
||||
|
||||
deferreds.append(
|
||||
self._stop_polling_remote(user, domain, remoteusers)
|
||||
)
|
||||
|
||||
return defer.DeferredList(deferreds)
|
||||
|
||||
def _stop_polling_local(self, user, target_user):
|
||||
for localpart in self._local_pushmap.keys():
|
||||
if target_user and localpart != target_user.localpart:
|
||||
continue
|
||||
|
||||
if user in self._local_pushmap[localpart]:
|
||||
self._local_pushmap[localpart].remove(user)
|
||||
|
||||
if not self._local_pushmap[localpart]:
|
||||
del self._local_pushmap[localpart]
|
||||
|
||||
def _stop_polling_remote(self, user, domain, remoteusers):
|
||||
for u in remoteusers:
|
||||
self._remote_recvmap[u].remove(user)
|
||||
|
||||
if not self._remote_recvmap[u]:
|
||||
del self._remote_recvmap[u]
|
||||
|
||||
return self.federation.send_edu(
|
||||
destination=domain,
|
||||
edu_type="m.presence",
|
||||
content={"unpoll": [u.to_string() for u in remoteusers]}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def push_presence(self, user, statuscache):
|
||||
assert(user.is_mine)
|
||||
|
||||
logger.debug("Pushing presence update from %s", user)
|
||||
|
||||
localusers = set(self._local_pushmap.get(user.localpart, set()))
|
||||
remotedomains = set(self._remote_sendmap.get(user.localpart, set()))
|
||||
|
||||
# Reflect users' status changes back to themselves, so UIs look nice
|
||||
# and also user is informed of server-forced pushes
|
||||
localusers.add(user)
|
||||
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
room_ids = yield rm_handler.get_rooms_for_user(user)
|
||||
|
||||
for room_id in room_ids:
|
||||
yield rm_handler.fetch_room_distributions_into(
|
||||
room_id, localusers=localusers, remotedomains=remotedomains,
|
||||
ignore_user=user,
|
||||
)
|
||||
|
||||
if not localusers and not remotedomains:
|
||||
defer.returnValue(None)
|
||||
|
||||
yield self._send_presence_to_distribution(user,
|
||||
localusers=localusers, remotedomains=remotedomains,
|
||||
statuscache=statuscache
|
||||
)
|
||||
|
||||
def _send_presence(self, srcuser, destuser, statuscache):
|
||||
if destuser.is_mine:
|
||||
self.push_update_to_clients(
|
||||
observer_user=destuser,
|
||||
observed_user=srcuser,
|
||||
statuscache=statuscache)
|
||||
return defer.succeed(None)
|
||||
else:
|
||||
return self._push_presence_remote(srcuser, destuser.domain,
|
||||
state=statuscache.get_state()
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_presence_to_distribution(self, srcuser, localusers=set(),
|
||||
remotedomains=set(), statuscache=None):
|
||||
|
||||
for u in localusers:
|
||||
logger.debug(" | push to local user %s", u)
|
||||
self.push_update_to_clients(
|
||||
observer_user=u,
|
||||
observed_user=srcuser,
|
||||
statuscache=statuscache,
|
||||
)
|
||||
|
||||
deferreds = []
|
||||
for domain in remotedomains:
|
||||
logger.debug(" | push to remote domain %s", domain)
|
||||
deferreds.append(self._push_presence_remote(srcuser, domain,
|
||||
state=statuscache.get_state())
|
||||
)
|
||||
|
||||
yield defer.DeferredList(deferreds)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _push_presence_remote(self, user, destination, state=None):
|
||||
if state is None:
|
||||
state = yield self.store.get_presence_state(user.localpart)
|
||||
yield self.distributor.fire(
|
||||
"collect_presencelike_data", user, state
|
||||
)
|
||||
|
||||
yield self.federation.send_edu(
|
||||
destination=destination,
|
||||
edu_type="m.presence",
|
||||
content={
|
||||
"push": [
|
||||
dict(user_id=user.to_string(), **state),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def incoming_presence(self, origin, content):
|
||||
deferreds = []
|
||||
|
||||
for push in content.get("push", []):
|
||||
user = self.hs.parse_userid(push["user_id"])
|
||||
|
||||
logger.debug("Incoming presence update from %s", user)
|
||||
|
||||
observers = set(self._remote_recvmap.get(user, set()))
|
||||
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
room_ids = yield rm_handler.get_rooms_for_user(user)
|
||||
|
||||
for room_id in room_ids:
|
||||
yield rm_handler.fetch_room_distributions_into(
|
||||
room_id, localusers=observers, ignore_user=user
|
||||
)
|
||||
|
||||
if not observers:
|
||||
break
|
||||
|
||||
state = dict(push)
|
||||
del state["user_id"]
|
||||
|
||||
statuscache = self._get_or_make_usercache(user)
|
||||
|
||||
self._user_cachemap_latest_serial += 1
|
||||
statuscache.update(state, serial=self._user_cachemap_latest_serial)
|
||||
|
||||
for observer_user in observers:
|
||||
self.push_update_to_clients(
|
||||
observer_user=observer_user,
|
||||
observed_user=user,
|
||||
statuscache=statuscache,
|
||||
)
|
||||
|
||||
if state["state"] == PresenceState.OFFLINE:
|
||||
del self._user_cachemap[user]
|
||||
|
||||
for poll in content.get("poll", []):
|
||||
user = self.hs.parse_userid(poll)
|
||||
|
||||
if not user.is_mine:
|
||||
continue
|
||||
|
||||
# TODO(paul) permissions checks
|
||||
|
||||
if not user in self._remote_sendmap:
|
||||
self._remote_sendmap[user] = set()
|
||||
|
||||
self._remote_sendmap[user].add(origin)
|
||||
|
||||
deferreds.append(self._push_presence_remote(user, origin))
|
||||
|
||||
for unpoll in content.get("unpoll", []):
|
||||
user = self.hs.parse_userid(unpoll)
|
||||
|
||||
if not user.is_mine:
|
||||
continue
|
||||
|
||||
if user in self._remote_sendmap:
|
||||
self._remote_sendmap[user].remove(origin)
|
||||
|
||||
if not self._remote_sendmap[user]:
|
||||
del self._remote_sendmap[user]
|
||||
|
||||
yield defer.DeferredList(deferreds)
|
||||
|
||||
def push_update_to_clients(self, observer_user, observed_user,
|
||||
statuscache):
|
||||
self.notifier.on_new_user_event(
|
||||
observer_user.to_string(),
|
||||
event_data=statuscache.make_event(user=observed_user),
|
||||
stream_type=PresenceStreamData,
|
||||
store_id=statuscache.serial
|
||||
)
|
||||
|
||||
|
||||
class PresenceStreamData(StreamData):
|
||||
def __init__(self, hs):
|
||||
super(PresenceStreamData, self).__init__(hs)
|
||||
self.presence = hs.get_handlers().presence_handler
|
||||
|
||||
def get_rows(self, user_id, from_key, to_key, limit):
|
||||
cachemap = self.presence._user_cachemap
|
||||
|
||||
# TODO(paul): limit, and filter by visibility
|
||||
updates = [(k, cachemap[k]) for k in cachemap
|
||||
if from_key < cachemap[k].serial <= to_key]
|
||||
|
||||
if updates:
|
||||
latest_serial = max([x[1].serial for x in updates])
|
||||
data = [x[1].make_event(user=x[0]) for x in updates]
|
||||
return ((data, latest_serial))
|
||||
else:
|
||||
return (([], self.presence._user_cachemap_latest_serial))
|
||||
|
||||
def max_token(self):
|
||||
return self.presence._user_cachemap_latest_serial
|
||||
|
||||
PresenceStreamData.EVENT_TYPE = PresenceStreamData
|
||||
|
||||
|
||||
class UserPresenceCache(object):
|
||||
"""Store an observed user's state and status message.
|
||||
|
||||
Includes the update timestamp.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.state = {}
|
||||
self.serial = None
|
||||
|
||||
def update(self, state, serial):
|
||||
self.state.update(state)
|
||||
# Delete keys that are now 'None'
|
||||
for k in self.state.keys():
|
||||
if self.state[k] is None:
|
||||
del self.state[k]
|
||||
|
||||
self.serial = serial
|
||||
|
||||
if "status_msg" in state:
|
||||
self.status_msg = state["status_msg"]
|
||||
else:
|
||||
self.status_msg = None
|
||||
|
||||
def get_state(self):
|
||||
# clone it so caller can't break our cache
|
||||
return dict(self.state)
|
||||
|
||||
def make_event(self, user):
|
||||
content = self.get_state()
|
||||
content["user_id"] = user.to_string()
|
||||
|
||||
return {"type": "m.presence", "content": content}
|
169
synapse/handlers/profile.py
Normal file
169
synapse/handlers/profile.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError, AuthError
|
||||
|
||||
from synapse.api.errors import CodeMessageException
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PREFIX = "/matrix/client/api/v1"
|
||||
|
||||
|
||||
class ProfileHandler(BaseHandler):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileHandler, self).__init__(hs)
|
||||
|
||||
self.client = hs.get_http_client()
|
||||
|
||||
distributor = hs.get_distributor()
|
||||
self.distributor = distributor
|
||||
|
||||
distributor.observe("registered_user", self.registered_user)
|
||||
|
||||
distributor.observe(
|
||||
"collect_presencelike_data", self.collect_presencelike_data
|
||||
)
|
||||
|
||||
def registered_user(self, user):
|
||||
self.store.create_profile(user.localpart)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_displayname(self, target_user, local_only=False):
|
||||
if target_user.is_mine:
|
||||
displayname = yield self.store.get_profile_displayname(
|
||||
target_user.localpart
|
||||
)
|
||||
|
||||
defer.returnValue(displayname)
|
||||
elif not local_only:
|
||||
# TODO(paul): This should use the server-server API to ask another
|
||||
# HS. For now we'll just have it use the http client to talk to the
|
||||
# other HS's REST client API
|
||||
path = PREFIX + "/profile/%s/displayname?local_only=1" % (
|
||||
target_user.to_string()
|
||||
)
|
||||
|
||||
try:
|
||||
result = yield self.client.get_json(
|
||||
destination=target_user.domain,
|
||||
path=path
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
if e.code != 404:
|
||||
logger.exception("Failed to get displayname")
|
||||
|
||||
raise
|
||||
except:
|
||||
logger.exception("Failed to get displayname")
|
||||
|
||||
defer.returnValue(result["displayname"])
|
||||
else:
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_displayname(self, target_user, auth_user, new_displayname):
|
||||
"""target_user is the user whose displayname is to be changed;
|
||||
auth_user is the user attempting to make this change."""
|
||||
if not target_user.is_mine:
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if target_user != auth_user:
|
||||
raise AuthError(400, "Cannot set another user's displayname")
|
||||
|
||||
yield self.store.set_profile_displayname(
|
||||
target_user.localpart, new_displayname
|
||||
)
|
||||
|
||||
yield self.distributor.fire(
|
||||
"changed_presencelike_data", target_user, {
|
||||
"displayname": new_displayname,
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_avatar_url(self, target_user, local_only=False):
|
||||
if target_user.is_mine:
|
||||
avatar_url = yield self.store.get_profile_avatar_url(
|
||||
target_user.localpart
|
||||
)
|
||||
|
||||
defer.returnValue(avatar_url)
|
||||
elif not local_only:
|
||||
# TODO(paul): This should use the server-server API to ask another
|
||||
# HS. For now we'll just have it use the http client to talk to the
|
||||
# other HS's REST client API
|
||||
destination = target_user.domain
|
||||
path = PREFIX + "/profile/%s/avatar_url?local_only=1" % (
|
||||
target_user.to_string(),
|
||||
)
|
||||
|
||||
try:
|
||||
result = yield self.client.get_json(
|
||||
destination=destination,
|
||||
path=path
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
if e.code != 404:
|
||||
logger.exception("Failed to get avatar_url")
|
||||
raise
|
||||
except:
|
||||
logger.exception("Failed to get avatar_url")
|
||||
|
||||
defer.returnValue(result["avatar_url"])
|
||||
else:
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_avatar_url(self, target_user, auth_user, new_avatar_url):
|
||||
"""target_user is the user whose avatar_url is to be changed;
|
||||
auth_user is the user attempting to make this change."""
|
||||
if not target_user.is_mine:
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if target_user != auth_user:
|
||||
raise AuthError(400, "Cannot set another user's avatar_url")
|
||||
|
||||
yield self.store.set_profile_avatar_url(
|
||||
target_user.localpart, new_avatar_url
|
||||
)
|
||||
|
||||
yield self.distributor.fire(
|
||||
"changed_presencelike_data", target_user, {
|
||||
"avatar_url": new_avatar_url,
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def collect_presencelike_data(self, user, state):
|
||||
if not user.is_mine:
|
||||
defer.returnValue(None)
|
||||
|
||||
(displayname, avatar_url) = yield defer.gatherResults([
|
||||
self.store.get_profile_displayname(user.localpart),
|
||||
self.store.get_profile_avatar_url(user.localpart),
|
||||
])
|
||||
|
||||
state["displayname"] = displayname
|
||||
state["avatar_url"] = avatar_url
|
||||
|
||||
defer.returnValue(None)
|
100
synapse/handlers/register.py
Normal file
100
synapse/handlers/register.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains functions for registering clients."""
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID
|
||||
from synapse.api.errors import SynapseError, RegistrationError
|
||||
from ._base import BaseHandler
|
||||
import synapse.util.stringutils as stringutils
|
||||
|
||||
import base64
|
||||
import bcrypt
|
||||
|
||||
|
||||
class RegistrationHandler(BaseHandler):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RegistrationHandler, self).__init__(hs)
|
||||
|
||||
self.distributor = hs.get_distributor()
|
||||
self.distributor.declare("registered_user")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register(self, localpart=None, password=None):
|
||||
"""Registers a new client on the server.
|
||||
|
||||
Args:
|
||||
localpart : The local part of the user ID to register. If None,
|
||||
one will be randomly generated.
|
||||
password (str) : The password to assign to this user so they can
|
||||
login again.
|
||||
Returns:
|
||||
A tuple of (user_id, access_token).
|
||||
Raises:
|
||||
RegistrationError if there was a problem registering.
|
||||
"""
|
||||
password_hash = None
|
||||
if password:
|
||||
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
|
||||
|
||||
if localpart:
|
||||
user = UserID(localpart, self.hs.hostname, True)
|
||||
user_id = user.to_string()
|
||||
|
||||
token = self._generate_token(user_id)
|
||||
yield self.store.register(user_id=user_id,
|
||||
token=token,
|
||||
password_hash=password_hash)
|
||||
|
||||
self.distributor.fire("registered_user", user)
|
||||
defer.returnValue((user_id, token))
|
||||
else:
|
||||
# autogen a random user ID
|
||||
attempts = 0
|
||||
user_id = None
|
||||
token = None
|
||||
while not user_id and not token:
|
||||
try:
|
||||
localpart = self._generate_user_id()
|
||||
user = UserID(localpart, self.hs.hostname, True)
|
||||
user_id = user.to_string()
|
||||
|
||||
token = self._generate_token(user_id)
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=password_hash)
|
||||
|
||||
self.distributor.fire("registered_user", user)
|
||||
defer.returnValue((user_id, token))
|
||||
except SynapseError:
|
||||
# if user id is taken, just generate another
|
||||
user_id = None
|
||||
token = None
|
||||
attempts += 1
|
||||
if attempts > 5:
|
||||
raise RegistrationError(
|
||||
500, "Cannot generate user ID.")
|
||||
|
||||
def _generate_token(self, user_id):
|
||||
# urlsafe variant uses _ and - so use . as the separator and replace
|
||||
# all =s with .s so http clients don't quote =s when it is used as
|
||||
# query params.
|
||||
return (base64.urlsafe_b64encode(user_id).replace('=', '.') + '.' +
|
||||
stringutils.random_string(18))
|
||||
|
||||
def _generate_user_id(self):
|
||||
return "-" + stringutils.random_string(18)
|
808
synapse/handlers/room.py
Normal file
808
synapse/handlers/room.py
Normal file
|
@ -0,0 +1,808 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains functions for performing events on rooms."""
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID, RoomAlias, RoomID
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import RoomError, StoreError, SynapseError
|
||||
from synapse.api.events.room import (
|
||||
RoomTopicEvent, MessageEvent, InviteJoinEvent, RoomMemberEvent,
|
||||
RoomConfigEvent
|
||||
)
|
||||
from synapse.api.streams.event import EventStream, MessagesStreamData
|
||||
from synapse.util import stringutils
|
||||
from ._base import BaseHandler
|
||||
|
||||
import logging
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageHandler(BaseHandler):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(MessageHandler, self).__init__(hs)
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.event_factory = hs.get_event_factory()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_message(self, msg_id=None, room_id=None, sender_id=None,
|
||||
user_id=None):
|
||||
""" Retrieve a message.
|
||||
|
||||
Args:
|
||||
msg_id (str): The message ID to obtain.
|
||||
room_id (str): The room where the message resides.
|
||||
sender_id (str): The user ID of the user who sent the message.
|
||||
user_id (str): The user ID of the user making this request.
|
||||
Returns:
|
||||
The message, or None if no message exists.
|
||||
Raises:
|
||||
SynapseError if something went wrong.
|
||||
"""
|
||||
yield self.auth.check_joined_room(room_id, user_id)
|
||||
|
||||
# Pull out the message from the db
|
||||
msg = yield self.store.get_message(room_id=room_id,
|
||||
msg_id=msg_id,
|
||||
user_id=sender_id)
|
||||
|
||||
if msg:
|
||||
defer.returnValue(msg)
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_message(self, event=None, suppress_auth=False, stamp_event=True):
|
||||
""" Send a message.
|
||||
|
||||
Args:
|
||||
event : The message event to store.
|
||||
suppress_auth (bool) : True to suppress auth for this message. This
|
||||
is primarily so the home server can inject messages into rooms at
|
||||
will.
|
||||
stamp_event (bool) : True to stamp event content with server keys.
|
||||
Raises:
|
||||
SynapseError if something went wrong.
|
||||
"""
|
||||
if stamp_event:
|
||||
event.content["hsob_ts"] = int(self.clock.time_msec())
|
||||
|
||||
with (yield self.room_lock.lock(event.room_id)):
|
||||
if not suppress_auth:
|
||||
yield self.auth.check(event, raises=True)
|
||||
|
||||
# store message in db
|
||||
store_id = yield self.store.persist_event(event)
|
||||
|
||||
event.destinations = yield self.store.get_joined_hosts_for_room(
|
||||
event.room_id
|
||||
)
|
||||
|
||||
yield self.hs.get_federation().handle_new_event(event)
|
||||
|
||||
self.notifier.on_new_room_event(event, store_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(self, user_id=None, room_id=None, pagin_config=None,
|
||||
feedback=False):
|
||||
"""Get messages in a room.
|
||||
|
||||
Args:
|
||||
user_id (str): The user requesting messages.
|
||||
room_id (str): The room they want messages from.
|
||||
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
||||
config rules to apply, if any.
|
||||
feedback (bool): True to get compressed feedback with the messages
|
||||
Returns:
|
||||
dict: Pagination API results
|
||||
"""
|
||||
yield self.auth.check_joined_room(room_id, user_id)
|
||||
|
||||
data_source = [MessagesStreamData(self.hs, room_id=room_id,
|
||||
feedback=feedback)]
|
||||
event_stream = EventStream(user_id, data_source)
|
||||
pagin_config = yield event_stream.fix_tokens(pagin_config)
|
||||
data_chunk = yield event_stream.get_chunk(config=pagin_config)
|
||||
defer.returnValue(data_chunk)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_room_data(self, event=None, stamp_event=True):
|
||||
""" Stores data for a room.
|
||||
|
||||
Args:
|
||||
event : The room path event
|
||||
stamp_event (bool) : True to stamp event content with server keys.
|
||||
Raises:
|
||||
SynapseError if something went wrong.
|
||||
"""
|
||||
|
||||
with (yield self.room_lock.lock(event.room_id)):
|
||||
yield self.auth.check(event, raises=True)
|
||||
|
||||
if stamp_event:
|
||||
event.content["hsob_ts"] = int(self.clock.time_msec())
|
||||
|
||||
yield self.state_handler.handle_new_event(event)
|
||||
|
||||
# store in db
|
||||
store_id = yield self.store.store_room_data(
|
||||
room_id=event.room_id,
|
||||
etype=event.type,
|
||||
state_key=event.state_key,
|
||||
content=json.dumps(event.content)
|
||||
)
|
||||
|
||||
event.destinations = yield self.store.get_joined_hosts_for_room(
|
||||
event.room_id
|
||||
)
|
||||
self.notifier.on_new_room_event(event, store_id)
|
||||
|
||||
yield self.hs.get_federation().handle_new_event(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_data(self, user_id=None, room_id=None,
|
||||
event_type=None, state_key="",
|
||||
public_room_rules=[],
|
||||
private_room_rules=["join"]):
|
||||
""" Get data from a room.
|
||||
|
||||
Args:
|
||||
event : The room path event
|
||||
public_room_rules : A list of membership states the user can be in,
|
||||
in order to read this data IN A PUBLIC ROOM. An empty list means
|
||||
'any state'.
|
||||
private_room_rules : A list of membership states the user can be
|
||||
in, in order to read this data IN A PRIVATE ROOM. An empty list
|
||||
means 'any state'.
|
||||
Returns:
|
||||
The path data content.
|
||||
Raises:
|
||||
SynapseError if something went wrong.
|
||||
"""
|
||||
if event_type == RoomTopicEvent.TYPE:
|
||||
# anyone invited/joined can read the topic
|
||||
private_room_rules = ["invite", "join"]
|
||||
|
||||
# does this room exist
|
||||
room = yield self.store.get_room(room_id)
|
||||
if not room:
|
||||
raise RoomError(403, "Room does not exist.")
|
||||
|
||||
# does this user exist in this room
|
||||
member = yield self.store.get_room_member(
|
||||
room_id=room_id,
|
||||
user_id="" if not user_id else user_id)
|
||||
|
||||
member_state = member.membership if member else None
|
||||
|
||||
if room.is_public and public_room_rules:
|
||||
# make sure the user meets public room rules
|
||||
if member_state not in public_room_rules:
|
||||
raise RoomError(403, "Member does not meet public room rules.")
|
||||
elif not room.is_public and private_room_rules:
|
||||
# make sure the user meets private room rules
|
||||
if member_state not in private_room_rules:
|
||||
raise RoomError(
|
||||
403, "Member does not meet private room rules.")
|
||||
|
||||
data = yield self.store.get_room_data(room_id, event_type, state_key)
|
||||
defer.returnValue(data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_feedback(self, room_id=None, msg_sender_id=None, msg_id=None,
|
||||
user_id=None, fb_sender_id=None, fb_type=None):
|
||||
yield self.auth.check_joined_room(room_id, user_id)
|
||||
|
||||
# Pull out the feedback from the db
|
||||
fb = yield self.store.get_feedback(
|
||||
room_id=room_id, msg_id=msg_id, msg_sender_id=msg_sender_id,
|
||||
fb_sender_id=fb_sender_id, fb_type=fb_type
|
||||
)
|
||||
|
||||
if fb:
|
||||
defer.returnValue(fb)
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_feedback(self, event, stamp_event=True):
|
||||
if stamp_event:
|
||||
event.content["hsob_ts"] = int(self.clock.time_msec())
|
||||
|
||||
with (yield self.room_lock.lock(event.room_id)):
|
||||
yield self.auth.check(event, raises=True)
|
||||
|
||||
# store message in db
|
||||
store_id = yield self.store.persist_event(event)
|
||||
|
||||
event.destinations = yield self.store.get_joined_hosts_for_room(
|
||||
event.room_id
|
||||
)
|
||||
yield self.hs.get_federation().handle_new_event(event)
|
||||
|
||||
self.notifier.on_new_room_event(event, store_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
|
||||
feedback=False):
|
||||
"""Retrieve a snapshot of all rooms the user is invited or has joined.
|
||||
|
||||
This snapshot may include messages for all rooms where the user is
|
||||
joined, depending on the pagination config.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user making the request.
|
||||
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
||||
config used to determine how many messages *PER ROOM* to return.
|
||||
feedback (bool): True to get feedback along with these messages.
|
||||
Returns:
|
||||
A list of dicts with "room_id" and "membership" keys for all rooms
|
||||
the user is currently invited or joined in on. Rooms where the user
|
||||
is joined on, may return a "messages" key with messages, depending
|
||||
on the specified PaginationConfig.
|
||||
"""
|
||||
room_list = yield self.store.get_rooms_for_user_where_membership_is(
|
||||
user_id=user_id,
|
||||
membership_list=[Membership.INVITE, Membership.JOIN]
|
||||
)
|
||||
for room_info in room_list:
|
||||
if room_info["membership"] != Membership.JOIN:
|
||||
continue
|
||||
try:
|
||||
event_chunk = yield self.get_messages(
|
||||
user_id=user_id,
|
||||
pagin_config=pagin_config,
|
||||
feedback=feedback,
|
||||
room_id=room_info["room_id"]
|
||||
)
|
||||
room_info["messages"] = event_chunk
|
||||
except:
|
||||
pass
|
||||
defer.returnValue(room_list)
|
||||
|
||||
|
||||
class RoomCreationHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_room(self, user_id, room_id, config):
|
||||
""" Creates a new room.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user creating the new room.
|
||||
room_id (str): The proposed ID for the new room. Can be None, in
|
||||
which case one will be created for you.
|
||||
config (dict) : A dict of configuration options.
|
||||
Returns:
|
||||
The new room ID.
|
||||
Raises:
|
||||
SynapseError if the room ID was taken, couldn't be stored, or
|
||||
something went horribly wrong.
|
||||
"""
|
||||
|
||||
if "room_alias_name" in config:
|
||||
room_alias = RoomAlias.create_local(
|
||||
config["room_alias_name"],
|
||||
self.hs
|
||||
)
|
||||
mapping = yield self.store.get_association_from_room_alias(
|
||||
room_alias
|
||||
)
|
||||
|
||||
if mapping:
|
||||
raise SynapseError(400, "Room alias already taken")
|
||||
else:
|
||||
room_alias = None
|
||||
|
||||
if room_id:
|
||||
# Ensure room_id is the correct type
|
||||
room_id_obj = RoomID.from_string(room_id, self.hs)
|
||||
if not room_id_obj.is_mine:
|
||||
raise SynapseError(400, "Room id must be local")
|
||||
|
||||
yield self.store.store_room(
|
||||
room_id=room_id,
|
||||
room_creator_user_id=user_id,
|
||||
is_public=config["visibility"] == "public"
|
||||
)
|
||||
else:
|
||||
# autogen room IDs and try to create it. We may clash, so just
|
||||
# try a few times till one goes through, giving up eventually.
|
||||
attempts = 0
|
||||
room_id = None
|
||||
while attempts < 5:
|
||||
try:
|
||||
random_string = stringutils.random_string(18)
|
||||
gen_room_id = RoomID.create_local(random_string, self.hs)
|
||||
yield self.store.store_room(
|
||||
room_id=gen_room_id.to_string(),
|
||||
room_creator_user_id=user_id,
|
||||
is_public=config["visibility"] == "public"
|
||||
)
|
||||
room_id = gen_room_id.to_string()
|
||||
break
|
||||
except StoreError:
|
||||
attempts += 1
|
||||
if not room_id:
|
||||
raise StoreError(500, "Couldn't generate a room ID.")
|
||||
|
||||
config_event = self.event_factory.create_event(
|
||||
etype=RoomConfigEvent.TYPE,
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
content=config,
|
||||
)
|
||||
|
||||
if room_alias:
|
||||
yield self.store.create_room_alias_association(
|
||||
room_id=room_id,
|
||||
room_alias=room_alias,
|
||||
servers=[self.hs.hostname],
|
||||
)
|
||||
|
||||
yield self.state_handler.handle_new_event(config_event)
|
||||
# store_id = persist...
|
||||
|
||||
yield self.hs.get_federation().handle_new_event(config_event)
|
||||
# self.notifier.on_new_room_event(event, store_id)
|
||||
|
||||
content = {"membership": Membership.JOIN}
|
||||
join_event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
target_user_id=user_id,
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
membership=Membership.JOIN,
|
||||
content=content
|
||||
)
|
||||
|
||||
yield self.hs.get_handlers().room_member_handler.change_membership(
|
||||
join_event,
|
||||
broadcast_msg=True,
|
||||
do_auth=False
|
||||
)
|
||||
|
||||
result = {"room_id": room_id}
|
||||
if room_alias:
|
||||
result["room_alias"] = room_alias.to_string()
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
|
||||
class RoomMemberHandler(BaseHandler):
|
||||
# TODO(paul): This handler currently contains a messy conflation of
|
||||
# low-level API that works on UserID objects and so on, and REST-level
|
||||
# API that takes ID strings and returns pagination chunks. These concerns
|
||||
# ought to be separated out a lot better.
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomMemberHandler, self).__init__(hs)
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
self.distributor = hs.get_distributor()
|
||||
self.distributor.declare("user_joined_room")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_members(self, room_id, membership=Membership.JOIN):
|
||||
hs = self.hs
|
||||
|
||||
memberships = yield self.store.get_room_members(
|
||||
room_id=room_id, membership=membership
|
||||
)
|
||||
|
||||
defer.returnValue([hs.parse_userid(m.user_id) for m in memberships])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_room_distributions_into(self, room_id, localusers=None,
|
||||
remotedomains=None, ignore_user=None):
|
||||
"""Fetch the distribution of a room, adding elements to either
|
||||
'localusers' or 'remotedomains', which should be a set() if supplied.
|
||||
If ignore_user is set, ignore that user.
|
||||
|
||||
This function returns nothing; its result is performed by the
|
||||
side-effect on the two passed sets. This allows easy accumulation of
|
||||
member lists of multiple rooms at once if required.
|
||||
"""
|
||||
members = yield self.get_room_members(room_id)
|
||||
for member in members:
|
||||
if ignore_user is not None and member == ignore_user:
|
||||
continue
|
||||
|
||||
if member.is_mine:
|
||||
if localusers is not None:
|
||||
localusers.add(member)
|
||||
else:
|
||||
if remotedomains is not None:
|
||||
remotedomains.add(member.domain)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_members_as_pagination_chunk(self, room_id=None, user_id=None,
|
||||
limit=0, start_tok=None,
|
||||
end_tok=None):
|
||||
"""Retrieve a list of room members in the room.
|
||||
|
||||
Args:
|
||||
room_id (str): The room to get the member list for.
|
||||
user_id (str): The ID of the user making the request.
|
||||
limit (int): The max number of members to return.
|
||||
start_tok (str): Optional. The start token if known.
|
||||
end_tok (str): Optional. The end token if known.
|
||||
Returns:
|
||||
dict: A Pagination streamable dict.
|
||||
Raises:
|
||||
SynapseError if something goes wrong.
|
||||
"""
|
||||
yield self.auth.check_joined_room(room_id, user_id)
|
||||
|
||||
member_list = yield self.store.get_room_members(room_id=room_id)
|
||||
event_list = [
|
||||
entry.as_event(self.event_factory).get_dict()
|
||||
for entry in member_list
|
||||
]
|
||||
chunk_data = {
|
||||
"start": "START",
|
||||
"end": "END",
|
||||
"chunk": event_list
|
||||
}
|
||||
# TODO honor Pagination stream params
|
||||
# TODO snapshot this list to return on subsequent requests when
|
||||
# paginating
|
||||
defer.returnValue(chunk_data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_member(self, room_id, member_user_id, auth_user_id):
|
||||
"""Retrieve a room member from a room.
|
||||
|
||||
Args:
|
||||
room_id : The room the member is in.
|
||||
member_user_id : The member's user ID
|
||||
auth_user_id : The user ID of the user making this request.
|
||||
Returns:
|
||||
The room member, or None if this member does not exist.
|
||||
Raises:
|
||||
SynapseError if something goes wrong.
|
||||
"""
|
||||
yield self.auth.check_joined_room(room_id, auth_user_id)
|
||||
|
||||
member = yield self.store.get_room_member(user_id=member_user_id,
|
||||
room_id=room_id)
|
||||
defer.returnValue(member)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def change_membership(self, event=None, broadcast_msg=False, do_auth=True):
|
||||
""" Change the membership status of a user in a room.
|
||||
|
||||
Args:
|
||||
event (SynapseEvent): The membership event
|
||||
broadcast_msg (bool): True to inject a membership message into this
|
||||
room on success.
|
||||
Raises:
|
||||
SynapseError if there was a problem changing the membership.
|
||||
"""
|
||||
|
||||
#broadcast_msg = False
|
||||
|
||||
prev_state = yield self.store.get_room_member(
|
||||
event.target_user_id, event.room_id
|
||||
)
|
||||
|
||||
if prev_state and prev_state.membership == event.membership:
|
||||
# treat this event as a NOOP.
|
||||
if do_auth: # This is mainly to fix a unit test.
|
||||
yield self.auth.check(event, raises=True)
|
||||
defer.returnValue({})
|
||||
return
|
||||
|
||||
room_id = event.room_id
|
||||
|
||||
# If we're trying to join a room then we have to do this differently
|
||||
# if this HS is not currently in the room, i.e. we have to do the
|
||||
# invite/join dance.
|
||||
if event.membership == Membership.JOIN:
|
||||
yield self._do_join(
|
||||
event, do_auth=do_auth, broadcast_msg=broadcast_msg
|
||||
)
|
||||
else:
|
||||
# This is not a JOIN, so we can handle it normally.
|
||||
if do_auth:
|
||||
yield self.auth.check(event, raises=True)
|
||||
|
||||
prev_state = yield self.store.get_room_member(
|
||||
event.target_user_id, event.room_id
|
||||
)
|
||||
if prev_state and prev_state.membership == event.membership:
|
||||
# double same action, treat this event as a NOOP.
|
||||
defer.returnValue({})
|
||||
return
|
||||
|
||||
yield self.state_handler.handle_new_event(event)
|
||||
yield self._do_local_membership_update(
|
||||
event,
|
||||
membership=event.content["membership"],
|
||||
broadcast_msg=broadcast_msg,
|
||||
)
|
||||
|
||||
defer.returnValue({"room_id": room_id})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def join_room_alias(self, joinee, room_alias, do_auth=True, content={}):
|
||||
directory_handler = self.hs.get_handlers().directory_handler
|
||||
mapping = yield directory_handler.get_association(room_alias)
|
||||
|
||||
if not mapping:
|
||||
raise SynapseError(404, "No such room alias")
|
||||
|
||||
room_id = mapping["room_id"]
|
||||
hosts = mapping["servers"]
|
||||
if not hosts:
|
||||
raise SynapseError(404, "No known servers")
|
||||
|
||||
host = hosts[0]
|
||||
|
||||
content.update({"membership": Membership.JOIN})
|
||||
new_event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
target_user_id=joinee.to_string(),
|
||||
room_id=room_id,
|
||||
user_id=joinee.to_string(),
|
||||
membership=Membership.JOIN,
|
||||
content=content,
|
||||
)
|
||||
|
||||
yield self._do_join(new_event, room_host=host, do_auth=True)
|
||||
|
||||
defer.returnValue({"room_id": room_id})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_join(self, event, room_host=None, do_auth=True, broadcast_msg=True):
|
||||
joinee = self.hs.parse_userid(event.target_user_id)
|
||||
# room_id = RoomID.from_string(event.room_id, self.hs)
|
||||
room_id = event.room_id
|
||||
|
||||
# If event doesn't include a display name, add one.
|
||||
yield self._fill_out_join_content(
|
||||
joinee, event.content
|
||||
)
|
||||
|
||||
# XXX: We don't do an auth check if we are doing an invite
|
||||
# join dance for now, since we're kinda implicitly checking
|
||||
# that we are allowed to join when we decide whether or not we
|
||||
# need to do the invite/join dance.
|
||||
|
||||
room = yield self.store.get_room(room_id)
|
||||
|
||||
if room:
|
||||
should_do_dance = False
|
||||
elif room_host:
|
||||
should_do_dance = True
|
||||
else:
|
||||
prev_state = yield self.store.get_room_member(
|
||||
joinee.to_string(), room_id
|
||||
)
|
||||
|
||||
if prev_state and prev_state.membership == Membership.INVITE:
|
||||
room = yield self.store.get_room(room_id)
|
||||
inviter = UserID.from_string(
|
||||
prev_state.sender, self.hs
|
||||
)
|
||||
|
||||
should_do_dance = not inviter.is_mine and not room
|
||||
room_host = inviter.domain
|
||||
else:
|
||||
should_do_dance = False
|
||||
|
||||
# We want to do the _do_update inside the room lock.
|
||||
if not should_do_dance:
|
||||
logger.debug("Doing normal join")
|
||||
|
||||
if do_auth:
|
||||
yield self.auth.check(event, raises=True)
|
||||
|
||||
yield self.state_handler.handle_new_event(event)
|
||||
yield self._do_local_membership_update(
|
||||
event,
|
||||
membership=event.content["membership"],
|
||||
broadcast_msg=broadcast_msg,
|
||||
)
|
||||
|
||||
|
||||
if should_do_dance:
|
||||
yield self._do_invite_join_dance(
|
||||
room_id=room_id,
|
||||
joinee=event.user_id,
|
||||
target_host=room_host,
|
||||
content=event.content,
|
||||
)
|
||||
|
||||
user = self.hs.parse_userid(event.user_id)
|
||||
self.distributor.fire(
|
||||
"user_joined_room", user=user, room_id=room_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _fill_out_join_content(self, user_id, content):
|
||||
# If event doesn't include a display name, add one.
|
||||
profile_handler = self.hs.get_handlers().profile_handler
|
||||
if "displayname" not in content:
|
||||
try:
|
||||
display_name = yield profile_handler.get_displayname(
|
||||
user_id
|
||||
)
|
||||
|
||||
if display_name:
|
||||
content["displayname"] = display_name
|
||||
except:
|
||||
logger.exception("Failed to set display_name")
|
||||
|
||||
if "avatar_url" not in content:
|
||||
try:
|
||||
avatar_url = yield profile_handler.get_avatar_url(
|
||||
user_id
|
||||
)
|
||||
|
||||
if avatar_url:
|
||||
content["avatar_url"] = avatar_url
|
||||
except:
|
||||
logger.exception("Failed to set display_name")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _should_invite_join(self, room_id, prev_state, do_auth):
|
||||
logger.debug("_should_invite_join: room_id: %s", room_id)
|
||||
|
||||
# XXX: We don't do an auth check if we are doing an invite
|
||||
# join dance for now, since we're kinda implicitly checking
|
||||
# that we are allowed to join when we decide whether or not we
|
||||
# need to do the invite/join dance.
|
||||
|
||||
# Only do an invite join dance if a) we were invited,
|
||||
# b) the person inviting was from a differnt HS and c) we are
|
||||
# not currently in the room
|
||||
room_host = None
|
||||
if prev_state and prev_state.membership == Membership.INVITE:
|
||||
room = yield self.store.get_room(room_id)
|
||||
inviter = UserID.from_string(
|
||||
prev_state.sender, self.hs
|
||||
)
|
||||
|
||||
is_remote_invite_join = not inviter.is_mine and not room
|
||||
room_host = inviter.domain
|
||||
else:
|
||||
is_remote_invite_join = False
|
||||
|
||||
defer.returnValue((is_remote_invite_join, room_host))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]):
|
||||
"""Returns a list of roomids that the user has any of the given
|
||||
membership states in."""
|
||||
rooms = yield self.store.get_rooms_for_user_where_membership_is(
|
||||
user_id=user.to_string(), membership_list=membership_list
|
||||
)
|
||||
|
||||
defer.returnValue([r["room_id"] for r in rooms])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_local_membership_update(self, event, membership, broadcast_msg):
|
||||
# store membership
|
||||
store_id = yield self.store.store_room_member(
|
||||
user_id=event.target_user_id,
|
||||
sender=event.user_id,
|
||||
room_id=event.room_id,
|
||||
content=event.content,
|
||||
membership=membership
|
||||
)
|
||||
|
||||
# Send a PDU to all hosts who have joined the room.
|
||||
destinations = yield self.store.get_joined_hosts_for_room(
|
||||
event.room_id
|
||||
)
|
||||
|
||||
# If we're inviting someone, then we should also send it to that
|
||||
# HS.
|
||||
if membership == Membership.INVITE:
|
||||
host = UserID.from_string(
|
||||
event.target_user_id, self.hs
|
||||
).domain
|
||||
destinations.append(host)
|
||||
|
||||
# If we are joining a remote HS, include that.
|
||||
if membership == Membership.JOIN:
|
||||
host = UserID.from_string(
|
||||
event.target_user_id, self.hs
|
||||
).domain
|
||||
destinations.append(host)
|
||||
|
||||
event.destinations = list(set(destinations))
|
||||
|
||||
yield self.hs.get_federation().handle_new_event(event)
|
||||
self.notifier.on_new_room_event(event, store_id)
|
||||
|
||||
if broadcast_msg:
|
||||
yield self._inject_membership_msg(
|
||||
source=event.user_id,
|
||||
target=event.target_user_id,
|
||||
room_id=event.room_id,
|
||||
membership=event.content["membership"]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_invite_join_dance(self, room_id, joinee, target_host, content):
|
||||
logger.debug("Doing remote join dance")
|
||||
|
||||
# do invite join dance
|
||||
federation = self.hs.get_federation()
|
||||
new_event = self.event_factory.create_event(
|
||||
etype=InviteJoinEvent.TYPE,
|
||||
target_host=target_host,
|
||||
room_id=room_id,
|
||||
user_id=joinee,
|
||||
content=content
|
||||
)
|
||||
|
||||
new_event.destinations = [target_host]
|
||||
|
||||
yield self.store.store_room(
|
||||
room_id, "", is_public=False
|
||||
)
|
||||
|
||||
#yield self.state_handler.handle_new_event(event)
|
||||
yield federation.handle_new_event(new_event)
|
||||
yield federation.get_state_for_room(
|
||||
target_host, room_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _inject_membership_msg(self, room_id=None, source=None, target=None,
|
||||
membership=None):
|
||||
# TODO this should be a different type of message, not m.text
|
||||
if membership == Membership.INVITE:
|
||||
body = "%s invited %s to the room." % (source, target)
|
||||
elif membership == Membership.JOIN:
|
||||
body = "%s joined the room." % (target)
|
||||
elif membership == Membership.LEAVE:
|
||||
body = "%s left the room." % (target)
|
||||
else:
|
||||
raise RoomError(500, "Unknown membership value %s" % membership)
|
||||
|
||||
membership_json = {
|
||||
"msgtype": u"m.text",
|
||||
"body": body,
|
||||
"membership_source": source,
|
||||
"membership_target": target,
|
||||
"membership": membership,
|
||||
}
|
||||
|
||||
msg_id = "m%s" % int(self.clock.time_msec())
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=MessageEvent.TYPE,
|
||||
room_id=room_id,
|
||||
user_id="_homeserver_",
|
||||
msg_id=msg_id,
|
||||
content=membership_json
|
||||
)
|
||||
|
||||
handler = self.hs.get_handlers().message_handler
|
||||
yield handler.send_message(event, suppress_auth=True)
|
||||
|
||||
|
||||
class RoomListHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_public_room_list(self):
|
||||
chunk = yield self.store.get_rooms(is_public=True, with_topics=True)
|
||||
defer.returnValue({"start": "START", "end": "END", "chunk": chunk})
|
14
synapse/http/__init__.py
Normal file
14
synapse/http/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
246
synapse/http/client.py
Normal file
246
synapse/http/client.py
Normal file
|
@ -0,0 +1,246 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.web.client import _AgentBase, _URI, readBody
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
from synapse.http.endpoint import matrix_endpoint
|
||||
from synapse.util.async import sleep
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
from synapse.api.errors import CodeMessageException
|
||||
|
||||
import json
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_destination_mappings = {
|
||||
"red": "localhost:8080",
|
||||
"blue": "localhost:8081",
|
||||
"green": "localhost:8082",
|
||||
}
|
||||
|
||||
|
||||
class HttpClient(object):
|
||||
""" Interface for talking json over http
|
||||
"""
|
||||
|
||||
def put_json(self, destination, path, data):
|
||||
""" Sends the specifed json data using PUT
|
||||
|
||||
Args:
|
||||
destination (str): The remote server to send the HTTP request
|
||||
to.
|
||||
path (str): The HTTP path.
|
||||
data (dict): A dict containing the data that will be used as
|
||||
the request body. This will be encoded as JSON.
|
||||
|
||||
Returns:
|
||||
Deferred: Succeeds when we get *any* HTTP response.
|
||||
|
||||
The result of the deferred is a tuple of `(code, response)`,
|
||||
where `response` is a dict representing the decoded JSON body.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_json(self, destination, path, args=None):
|
||||
""" Get's some json from the given host homeserver and path
|
||||
|
||||
Args:
|
||||
destination (str): The remote server to send the HTTP request
|
||||
to.
|
||||
path (str): The HTTP path.
|
||||
args (dict): A dictionary used to create query strings, defaults to
|
||||
None.
|
||||
**Note**: The value of each key is assumed to be an iterable
|
||||
and *not* a string.
|
||||
|
||||
Returns:
|
||||
Deferred: Succeeds when we get *any* HTTP response.
|
||||
|
||||
The result of the deferred is a tuple of `(code, response)`,
|
||||
where `response` is a dict representing the decoded JSON body.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MatrixHttpAgent(_AgentBase):
|
||||
|
||||
def __init__(self, reactor, pool=None):
|
||||
_AgentBase.__init__(self, reactor, pool)
|
||||
|
||||
def request(self, destination, endpoint, method, path, params, query,
|
||||
headers, body_producer):
|
||||
|
||||
host = b""
|
||||
port = 0
|
||||
fragment = b""
|
||||
|
||||
parsed_URI = _URI(b"http", destination, host, port, path, params,
|
||||
query, fragment)
|
||||
|
||||
# Set the connection pool key to be the destination.
|
||||
key = destination
|
||||
|
||||
return self._requestWithEndpoint(key, endpoint, method, parsed_URI,
|
||||
headers, body_producer,
|
||||
parsed_URI.originForm)
|
||||
|
||||
|
||||
class TwistedHttpClient(HttpClient):
|
||||
""" Wrapper around the twisted HTTP client api.
|
||||
|
||||
Attributes:
|
||||
agent (twisted.web.client.Agent): The twisted Agent used to send the
|
||||
requests.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.agent = MatrixHttpAgent(reactor)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def put_json(self, destination, path, data):
|
||||
if destination in _destination_mappings:
|
||||
destination = _destination_mappings[destination]
|
||||
|
||||
response = yield self._create_request(
|
||||
destination.encode("ascii"),
|
||||
"PUT",
|
||||
path.encode("ascii"),
|
||||
producer=_JsonProducer(data),
|
||||
headers_dict={"Content-Type": ["application/json"]}
|
||||
)
|
||||
|
||||
logger.debug("Getting resp body")
|
||||
body = yield readBody(response)
|
||||
logger.debug("Got resp body")
|
||||
|
||||
defer.returnValue((response.code, body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_json(self, destination, path, args={}):
|
||||
if destination in _destination_mappings:
|
||||
destination = _destination_mappings[destination]
|
||||
|
||||
logger.debug("get_json args: %s", args)
|
||||
query_bytes = urllib.urlencode(args, True)
|
||||
|
||||
response = yield self._create_request(
|
||||
destination.encode("ascii"),
|
||||
"GET",
|
||||
path.encode("ascii"),
|
||||
query_bytes
|
||||
)
|
||||
|
||||
body = yield readBody(response)
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_request(self, destination, method, path_bytes, param_bytes=b"",
|
||||
query_bytes=b"", producer=None, headers_dict={}):
|
||||
""" Creates and sends a request to the given url
|
||||
"""
|
||||
headers_dict[b"User-Agent"] = [b"Synapse"]
|
||||
headers_dict[b"Host"] = [destination]
|
||||
|
||||
logger.debug("Sending request to %s: %s %s;%s?%s",
|
||||
destination, method, path_bytes, param_bytes, query_bytes)
|
||||
|
||||
logger.debug(
|
||||
"Types: %s",
|
||||
[
|
||||
type(destination), type(method), type(path_bytes),
|
||||
type(param_bytes),
|
||||
type(query_bytes)
|
||||
]
|
||||
)
|
||||
|
||||
retries_left = 5
|
||||
|
||||
# TODO: setup and pass in an ssl_context to enable TLS
|
||||
endpoint = matrix_endpoint(reactor, destination, timeout=10)
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = yield self.agent.request(
|
||||
destination,
|
||||
endpoint,
|
||||
method,
|
||||
path_bytes,
|
||||
param_bytes,
|
||||
query_bytes,
|
||||
Headers(headers_dict),
|
||||
producer
|
||||
)
|
||||
|
||||
logger.debug("Got response to %s", method)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception("Got error in _create_request")
|
||||
_print_ex(e)
|
||||
|
||||
if retries_left:
|
||||
yield sleep(2 ** (5 - retries_left))
|
||||
retries_left -= 1
|
||||
else:
|
||||
raise
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
# We need to update the transactions table to say it was sent?
|
||||
pass
|
||||
else:
|
||||
# :'(
|
||||
# Update transactions table?
|
||||
logger.error(
|
||||
"Got response %d %s", response.code, response.phrase
|
||||
)
|
||||
raise CodeMessageException(
|
||||
response.code, response.phrase
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
|
||||
def _print_ex(e):
|
||||
if hasattr(e, "reasons") and e.reasons:
|
||||
for ex in e.reasons:
|
||||
_print_ex(ex)
|
||||
else:
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
class _JsonProducer(object):
|
||||
""" Used by the twisted http client to create the HTTP body from json
|
||||
"""
|
||||
def __init__(self, jsn):
|
||||
self.body = encode_canonical_json(jsn)
|
||||
self.length = len(self.body)
|
||||
|
||||
def startProducing(self, consumer):
|
||||
consumer.write(self.body)
|
||||
return defer.succeed(None)
|
||||
|
||||
def pauseProducing(self):
|
||||
pass
|
||||
|
||||
def stopProducing(self):
|
||||
pass
|
171
synapse/http/endpoint.py
Normal file
171
synapse/http/endpoint.py
Normal file
|
@ -0,0 +1,171 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.error import ConnectError
|
||||
from twisted.names import client, dns
|
||||
from twisted.names.error import DNSNameError
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import random
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def matrix_endpoint(reactor, destination, ssl_context_factory=None,
|
||||
timeout=None):
|
||||
"""Construct an endpoint for the given matrix destination.
|
||||
|
||||
Args:
|
||||
reactor: Twisted reactor.
|
||||
destination (bytes): The name of the server to connect to.
|
||||
ssl_context_factory (twisted.internet.ssl.ContextFactory): Factory
|
||||
which generates SSL contexts to use for TLS.
|
||||
timeout (int): connection timeout in seconds
|
||||
"""
|
||||
|
||||
domain_port = destination.split(":")
|
||||
domain = domain_port[0]
|
||||
port = int(domain_port[1]) if domain_port[1:] else None
|
||||
|
||||
endpoint_kw_args = {}
|
||||
|
||||
if timeout is not None:
|
||||
endpoint_kw_args.update(timeout=timeout)
|
||||
|
||||
if ssl_context_factory is None:
|
||||
transport_endpoint = TCP4ClientEndpoint
|
||||
default_port = 8080
|
||||
else:
|
||||
transport_endpoint = SSL4ClientEndpoint
|
||||
endpoint_kw_args.update(ssl_context_factory=ssl_context_factory)
|
||||
default_port = 443
|
||||
|
||||
if port is None:
|
||||
return SRVClientEndpoint(
|
||||
reactor, "matrix", domain, protocol="tcp",
|
||||
default_port=default_port, endpoint=transport_endpoint,
|
||||
endpoint_kw_args=endpoint_kw_args
|
||||
)
|
||||
else:
|
||||
return transport_endpoint(reactor, domain, port, **endpoint_kw_args)
|
||||
|
||||
|
||||
class SRVClientEndpoint(object):
|
||||
"""An endpoint which looks up SRV records for a service.
|
||||
Cycles through the list of servers starting with each call to connect
|
||||
picking the next server.
|
||||
Implements twisted.internet.interfaces.IStreamClientEndpoint.
|
||||
"""
|
||||
|
||||
_Server = collections.namedtuple(
|
||||
"_Server", "priority weight host port"
|
||||
)
|
||||
|
||||
def __init__(self, reactor, service, domain, protocol="tcp",
|
||||
default_port=None, endpoint=TCP4ClientEndpoint,
|
||||
endpoint_kw_args={}):
|
||||
self.reactor = reactor
|
||||
self.service_name = "_%s._%s.%s" % (service, protocol, domain)
|
||||
|
||||
if default_port is not None:
|
||||
self.default_server = self._Server(
|
||||
host=domain,
|
||||
port=default_port,
|
||||
priority=0,
|
||||
weight=0
|
||||
)
|
||||
else:
|
||||
self.default_server = None
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.endpoint_kw_args = endpoint_kw_args
|
||||
|
||||
self.servers = None
|
||||
self.used_servers = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_servers(self):
|
||||
try:
|
||||
answers, auth, add = yield client.lookupService(self.service_name)
|
||||
except DNSNameError:
|
||||
answers = []
|
||||
|
||||
if (len(answers) == 1
|
||||
and answers[0].type == dns.SRV
|
||||
and answers[0].payload
|
||||
and answers[0].payload.target == dns.Name('.')):
|
||||
raise ConnectError("Service %s unavailable", self.service_name)
|
||||
|
||||
self.servers = []
|
||||
self.used_servers = []
|
||||
|
||||
for answer in answers:
|
||||
if answer.type != dns.SRV or not answer.payload:
|
||||
continue
|
||||
payload = answer.payload
|
||||
self.servers.append(self._Server(
|
||||
host=str(payload.target),
|
||||
port=int(payload.port),
|
||||
priority=int(payload.priority),
|
||||
weight=int(payload.weight)
|
||||
))
|
||||
|
||||
self.servers.sort()
|
||||
|
||||
def pick_server(self):
|
||||
if not self.servers:
|
||||
if self.used_servers:
|
||||
self.servers = self.used_servers
|
||||
self.used_servers = []
|
||||
self.servers.sort()
|
||||
elif self.default_server:
|
||||
return self.default_server
|
||||
else:
|
||||
raise ConnectError(
|
||||
"Not server available for %s", self.service_name
|
||||
)
|
||||
|
||||
min_priority = self.servers[0].priority
|
||||
weight_indexes = list(
|
||||
(index, server.weight + 1)
|
||||
for index, server in enumerate(self.servers)
|
||||
if server.priority == min_priority
|
||||
)
|
||||
|
||||
total_weight = sum(weight for index, weight in weight_indexes)
|
||||
target_weight = random.randint(0, total_weight)
|
||||
|
||||
for index, weight in weight_indexes:
|
||||
target_weight -= weight
|
||||
if target_weight <= 0:
|
||||
server = self.servers[index]
|
||||
del self.servers[index]
|
||||
self.used_servers.append(server)
|
||||
return server
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def connect(self, protocolFactory):
|
||||
if self.servers is None:
|
||||
yield self.fetch_servers()
|
||||
server = self.pick_server()
|
||||
logger.info("Connecting to %s:%s", server.host, server.port)
|
||||
endpoint = self.endpoint(
|
||||
self.reactor, server.host, server.port, **self.endpoint_kw_args
|
||||
)
|
||||
connection = yield endpoint.connect(protocolFactory)
|
||||
defer.returnValue(connection)
|
181
synapse/http/server.py
Normal file
181
synapse/http/server.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from syutil.jsonutil import (
|
||||
encode_canonical_json, encode_pretty_printed_json
|
||||
)
|
||||
from synapse.api.errors import cs_exception, CodeMessageException
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.web import server, resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
||||
import collections
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HttpServer(object):
|
||||
""" Interface for registering callbacks on a HTTP server
|
||||
"""
|
||||
|
||||
def register_path(self, method, path_pattern, callback):
|
||||
""" Register a callback that get's fired if we receive a http request
|
||||
with the given method for a path that matches the given regex.
|
||||
|
||||
If the regex contains groups these get's passed to the calback via
|
||||
an unpacked tuple.
|
||||
|
||||
Args:
|
||||
method (str): The method to listen to.
|
||||
path_pattern (str): The regex used to match requests.
|
||||
callback (function): The function to fire if we receive a matched
|
||||
request. The first argument will be the request object and
|
||||
subsequent arguments will be any matched groups from the regex.
|
||||
This should return a tuple of (code, response).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# The actual HTTP server impl, using twisted http server
|
||||
class TwistedHttpServer(HttpServer, resource.Resource):
|
||||
""" This wraps the twisted HTTP server, and triggers the correct callbacks
|
||||
on the transport_layer.
|
||||
|
||||
Register callbacks via register_path()
|
||||
"""
|
||||
|
||||
isLeaf = True
|
||||
|
||||
_PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
|
||||
|
||||
def __init__(self):
|
||||
resource.Resource.__init__(self)
|
||||
|
||||
self.path_regexs = {}
|
||||
|
||||
def register_path(self, method, path_pattern, callback):
|
||||
self.path_regexs.setdefault(method, []).append(
|
||||
self._PathEntry(path_pattern, callback)
|
||||
)
|
||||
|
||||
def start_listening(self, port):
|
||||
""" Registers the http server with the twisted reactor.
|
||||
|
||||
Args:
|
||||
port (int): The port to listen on.
|
||||
|
||||
"""
|
||||
reactor.listenTCP(port, server.Site(self))
|
||||
|
||||
# Gets called by twisted
|
||||
def render(self, request):
|
||||
""" This get's called by twisted every time someone sends us a request.
|
||||
"""
|
||||
self._async_render(request)
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _async_render(self, request):
|
||||
""" This get's called by twisted every time someone sends us a request.
|
||||
This checks if anyone has registered a callback for that method and
|
||||
path.
|
||||
"""
|
||||
try:
|
||||
# Loop through all the registered callbacks to check if the method
|
||||
# and path regex match
|
||||
for path_entry in self.path_regexs.get(request.method, []):
|
||||
m = path_entry.pattern.match(request.path)
|
||||
if m:
|
||||
# We found a match! Trigger callback and then return the
|
||||
# returned response. We pass both the request and any
|
||||
# matched groups from the regex to the callback.
|
||||
code, response = yield path_entry.callback(
|
||||
request,
|
||||
*m.groups()
|
||||
)
|
||||
|
||||
self._send_response(request, code, response)
|
||||
return
|
||||
|
||||
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
||||
self._send_response(
|
||||
request,
|
||||
400,
|
||||
{"error": "Unrecognized request"}
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
logger.exception(e)
|
||||
self._send_response(
|
||||
request,
|
||||
e.code,
|
||||
cs_exception(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
self._send_response(
|
||||
request,
|
||||
500,
|
||||
{"error": "Internal server error"}
|
||||
)
|
||||
|
||||
def _send_response(self, request, code, response_json_object):
|
||||
|
||||
if not self._request_user_agent_is_curl(request):
|
||||
json_bytes = encode_canonical_json(response_json_object)
|
||||
else:
|
||||
json_bytes = encode_pretty_printed_json(response_json_object)
|
||||
|
||||
# TODO: Only enable CORS for the requests that need it.
|
||||
respond_with_json_bytes(request, code, json_bytes, send_cors=True)
|
||||
|
||||
@staticmethod
|
||||
def _request_user_agent_is_curl(request):
|
||||
user_agents = request.requestHeaders.getRawHeaders(
|
||||
"User-Agent", default=[]
|
||||
)
|
||||
for user_agent in user_agents:
|
||||
if "curl" in user_agent:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def respond_with_json_bytes(request, code, json_bytes, send_cors=False):
|
||||
"""Sends encoded JSON in response to the given request.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request): The http request to respond to.
|
||||
code (int): The HTTP response code.
|
||||
json_bytes (bytes): The json bytes to use as the response body.
|
||||
send_cors (bool): Whether to send Cross-Origin Resource Sharing headers
|
||||
http://www.w3.org/TR/cors/
|
||||
Returns:
|
||||
twisted.web.server.NOT_DONE_YET"""
|
||||
|
||||
request.setResponseCode(code)
|
||||
request.setHeader(b"Content-Type", b"application/json")
|
||||
|
||||
if send_cors:
|
||||
request.setHeader("Access-Control-Allow-Origin", "*")
|
||||
request.setHeader("Access-Control-Allow-Methods",
|
||||
"GET, POST, PUT, DELETE, OPTIONS")
|
||||
request.setHeader("Access-Control-Allow-Headers",
|
||||
"Origin, X-Requested-With, Content-Type, Accept")
|
||||
|
||||
request.write(json_bytes)
|
||||
request.finish()
|
||||
return NOT_DONE_YET
|
44
synapse/rest/__init__.py
Normal file
44
synapse/rest/__init__.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import (
|
||||
room, events, register, profile, public, presence, im, directory
|
||||
)
|
||||
|
||||
class RestServletFactory(object):
|
||||
|
||||
""" A factory for creating REST servlets.
|
||||
|
||||
These REST servlets represent the entire client-server REST API. Generally
|
||||
speaking, they serve as wrappers around events and the handlers that
|
||||
process them.
|
||||
|
||||
See synapse.api.events for information on synapse events.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
http_server = hs.get_http_server()
|
||||
|
||||
# TODO(erikj): There *must* be a better way of doing this.
|
||||
room.register_servlets(hs, http_server)
|
||||
events.register_servlets(hs, http_server)
|
||||
register.register_servlets(hs, http_server)
|
||||
profile.register_servlets(hs, http_server)
|
||||
public.register_servlets(hs, http_server)
|
||||
presence.register_servlets(hs, http_server)
|
||||
im.register_servlets(hs, http_server)
|
||||
directory.register_servlets(hs, http_server)
|
||||
|
||||
|
113
synapse/rest/base.py
Normal file
113
synapse/rest/base.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" This module contains base REST classes for constructing REST servlets. """
|
||||
import re
|
||||
|
||||
|
||||
def client_path_pattern(path_regex):
|
||||
"""Creates a regex compiled client path with the correct client path
|
||||
prefix.
|
||||
|
||||
Args:
|
||||
path_regex (str): The regex string to match. This should NOT have a ^
|
||||
as this will be prefixed.
|
||||
Returns:
|
||||
SRE_Pattern
|
||||
"""
|
||||
return re.compile("^/matrix/client/api/v1" + path_regex)
|
||||
|
||||
|
||||
class RestServletFactory(object):
|
||||
|
||||
""" A factory for creating REST servlets.
|
||||
|
||||
These REST servlets represent the entire client-server REST API. Generally
|
||||
speaking, they serve as wrappers around events and the handlers that
|
||||
process them.
|
||||
|
||||
See synapse.api.events for information on synapse events.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
http_server = hs.get_http_server()
|
||||
|
||||
# You get import errors if you try to import before the classes in this
|
||||
# file are defined, hence importing here instead.
|
||||
|
||||
import room
|
||||
room.register_servlets(hs, http_server)
|
||||
|
||||
import events
|
||||
events.register_servlets(hs, http_server)
|
||||
|
||||
import register
|
||||
register.register_servlets(hs, http_server)
|
||||
|
||||
import profile
|
||||
profile.register_servlets(hs, http_server)
|
||||
|
||||
import public
|
||||
public.register_servlets(hs, http_server)
|
||||
|
||||
import presence
|
||||
presence.register_servlets(hs, http_server)
|
||||
|
||||
import im
|
||||
im.register_servlets(hs, http_server)
|
||||
|
||||
import login
|
||||
login.register_servlets(hs, http_server)
|
||||
|
||||
|
||||
class RestServlet(object):
|
||||
|
||||
""" A Synapse REST Servlet.
|
||||
|
||||
An implementing class can either provide its own custom 'register' method,
|
||||
or use the automatic pattern handling provided by the base class.
|
||||
|
||||
To use this latter, the implementing class instead provides a `PATTERN`
|
||||
class attribute containing a pre-compiled regular expression. The automatic
|
||||
register method will then use this method to register any of the following
|
||||
instance methods associated with the corresponding HTTP method:
|
||||
|
||||
on_GET
|
||||
on_PUT
|
||||
on_POST
|
||||
on_DELETE
|
||||
on_OPTIONS
|
||||
|
||||
Automatically handles turning CodeMessageExceptions thrown by these methods
|
||||
into the appropriate HTTP response.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
|
||||
self.handlers = hs.get_handlers()
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
def register(self, http_server):
|
||||
""" Register this servlet with the given HTTP server. """
|
||||
if hasattr(self, "PATTERN"):
|
||||
pattern = self.PATTERN
|
||||
|
||||
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
|
||||
if hasattr(self, "on_%s" % (method)):
|
||||
method_handler = getattr(self, "on_%s" % (method))
|
||||
http_server.register_path(method, pattern, method_handler)
|
||||
else:
|
||||
raise NotImplementedError("RestServlet must register something.")
|
82
synapse/rest/directory.py
Normal file
82
synapse/rest/directory.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import RoomAlias, RoomID
|
||||
from base import RestServlet, client_path_pattern
|
||||
|
||||
import json
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
ClientDirectoryServer(hs).register(http_server)
|
||||
|
||||
|
||||
class ClientDirectoryServer(RestServlet):
|
||||
PATTERN = client_path_pattern("/ds/room/(?P<room_alias>[^/]*)$")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_alias):
|
||||
# TODO(erikj): Handle request
|
||||
local_only = "local_only" in request.args
|
||||
|
||||
room_alias = urllib.unquote(room_alias)
|
||||
room_alias_obj = RoomAlias.from_string(room_alias, self.hs)
|
||||
|
||||
dir_handler = self.handlers.directory_handler
|
||||
res = yield dir_handler.get_association(
|
||||
room_alias_obj,
|
||||
local_only=local_only
|
||||
)
|
||||
|
||||
defer.returnValue((200, res))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_alias):
|
||||
# TODO(erikj): Exceptions
|
||||
content = json.loads(request.content.read())
|
||||
|
||||
logger.debug("Got content: %s", content)
|
||||
|
||||
room_alias = urllib.unquote(room_alias)
|
||||
room_alias_obj = RoomAlias.from_string(room_alias, self.hs)
|
||||
|
||||
logger.debug("Got room name: %s", room_alias_obj.to_string())
|
||||
|
||||
room_id = content["room_id"]
|
||||
servers = content["servers"]
|
||||
|
||||
logger.debug("Got room_id: %s", room_id)
|
||||
logger.debug("Got servers: %s", servers)
|
||||
|
||||
# TODO(erikj): Check types.
|
||||
# TODO(erikj): Check that room exists
|
||||
|
||||
dir_handler = self.handlers.directory_handler
|
||||
|
||||
try:
|
||||
yield dir_handler.create_association(
|
||||
room_alias_obj, room_id, servers
|
||||
)
|
||||
except:
|
||||
logger.exception("Failed to create association")
|
||||
|
||||
defer.returnValue((200, {}))
|
50
synapse/rest/events.py
Normal file
50
synapse/rest/events.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This module contains REST servlets to do with event streaming, /events."""
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.streams import PaginationConfig
|
||||
from synapse.rest.base import RestServlet, client_path_pattern
|
||||
|
||||
|
||||
class EventStreamRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/events$")
|
||||
|
||||
DEFAULT_LONGPOLL_TIME_MS = 5000
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if "timeout" in request.args:
|
||||
try:
|
||||
timeout = int(request.args["timeout"][0])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "timeout must be in milliseconds.")
|
||||
|
||||
chunk = yield handler.get_stream(auth_user.to_string(), pagin_config,
|
||||
timeout=timeout)
|
||||
defer.returnValue((200, chunk))
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return (200, {})
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
EventStreamRestServlet(hs).register(http_server)
|
39
synapse/rest/im.py
Normal file
39
synapse/rest/im.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.streams import PaginationConfig
|
||||
from base import RestServlet, client_path_pattern
|
||||
|
||||
|
||||
class ImSyncRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/im/sync$")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
with_feedback = "feedback" in request.args
|
||||
pagination_config = PaginationConfig.from_request(request)
|
||||
handler = self.handlers.message_handler
|
||||
content = yield handler.snapshot_all_rooms(
|
||||
user_id=user.to_string(),
|
||||
pagin_config=pagination_config,
|
||||
feedback=with_feedback)
|
||||
|
||||
defer.returnValue((200, content))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
ImSyncRestServlet(hs).register(http_server)
|
80
synapse/rest/login.py
Normal file
80
synapse/rest/login.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from base import RestServlet, client_path_pattern
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class LoginRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/login$")
|
||||
PASS_TYPE = "m.login.password"
|
||||
|
||||
def on_GET(self, request):
|
||||
return (200, {"type": LoginRestServlet.PASS_TYPE})
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return (200, {})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
login_submission = _parse_json(request)
|
||||
try:
|
||||
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
|
||||
result = yield self.do_password_login(login_submission)
|
||||
defer.returnValue(result)
|
||||
else:
|
||||
raise SynapseError(400, "Bad login type.")
|
||||
except KeyError:
|
||||
raise SynapseError(400, "Missing JSON keys.")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_password_login(self, login_submission):
|
||||
handler = self.handlers.login_handler
|
||||
token = yield handler.login(
|
||||
user=login_submission["user"],
|
||||
password=login_submission["password"])
|
||||
|
||||
result = {
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class LoginFallbackRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/login/fallback$")
|
||||
|
||||
def on_GET(self, request):
|
||||
# TODO(kegan): This should be returning some HTML which is capable of
|
||||
# hitting LoginRestServlet
|
||||
return (200, "")
|
||||
|
||||
|
||||
def _parse_json(request):
|
||||
try:
|
||||
content = json.loads(request.content.read())
|
||||
if type(content) != dict:
|
||||
raise SynapseError(400, "Content must be a JSON object.")
|
||||
return content
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Content not JSON.")
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
LoginRestServlet(hs).register(http_server)
|
134
synapse/rest/presence.py
Normal file
134
synapse/rest/presence.py
Normal file
|
@ -0,0 +1,134 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" This module contains REST servlets to do with presence: /presence/<paths>
|
||||
"""
|
||||
from twisted.internet import defer
|
||||
|
||||
from base import RestServlet, client_path_pattern
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PresenceStatusRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/presence/(?P<user_id>[^/]*)/status")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
state = yield self.handlers.presence_handler.get_state(
|
||||
target_user=user, auth_user=auth_user)
|
||||
|
||||
defer.returnValue((200, state))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
state = {}
|
||||
try:
|
||||
content = json.loads(request.content.read())
|
||||
|
||||
state["state"] = content.pop("state")
|
||||
|
||||
if "status_msg" in content:
|
||||
state["status_msg"] = content.pop("status_msg")
|
||||
|
||||
if content:
|
||||
raise KeyError()
|
||||
except:
|
||||
defer.returnValue((400, "Unable to parse state"))
|
||||
|
||||
yield self.handlers.presence_handler.set_state(
|
||||
target_user=user, auth_user=auth_user, state=state)
|
||||
|
||||
defer.returnValue((200, ""))
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return (200, {})
|
||||
|
||||
|
||||
class PresenceListRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/presence_list/(?P<user_id>[^/]*)")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
if not user.is_mine:
|
||||
defer.returnValue((400, "User not hosted on this Home Server"))
|
||||
|
||||
if auth_user != user:
|
||||
defer.returnValue((400, "Cannot get another user's presence list"))
|
||||
|
||||
presence = yield self.handlers.presence_handler.get_presence_list(
|
||||
observer_user=user, accepted=True)
|
||||
|
||||
for p in presence:
|
||||
observed_user = p.pop("observed_user")
|
||||
p["user_id"] = observed_user.to_string()
|
||||
|
||||
defer.returnValue((200, presence))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
if not user.is_mine:
|
||||
defer.returnValue((400, "User not hosted on this Home Server"))
|
||||
|
||||
if auth_user != user:
|
||||
defer.returnValue((
|
||||
400, "Cannot modify another user's presence list"))
|
||||
|
||||
try:
|
||||
content = json.loads(request.content.read())
|
||||
except:
|
||||
logger.exception("JSON parse error")
|
||||
defer.returnValue((400, "Unable to parse content"))
|
||||
|
||||
deferreds = []
|
||||
|
||||
if "invite" in content:
|
||||
for u in content["invite"]:
|
||||
invited_user = self.hs.parse_userid(u)
|
||||
deferreds.append(self.handlers.presence_handler.send_invite(
|
||||
observer_user=user, observed_user=invited_user))
|
||||
|
||||
if "drop" in content:
|
||||
for u in content["drop"]:
|
||||
dropped_user = self.hs.parse_userid(u)
|
||||
deferreds.append(self.handlers.presence_handler.drop(
|
||||
observer_user=user, observed_user=dropped_user))
|
||||
|
||||
yield defer.DeferredList(deferreds)
|
||||
|
||||
defer.returnValue((200, ""))
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return (200, {})
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
PresenceStatusRestServlet(hs).register(http_server)
|
||||
PresenceListRestServlet(hs).register(http_server)
|
93
synapse/rest/profile.py
Normal file
93
synapse/rest/profile.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" This module contains REST servlets to do with profile: /profile/<paths> """
|
||||
from twisted.internet import defer
|
||||
|
||||
from base import RestServlet, client_path_pattern
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class ProfileDisplaynameRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
displayname = yield self.handlers.profile_handler.get_displayname(
|
||||
user,
|
||||
local_only="local_only" in request.args
|
||||
)
|
||||
|
||||
defer.returnValue((200, {"displayname": displayname}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
try:
|
||||
content = json.loads(request.content.read())
|
||||
new_name = content["displayname"]
|
||||
except:
|
||||
defer.returnValue((400, "Unable to parse name"))
|
||||
|
||||
yield self.handlers.profile_handler.set_displayname(
|
||||
user, auth_user, new_name)
|
||||
|
||||
defer.returnValue((200, ""))
|
||||
|
||||
def on_OPTIONS(self, request, user_id):
|
||||
return (200, {})
|
||||
|
||||
|
||||
class ProfileAvatarURLRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/avatar_url")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
|
||||
user,
|
||||
local_only="local_only" in request.args
|
||||
)
|
||||
|
||||
defer.returnValue((200, {"avatar_url": avatar_url}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
try:
|
||||
content = json.loads(request.content.read())
|
||||
new_name = content["avatar_url"]
|
||||
except:
|
||||
defer.returnValue((400, "Unable to parse name"))
|
||||
|
||||
yield self.handlers.profile_handler.set_avatar_url(
|
||||
user, auth_user, new_name)
|
||||
|
||||
defer.returnValue((200, ""))
|
||||
|
||||
def on_OPTIONS(self, request, user_id):
|
||||
return (200, {})
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
ProfileDisplaynameRestServlet(hs).register(http_server)
|
||||
ProfileAvatarURLRestServlet(hs).register(http_server)
|
32
synapse/rest/public.py
Normal file
32
synapse/rest/public.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This module contains REST servlets to do with public paths: /public"""
|
||||
from twisted.internet import defer
|
||||
|
||||
from base import RestServlet, client_path_pattern
|
||||
|
||||
|
||||
class PublicRoomListRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/public/rooms$")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
handler = self.handlers.room_list_handler
|
||||
data = yield handler.get_public_room_list()
|
||||
defer.returnValue((200, data))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
PublicRoomListRestServlet(hs).register(http_server)
|
68
synapse/rest/register.py
Normal file
68
synapse/rest/register.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This module contains REST servlets to do with registration: /register"""
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from base import RestServlet, client_path_pattern
|
||||
|
||||
import json
|
||||
import urllib
|
||||
|
||||
|
||||
class RegisterRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/register$")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
desired_user_id = None
|
||||
password = None
|
||||
try:
|
||||
register_json = json.loads(request.content.read())
|
||||
if "password" in register_json:
|
||||
password = register_json["password"]
|
||||
|
||||
if type(register_json["user_id"]) == unicode:
|
||||
desired_user_id = register_json["user_id"]
|
||||
if urllib.quote(desired_user_id) != desired_user_id:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"User ID must only contain characters which do not " +
|
||||
"require URL encoding.")
|
||||
except ValueError:
|
||||
defer.returnValue((400, "No JSON object."))
|
||||
except KeyError:
|
||||
pass # user_id is optional
|
||||
|
||||
handler = self.handlers.registration_handler
|
||||
(user_id, token) = yield handler.register(
|
||||
localpart=desired_user_id,
|
||||
password=password)
|
||||
|
||||
result = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
defer.returnValue(
|
||||
(200, result)
|
||||
)
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return (200, {})
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
RegisterRestServlet(hs).register(http_server)
|
394
synapse/rest/room.py
Normal file
394
synapse/rest/room.py
Normal file
|
@ -0,0 +1,394 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
|
||||
from twisted.internet import defer
|
||||
|
||||
from base import RestServlet, client_path_pattern
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.api.events.room import (RoomTopicEvent, MessageEvent,
|
||||
RoomMemberEvent, FeedbackEvent)
|
||||
from synapse.api.constants import Feedback, Membership
|
||||
from synapse.api.streams import PaginationConfig
|
||||
from synapse.types import RoomAlias
|
||||
|
||||
import json
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RoomCreateRestServlet(RestServlet):
|
||||
# No PATTERN; we have custom dispatch rules here
|
||||
|
||||
def register(self, http_server):
|
||||
# /rooms OR /rooms/<roomid>
|
||||
http_server.register_path("POST",
|
||||
client_path_pattern("/rooms$"),
|
||||
self.on_POST)
|
||||
http_server.register_path("PUT",
|
||||
client_path_pattern(
|
||||
"/rooms/(?P<room_id>[^/]*)$"),
|
||||
self.on_PUT)
|
||||
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity
|
||||
http_server.register_path("OPTIONS",
|
||||
client_path_pattern("/rooms(?:/.*)?$"),
|
||||
self.on_OPTIONS)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id):
|
||||
room_id = urllib.unquote(room_id)
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
|
||||
if not room_id:
|
||||
raise SynapseError(400, "PUT must specify a room ID")
|
||||
|
||||
room_config = self.get_room_config(request)
|
||||
info = yield self.make_room(room_config, auth_user, room_id)
|
||||
room_config.update(info)
|
||||
defer.returnValue((200, info))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
|
||||
room_config = self.get_room_config(request)
|
||||
info = yield self.make_room(room_config, auth_user, None)
|
||||
room_config.update(info)
|
||||
defer.returnValue((200, info))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def make_room(self, room_config, auth_user, room_id):
|
||||
handler = self.handlers.room_creation_handler
|
||||
info = yield handler.create_room(
|
||||
user_id=auth_user.to_string(),
|
||||
room_id=room_id,
|
||||
config=room_config
|
||||
)
|
||||
defer.returnValue(info)
|
||||
|
||||
def get_room_config(self, request):
|
||||
try:
|
||||
user_supplied_config = json.loads(request.content.read())
|
||||
if "visibility" not in user_supplied_config:
|
||||
# default visibility
|
||||
user_supplied_config["visibility"] = "public"
|
||||
return user_supplied_config
|
||||
except (ValueError, TypeError):
|
||||
raise SynapseError(400, "Body must be JSON.",
|
||||
errcode=Codes.BAD_JSON)
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return (200, {})
|
||||
|
||||
|
||||
class RoomTopicRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/topic$")
|
||||
|
||||
def get_event_type(self):
|
||||
return RoomTopicEvent.TYPE
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
data = yield msg_handler.get_room_data(
|
||||
user_id=user.to_string(),
|
||||
room_id=urllib.unquote(room_id),
|
||||
event_type=RoomTopicEvent.TYPE,
|
||||
state_key="",
|
||||
)
|
||||
|
||||
if not data:
|
||||
raise SynapseError(404, "Topic not found.", errcode=Codes.NOT_FOUND)
|
||||
defer.returnValue((200, json.loads(data.content)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
|
||||
content = _parse_json(request)
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=self.get_event_type(),
|
||||
content=content,
|
||||
room_id=urllib.unquote(room_id),
|
||||
user_id=user.to_string(),
|
||||
)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
yield msg_handler.store_room_data(
|
||||
event=event
|
||||
)
|
||||
defer.returnValue((200, ""))
|
||||
|
||||
|
||||
class JoinRoomAliasServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/join/(?P<room_alias>[^/]+)$")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_alias):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
|
||||
if not user:
|
||||
defer.returnValue((403, "Unrecognized user"))
|
||||
|
||||
logger.debug("room_alias: %s", room_alias)
|
||||
|
||||
room_alias = RoomAlias.from_string(
|
||||
urllib.unquote(room_alias),
|
||||
self.hs
|
||||
)
|
||||
|
||||
handler = self.handlers.room_member_handler
|
||||
ret_dict = yield handler.join_room_alias(user, room_alias)
|
||||
|
||||
defer.returnValue((200, ret_dict))
|
||||
|
||||
|
||||
class RoomMemberRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members/"
|
||||
+ "(?P<target_user_id>[^/]*)/state$")
|
||||
|
||||
def get_event_type(self):
|
||||
return RoomMemberEvent.TYPE
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, target_user_id):
|
||||
room_id = urllib.unquote(room_id)
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
|
||||
handler = self.handlers.room_member_handler
|
||||
member = yield handler.get_room_member(room_id, target_user_id,
|
||||
user.to_string())
|
||||
if not member:
|
||||
raise SynapseError(404, "Member not found.",
|
||||
errcode=Codes.NOT_FOUND)
|
||||
defer.returnValue((200, json.loads(member.content)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, roomid, target_user_id):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=self.get_event_type(),
|
||||
target_user_id=target_user_id,
|
||||
room_id=urllib.unquote(roomid),
|
||||
user_id=user.to_string(),
|
||||
membership=Membership.LEAVE,
|
||||
content={"membership": Membership.LEAVE}
|
||||
)
|
||||
|
||||
handler = self.handlers.room_member_handler
|
||||
yield handler.change_membership(event, broadcast_msg=True)
|
||||
defer.returnValue((200, ""))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, roomid, target_user_id):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
|
||||
content = _parse_json(request)
|
||||
if "membership" not in content:
|
||||
raise SynapseError(400, "No membership key.",
|
||||
errcode=Codes.BAD_JSON)
|
||||
|
||||
valid_membership_values = [Membership.JOIN, Membership.INVITE]
|
||||
if (content["membership"] not in valid_membership_values):
|
||||
raise SynapseError(400, "Membership value must be %s." % (
|
||||
valid_membership_values,), errcode=Codes.BAD_JSON)
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=self.get_event_type(),
|
||||
target_user_id=target_user_id,
|
||||
room_id=urllib.unquote(roomid),
|
||||
user_id=user.to_string(),
|
||||
membership=content["membership"],
|
||||
content=content
|
||||
)
|
||||
|
||||
handler = self.handlers.room_member_handler
|
||||
result = yield handler.change_membership(event, broadcast_msg=True)
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class MessageRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages/"
|
||||
+ "(?P<sender_id>[^/]*)/(?P<msg_id>[^/]*)$")
|
||||
|
||||
def get_event_type(self):
|
||||
return MessageEvent.TYPE
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, sender_id, msg_id):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
msg = yield msg_handler.get_message(room_id=urllib.unquote(room_id),
|
||||
sender_id=sender_id,
|
||||
msg_id=msg_id,
|
||||
user_id=user.to_string(),
|
||||
)
|
||||
|
||||
if not msg:
|
||||
raise SynapseError(404, "Message not found.",
|
||||
errcode=Codes.NOT_FOUND)
|
||||
|
||||
defer.returnValue((200, json.loads(msg.content)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, sender_id, msg_id):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
|
||||
if user.to_string() != sender_id:
|
||||
raise SynapseError(403, "Must send messages as yourself.",
|
||||
errcode=Codes.FORBIDDEN)
|
||||
|
||||
content = _parse_json(request)
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=self.get_event_type(),
|
||||
room_id=urllib.unquote(room_id),
|
||||
user_id=user.to_string(),
|
||||
msg_id=msg_id,
|
||||
content=content
|
||||
)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
yield msg_handler.send_message(event)
|
||||
|
||||
defer.returnValue((200, ""))
|
||||
|
||||
|
||||
class FeedbackRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern(
|
||||
"/rooms/(?P<room_id>[^/]*)/messages/" +
|
||||
"(?P<msg_sender_id>[^/]*)/(?P<msg_id>[^/]*)/feedback/" +
|
||||
"(?P<sender_id>[^/]*)/(?P<feedback_type>[^/]*)$"
|
||||
)
|
||||
|
||||
def get_event_type(self):
|
||||
return FeedbackEvent.TYPE
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, msg_sender_id, msg_id, fb_sender_id,
|
||||
feedback_type):
|
||||
user = yield (self.auth.get_user_by_req(request))
|
||||
|
||||
if feedback_type not in Feedback.LIST:
|
||||
raise SynapseError(400, "Bad feedback type.",
|
||||
errcode=Codes.BAD_JSON)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
feedback = yield msg_handler.get_feedback(
|
||||
room_id=urllib.unquote(room_id),
|
||||
msg_sender_id=msg_sender_id,
|
||||
msg_id=msg_id,
|
||||
user_id=user.to_string(),
|
||||
fb_sender_id=fb_sender_id,
|
||||
fb_type=feedback_type
|
||||
)
|
||||
|
||||
if not feedback:
|
||||
raise SynapseError(404, "Feedback not found.",
|
||||
errcode=Codes.NOT_FOUND)
|
||||
|
||||
defer.returnValue((200, json.loads(feedback.content)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, sender_id, msg_id, fb_sender_id,
|
||||
feedback_type):
|
||||
user = yield (self.auth.get_user_by_req(request))
|
||||
|
||||
if user.to_string() != fb_sender_id:
|
||||
raise SynapseError(403, "Must send feedback as yourself.",
|
||||
errcode=Codes.FORBIDDEN)
|
||||
|
||||
if feedback_type not in Feedback.LIST:
|
||||
raise SynapseError(400, "Bad feedback type.",
|
||||
errcode=Codes.BAD_JSON)
|
||||
|
||||
content = _parse_json(request)
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=self.get_event_type(),
|
||||
room_id=urllib.unquote(room_id),
|
||||
msg_sender_id=sender_id,
|
||||
msg_id=msg_id,
|
||||
user_id=user.to_string(), # user sending the feedback
|
||||
feedback_type=feedback_type,
|
||||
content=content
|
||||
)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
yield msg_handler.send_feedback(event)
|
||||
|
||||
defer.returnValue((200, ""))
|
||||
|
||||
|
||||
class RoomMemberListRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members/list$")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
# TODO support Pagination stream API (limit/tokens)
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
handler = self.handlers.room_member_handler
|
||||
members = yield handler.get_room_members_as_pagination_chunk(
|
||||
room_id=urllib.unquote(room_id),
|
||||
user_id=user.to_string())
|
||||
|
||||
defer.returnValue((200, members))
|
||||
|
||||
|
||||
class RoomMessageListRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages/list$")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
pagination_config = PaginationConfig.from_request(request)
|
||||
with_feedback = "feedback" in request.args
|
||||
handler = self.handlers.message_handler
|
||||
msgs = yield handler.get_messages(
|
||||
room_id=urllib.unquote(room_id),
|
||||
user_id=user.to_string(),
|
||||
pagin_config=pagination_config,
|
||||
feedback=with_feedback)
|
||||
|
||||
defer.returnValue((200, msgs))
|
||||
|
||||
|
||||
def _parse_json(request):
|
||||
try:
|
||||
content = json.loads(request.content.read())
|
||||
if type(content) != dict:
|
||||
raise SynapseError(400, "Content must be a JSON object.",
|
||||
errcode=Codes.NOT_JSON)
|
||||
return content
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
RoomTopicRestServlet(hs).register(http_server)
|
||||
RoomMemberRestServlet(hs).register(http_server)
|
||||
MessageRestServlet(hs).register(http_server)
|
||||
FeedbackRestServlet(hs).register(http_server)
|
||||
RoomCreateRestServlet(hs).register(http_server)
|
||||
RoomMemberListRestServlet(hs).register(http_server)
|
||||
RoomMessageListRestServlet(hs).register(http_server)
|
||||
JoinRoomAliasServlet(hs).register(http_server)
|
176
synapse/server.py
Normal file
176
synapse/server.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This file provides some classes for setting up (partially-populated)
|
||||
# homeservers; either as a full homeserver as a real application, or a small
|
||||
# partial one for unit test mocking.
|
||||
|
||||
# Imports required for the default HomeServer() implementation
|
||||
from synapse.federation import initialize_http_replication
|
||||
from synapse.federation.handler import FederationEventHandler
|
||||
from synapse.api.events.factory import EventFactory
|
||||
from synapse.api.notifier import Notifier
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.handlers import Handlers
|
||||
from synapse.rest import RestServletFactory
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage import DataStore
|
||||
from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
from synapse.util.distributor import Distributor
|
||||
from synapse.util.lockutils import LockManager
|
||||
|
||||
|
||||
class BaseHomeServer(object):
|
||||
"""A basic homeserver object without lazy component builders.
|
||||
|
||||
This will need all of the components it requires to either be passed as
|
||||
constructor arguments, or the relevant methods overriding to create them.
|
||||
Typically this would only be used for unit tests.
|
||||
|
||||
For every dependency in the DEPENDENCIES list below, this class creates one
|
||||
method,
|
||||
def get_DEPENDENCY(self)
|
||||
which returns the value of that dependency. If no value has yet been set
|
||||
nor was provided to the constructor, it will attempt to call a lazy builder
|
||||
method called
|
||||
def build_DEPENDENCY(self)
|
||||
which must be implemented by the subclass. This code may call any of the
|
||||
required "get" methods on the instance to obtain the sub-dependencies that
|
||||
one requires.
|
||||
"""
|
||||
|
||||
DEPENDENCIES = [
|
||||
'clock',
|
||||
'http_server',
|
||||
'http_client',
|
||||
'db_pool',
|
||||
'persistence_service',
|
||||
'federation',
|
||||
'replication_layer',
|
||||
'datastore',
|
||||
'event_factory',
|
||||
'handlers',
|
||||
'auth',
|
||||
'rest_servlet_factory',
|
||||
'state_handler',
|
||||
'room_lock_manager',
|
||||
'notifier',
|
||||
'distributor',
|
||||
]
|
||||
|
||||
def __init__(self, hostname, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
hostname : The hostname for the server.
|
||||
"""
|
||||
self.hostname = hostname
|
||||
self._building = {}
|
||||
|
||||
# Other kwargs are explicit dependencies
|
||||
for depname in kwargs:
|
||||
setattr(self, depname, kwargs[depname])
|
||||
|
||||
@classmethod
|
||||
def _make_dependency_method(cls, depname):
|
||||
def _get(self):
|
||||
if hasattr(self, depname):
|
||||
return getattr(self, depname)
|
||||
|
||||
if hasattr(self, "build_%s" % (depname)):
|
||||
# Prevent cyclic dependencies from deadlocking
|
||||
if depname in self._building:
|
||||
raise ValueError("Cyclic dependency while building %s" % (
|
||||
depname,
|
||||
))
|
||||
self._building[depname] = 1
|
||||
|
||||
builder = getattr(self, "build_%s" % (depname))
|
||||
dep = builder()
|
||||
setattr(self, depname, dep)
|
||||
|
||||
del self._building[depname]
|
||||
|
||||
return dep
|
||||
|
||||
raise NotImplementedError(
|
||||
"%s has no %s nor a builder for it" % (
|
||||
type(self).__name__, depname,
|
||||
)
|
||||
)
|
||||
|
||||
setattr(BaseHomeServer, "get_%s" % (depname), _get)
|
||||
|
||||
# Other utility methods
|
||||
def parse_userid(self, s):
|
||||
"""Parse the string given by 's' as a User ID and return a UserID
|
||||
object."""
|
||||
return UserID.from_string(s, hs=self)
|
||||
|
||||
# Build magic accessors for every dependency
|
||||
for depname in BaseHomeServer.DEPENDENCIES:
|
||||
BaseHomeServer._make_dependency_method(depname)
|
||||
|
||||
|
||||
class HomeServer(BaseHomeServer):
|
||||
"""A homeserver object that will construct most of its dependencies as
|
||||
required.
|
||||
|
||||
It still requires the following to be specified by the caller:
|
||||
http_server
|
||||
http_client
|
||||
db_pool
|
||||
"""
|
||||
|
||||
def build_clock(self):
|
||||
return Clock()
|
||||
|
||||
def build_replication_layer(self):
|
||||
return initialize_http_replication(self)
|
||||
|
||||
def build_federation(self):
|
||||
return FederationEventHandler(self)
|
||||
|
||||
def build_datastore(self):
|
||||
return DataStore(self)
|
||||
|
||||
def build_event_factory(self):
|
||||
return EventFactory()
|
||||
|
||||
def build_handlers(self):
|
||||
return Handlers(self)
|
||||
|
||||
def build_notifier(self):
|
||||
return Notifier(self)
|
||||
|
||||
def build_auth(self):
|
||||
return Auth(self)
|
||||
|
||||
def build_rest_servlet_factory(self):
|
||||
return RestServletFactory(self)
|
||||
|
||||
def build_state_handler(self):
|
||||
return StateHandler(self)
|
||||
|
||||
def build_room_lock_manager(self):
|
||||
return LockManager()
|
||||
|
||||
def build_distributor(self):
|
||||
return Distributor()
|
||||
|
||||
def register_servlets(self):
|
||||
"""Simply building the ServletFactory is sufficient to have it
|
||||
register."""
|
||||
self.get_rest_servlet_factory()
|
223
synapse/state.py
Normal file
223
synapse/state.py
Normal file
|
@ -0,0 +1,223 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.federation.pdu_codec import encode_event_id
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import logging
|
||||
import hashlib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_state_key_from_event(event):
|
||||
return event.state_key
|
||||
|
||||
|
||||
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
|
||||
|
||||
|
||||
class StateHandler(object):
|
||||
""" Repsonsible for doing state conflict resolution.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self._replication = hs.get_replication_layer()
|
||||
self.server_name = hs.hostname
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def handle_new_event(self, event):
|
||||
""" Given an event this works out if a) we have sufficient power level
|
||||
to update the state and b) works out what the prev_state should be.
|
||||
|
||||
Returns:
|
||||
Deferred: Resolved with a boolean indicating if we succesfully
|
||||
updated the state.
|
||||
|
||||
Raised:
|
||||
AuthError
|
||||
"""
|
||||
# This needs to be done in a transaction.
|
||||
|
||||
if not hasattr(event, "state_key"):
|
||||
return
|
||||
|
||||
key = KeyStateTuple(
|
||||
event.room_id,
|
||||
event.type,
|
||||
_get_state_key_from_event(event)
|
||||
)
|
||||
|
||||
# Now I need to fill out the prev state and work out if it has auth
|
||||
# (w.r.t. to power levels)
|
||||
|
||||
results = yield self.store.get_latest_pdus_in_context(
|
||||
event.room_id
|
||||
)
|
||||
|
||||
event.prev_events = [
|
||||
encode_event_id(p_id, origin) for p_id, origin, _ in results
|
||||
]
|
||||
event.prev_events = [
|
||||
e for e in event.prev_events if e != event.event_id
|
||||
]
|
||||
|
||||
if results:
|
||||
event.depth = max([int(v) for _, _, v in results]) + 1
|
||||
else:
|
||||
event.depth = 0
|
||||
|
||||
current_state = yield self.store.get_current_state(
|
||||
key.context, key.type, key.state_key
|
||||
)
|
||||
|
||||
if current_state:
|
||||
event.prev_state = encode_event_id(
|
||||
current_state.pdu_id, current_state.origin
|
||||
)
|
||||
|
||||
# TODO check current_state to see if the min power level is less
|
||||
# than the power level of the user
|
||||
# power_level = self._get_power_level_for_event(event)
|
||||
|
||||
yield self.store.update_current_state(
|
||||
pdu_id=event.event_id,
|
||||
origin=self.server_name,
|
||||
context=key.context,
|
||||
pdu_type=key.type,
|
||||
state_key=key.state_key
|
||||
)
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def handle_new_state(self, new_pdu):
|
||||
""" Apply conflict resolution to `new_pdu`.
|
||||
|
||||
This should be called on every new state pdu, regardless of whether or
|
||||
not there is a conflict.
|
||||
|
||||
This function is safe against the race of it getting called with two
|
||||
`PDU`s trying to update the same state.
|
||||
"""
|
||||
|
||||
# This needs to be done in a transaction.
|
||||
|
||||
is_new = yield self._handle_new_state(new_pdu)
|
||||
|
||||
if is_new:
|
||||
yield self.store.update_current_state(
|
||||
pdu_id=new_pdu.pdu_id,
|
||||
origin=new_pdu.origin,
|
||||
context=new_pdu.context,
|
||||
pdu_type=new_pdu.pdu_type,
|
||||
state_key=new_pdu.state_key
|
||||
)
|
||||
|
||||
defer.returnValue(is_new)
|
||||
|
||||
def _get_power_level_for_event(self, event):
|
||||
# return self._persistence.get_power_level_for_user(event.room_id,
|
||||
# event.sender)
|
||||
return event.power_level
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _handle_new_state(self, new_pdu):
|
||||
tree = yield self.store.get_unresolved_state_tree(new_pdu)
|
||||
new_branch, current_branch = tree
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_state new=%s, current=%s",
|
||||
new_branch, current_branch
|
||||
)
|
||||
|
||||
if not current_branch:
|
||||
# There is no current state
|
||||
defer.returnValue(True)
|
||||
return
|
||||
|
||||
if new_branch[-1] == current_branch[-1]:
|
||||
# We have all the PDUs we need, so we can just do the conflict
|
||||
# resolution.
|
||||
|
||||
if len(current_branch) == 1:
|
||||
# This is a direct clobber so we can just...
|
||||
defer.returnValue(True)
|
||||
|
||||
conflict_res = [
|
||||
self._do_power_level_conflict_res,
|
||||
self._do_chain_length_conflict_res,
|
||||
self._do_hash_conflict_res,
|
||||
]
|
||||
|
||||
for algo in conflict_res:
|
||||
new_res, curr_res = algo(new_branch, current_branch)
|
||||
|
||||
if new_res < curr_res:
|
||||
defer.returnValue(False)
|
||||
elif new_res > curr_res:
|
||||
defer.returnValue(True)
|
||||
|
||||
raise Exception("Conflict resolution failed.")
|
||||
|
||||
else:
|
||||
# We need to ask for PDUs.
|
||||
missing_prev = max(
|
||||
new_branch[-1], current_branch[-1],
|
||||
key=lambda x: x.depth
|
||||
)
|
||||
|
||||
yield self._replication.get_pdu(
|
||||
destination=missing_prev.origin,
|
||||
pdu_origin=missing_prev.prev_state_origin,
|
||||
pdu_id=missing_prev.prev_state_id,
|
||||
outlier=True
|
||||
)
|
||||
|
||||
updated_current = yield self._handle_new_state(new_pdu)
|
||||
defer.returnValue(updated_current)
|
||||
|
||||
def _do_power_level_conflict_res(self, new_branch, current_branch):
|
||||
max_power_new = max(
|
||||
new_branch[:-1],
|
||||
key=lambda t: t.power_level
|
||||
).power_level
|
||||
|
||||
max_power_current = max(
|
||||
current_branch[:-1],
|
||||
key=lambda t: t.power_level
|
||||
).power_level
|
||||
|
||||
return (max_power_new, max_power_current)
|
||||
|
||||
def _do_chain_length_conflict_res(self, new_branch, current_branch):
|
||||
return (len(new_branch), len(current_branch))
|
||||
|
||||
def _do_hash_conflict_res(self, new_branch, current_branch):
|
||||
new_str = "".join([p.pdu_id + p.origin for p in new_branch])
|
||||
c_str = "".join([p.pdu_id + p.origin for p in current_branch])
|
||||
|
||||
return (
|
||||
hashlib.sha1(new_str).hexdigest(),
|
||||
hashlib.sha1(c_str).hexdigest()
|
||||
)
|
117
synapse/storage/__init__.py
Normal file
117
synapse/storage/__init__.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.api.events.room import (
|
||||
RoomMemberEvent, MessageEvent, RoomTopicEvent, FeedbackEvent,
|
||||
RoomConfigEvent
|
||||
)
|
||||
|
||||
from .directory import DirectoryStore
|
||||
from .feedback import FeedbackStore
|
||||
from .message import MessageStore
|
||||
from .presence import PresenceStore
|
||||
from .profile import ProfileStore
|
||||
from .registration import RegistrationStore
|
||||
from .room import RoomStore
|
||||
from .roommember import RoomMemberStore
|
||||
from .roomdata import RoomDataStore
|
||||
from .stream import StreamStore
|
||||
from .pdu import StatePduStore, PduStore
|
||||
from .transactions import TransactionStore
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
class DataStore(RoomDataStore, RoomMemberStore, MessageStore, RoomStore,
|
||||
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
|
||||
PresenceStore, PduStore, StatePduStore, TransactionStore,
|
||||
DirectoryStore):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(DataStore, self).__init__(hs)
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.hs = hs
|
||||
|
||||
def persist_event(self, event):
|
||||
if event.type == MessageEvent.TYPE:
|
||||
return self.store_message(
|
||||
user_id=event.user_id,
|
||||
room_id=event.room_id,
|
||||
msg_id=event.msg_id,
|
||||
content=json.dumps(event.content)
|
||||
)
|
||||
elif event.type == RoomMemberEvent.TYPE:
|
||||
return self.store_room_member(
|
||||
user_id=event.target_user_id,
|
||||
sender=event.user_id,
|
||||
room_id=event.room_id,
|
||||
content=event.content,
|
||||
membership=event.content["membership"]
|
||||
)
|
||||
elif event.type == FeedbackEvent.TYPE:
|
||||
return self.store_feedback(
|
||||
room_id=event.room_id,
|
||||
msg_id=event.msg_id,
|
||||
msg_sender_id=event.msg_sender_id,
|
||||
fb_sender_id=event.user_id,
|
||||
fb_type=event.feedback_type,
|
||||
content=json.dumps(event.content)
|
||||
)
|
||||
elif event.type == RoomTopicEvent.TYPE:
|
||||
return self.store_room_data(
|
||||
room_id=event.room_id,
|
||||
etype=event.type,
|
||||
state_key=event.state_key,
|
||||
content=json.dumps(event.content)
|
||||
)
|
||||
elif event.type == RoomConfigEvent.TYPE:
|
||||
if "visibility" in event.content:
|
||||
visibility = event.content["visibility"]
|
||||
return self.store_room_config(
|
||||
room_id=event.room_id,
|
||||
visibility=visibility
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Don't know how to persist type=%s" % event.type
|
||||
)
|
||||
|
||||
|
||||
def schema_path(schema):
|
||||
""" Get a filesystem path for the named database schema
|
||||
|
||||
Args:
|
||||
schema: Name of the database schema.
|
||||
Returns:
|
||||
A filesystem path pointing at a ".sql" file.
|
||||
|
||||
"""
|
||||
dir_path = os.path.dirname(__file__)
|
||||
schemaPath = os.path.join(dir_path, "schema", schema + ".sql")
|
||||
return schemaPath
|
||||
|
||||
|
||||
def read_schema(schema):
|
||||
""" Read the named database schema.
|
||||
|
||||
Args:
|
||||
schema: Name of the datbase schema.
|
||||
Returns:
|
||||
A string containing the database schema.
|
||||
"""
|
||||
with open(schema_path(schema)) as schema_file:
|
||||
return schema_file.read()
|
405
synapse/storage/_base.py
Normal file
405
synapse/storage/_base.py
Normal file
|
@ -0,0 +1,405 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
|
||||
import collections
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SQLBaseStore(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
self._db_pool = hs.get_db_pool()
|
||||
|
||||
def cursor_to_dict(self, cursor):
|
||||
"""Converts a SQL cursor into an list of dicts.
|
||||
|
||||
Args:
|
||||
cursor : The DBAPI cursor which has executed a query.
|
||||
Returns:
|
||||
A list of dicts where the key is the column header.
|
||||
"""
|
||||
col_headers = list(column[0] for column in cursor.description)
|
||||
results = list(
|
||||
dict(zip(col_headers, row)) for row in cursor.fetchall()
|
||||
)
|
||||
return results
|
||||
|
||||
def _execute(self, decoder, query, *args):
|
||||
"""Runs a single query for a result set.
|
||||
|
||||
Args:
|
||||
decoder - The function which can resolve the cursor results to
|
||||
something meaningful.
|
||||
query - The query string to execute
|
||||
*args - Query args.
|
||||
Returns:
|
||||
The result of decoder(results)
|
||||
"""
|
||||
logger.debug(
|
||||
"[SQL] %s Args=%s Func=%s", query, args, decoder.__name__
|
||||
)
|
||||
|
||||
def interaction(txn):
|
||||
cursor = txn.execute(query, args)
|
||||
return decoder(cursor)
|
||||
return self._db_pool.runInteraction(interaction)
|
||||
|
||||
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
||||
# no complex WHERE clauses, just a dict of values for columns.
|
||||
|
||||
def _simple_insert(self, table, values):
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
Args:
|
||||
table : string giving the table name
|
||||
values : dict of new column names and values for them
|
||||
"""
|
||||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||
table,
|
||||
", ".join(k for k in values),
|
||||
", ".join("?" for k in values)
|
||||
)
|
||||
|
||||
def func(txn):
|
||||
txn.execute(sql, values.values())
|
||||
return txn.lastrowid
|
||||
return self._db_pool.runInteraction(func)
|
||||
|
||||
def _simple_select_one(self, table, keyvalues, retcols,
|
||||
allow_none=False):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it.
|
||||
|
||||
Args:
|
||||
table : string giving the table name
|
||||
keyvalues : dict of column names and values to select the row with
|
||||
retcols : list of strings giving the names of the columns to return
|
||||
|
||||
allow_none : If true, return None instead of failing if the SELECT
|
||||
statement returns no rows
|
||||
"""
|
||||
return self._simple_selectupdate_one(
|
||||
table, keyvalues, retcols=retcols, allow_none=allow_none
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _simple_select_one_onecol(self, table, keyvalues, retcol,
|
||||
allow_none=False):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it."
|
||||
|
||||
Args:
|
||||
table : string giving the table name
|
||||
keyvalues : dict of column names and values to select the row with
|
||||
retcol : string giving the name of the column to return
|
||||
"""
|
||||
ret = yield self._simple_select_one(
|
||||
table=table,
|
||||
keyvalues=keyvalues,
|
||||
retcols=[retcol],
|
||||
allow_none=allow_none
|
||||
)
|
||||
|
||||
if ret:
|
||||
defer.returnValue(ret[retcol])
|
||||
else:
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _simple_select_onecol(self, table, keyvalues, retcol):
|
||||
"""Executes a SELECT query on the named table, which returns a list
|
||||
comprising of the values of the named column from the selected rows.
|
||||
|
||||
Args:
|
||||
table (str): table name
|
||||
keyvalues (dict): column names and values to select the rows with
|
||||
retcol (str): column whos value we wish to retrieve.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a list
|
||||
"""
|
||||
sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
|
||||
"retcol": retcol,
|
||||
"table": table,
|
||||
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
|
||||
}
|
||||
|
||||
def func(txn):
|
||||
txn.execute(sql, keyvalues.values())
|
||||
return txn.fetchall()
|
||||
|
||||
res = yield self._db_pool.runInteraction(func)
|
||||
|
||||
defer.returnValue([r[0] for r in res])
|
||||
|
||||
def _simple_select_list(self, table, keyvalues, retcols):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
Args:
|
||||
table : string giving the table name
|
||||
keyvalues : dict of column names and values to select the rows with
|
||||
retcols : list of strings giving the names of the columns to return
|
||||
"""
|
||||
sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
||||
)
|
||||
|
||||
def func(txn):
|
||||
txn.execute(sql, keyvalues.values())
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
return self._db_pool.runInteraction(func)
|
||||
|
||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||
retcols=None):
|
||||
"""Executes an UPDATE query on the named table, setting new values for
|
||||
columns in a row matching the key values.
|
||||
|
||||
Args:
|
||||
table : string giving the table name
|
||||
keyvalues : dict of column names and values to select the row with
|
||||
updatevalues : dict giving column names and values to update
|
||||
retcols : optional list of column names to return
|
||||
|
||||
If present, retcols gives a list of column names on which to perform
|
||||
a SELECT statement *before* performing the UPDATE statement. The values
|
||||
of these will be returned in a dict.
|
||||
|
||||
These are performed within the same transaction, allowing an atomic
|
||||
get-and-set. This can be used to implement compare-and-set by putting
|
||||
the update column in the 'keyvalues' dict as well.
|
||||
"""
|
||||
return self._simple_selectupdate_one(table, keyvalues, updatevalues,
|
||||
retcols=retcols)
|
||||
|
||||
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
|
||||
retcols=None, allow_none=False):
|
||||
""" Combined SELECT then UPDATE."""
|
||||
if retcols:
|
||||
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
||||
)
|
||||
|
||||
if updatevalues:
|
||||
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k) for k in updatevalues),
|
||||
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
||||
)
|
||||
|
||||
def func(txn):
|
||||
ret = None
|
||||
if retcols:
|
||||
txn.execute(select_sql, keyvalues.values())
|
||||
|
||||
row = txn.fetchone()
|
||||
if not row:
|
||||
if allow_none:
|
||||
return None
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
ret = dict(zip(retcols, row))
|
||||
|
||||
if updatevalues:
|
||||
txn.execute(
|
||||
update_sql,
|
||||
updatevalues.values() + keyvalues.values()
|
||||
)
|
||||
|
||||
if txn.rowcount == 0:
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
return ret
|
||||
return self._db_pool.runInteraction(func)
|
||||
|
||||
def _simple_delete_one(self, table, keyvalues):
|
||||
"""Executes a DELETE query on the named table, expecting to delete a
|
||||
single row.
|
||||
|
||||
Args:
|
||||
table : string giving the table name
|
||||
keyvalues : dict of column names and values to select the row with
|
||||
"""
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
||||
)
|
||||
|
||||
def func(txn):
|
||||
txn.execute(sql, keyvalues.values())
|
||||
if txn.rowcount == 0:
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "more than one row matched")
|
||||
return self._db_pool.runInteraction(func)
|
||||
|
||||
def _simple_max_id(self, table):
|
||||
"""Executes a SELECT query on the named table, expecting to return the
|
||||
max value for the column "id".
|
||||
|
||||
Args:
|
||||
table : string giving the table name
|
||||
"""
|
||||
sql = "SELECT MAX(id) AS id FROM %s" % table
|
||||
|
||||
def func(txn):
|
||||
txn.execute(sql)
|
||||
max_id = self.cursor_to_dict(txn)[0]["id"]
|
||||
if max_id is None:
|
||||
return 0
|
||||
return max_id
|
||||
|
||||
return self._db_pool.runInteraction(func)
|
||||
|
||||
|
||||
class Table(object):
|
||||
""" A base class used to store information about a particular table.
|
||||
"""
|
||||
|
||||
table_name = None
|
||||
""" str: The name of the table """
|
||||
|
||||
fields = None
|
||||
""" list: The field names """
|
||||
|
||||
EntryType = None
|
||||
""" Type: A tuple type used to decode the results """
|
||||
|
||||
_select_where_clause = "SELECT %s FROM %s WHERE %s"
|
||||
_select_clause = "SELECT %s FROM %s"
|
||||
_insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)"
|
||||
|
||||
@classmethod
|
||||
def select_statement(cls, where_clause=None):
|
||||
"""
|
||||
Args:
|
||||
where_clause (str): The WHERE clause to use.
|
||||
|
||||
Returns:
|
||||
str: An SQL statement to select rows from the table with the given
|
||||
WHERE clause.
|
||||
"""
|
||||
if where_clause:
|
||||
return cls._select_where_clause % (
|
||||
", ".join(cls.fields),
|
||||
cls.table_name,
|
||||
where_clause
|
||||
)
|
||||
else:
|
||||
return cls._select_clause % (
|
||||
", ".join(cls.fields),
|
||||
cls.table_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_statement(cls):
|
||||
return cls._insert_clause % (
|
||||
cls.table_name,
|
||||
", ".join(cls.fields),
|
||||
", ".join(["?"] * len(cls.fields)),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def decode_single_result(cls, results):
|
||||
""" Given an iterable of tuples, return a single instance of
|
||||
`EntryType` or None if the iterable is empty
|
||||
Args:
|
||||
results (list): The results list to convert to `EntryType`
|
||||
Returns:
|
||||
EntryType: An instance of `EntryType`
|
||||
"""
|
||||
results = list(results)
|
||||
if results:
|
||||
return cls.EntryType(*results[0])
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def decode_results(cls, results):
|
||||
""" Given an iterable of tuples, return a list of `EntryType`
|
||||
Args:
|
||||
results (list): The results list to convert to `EntryType`
|
||||
|
||||
Returns:
|
||||
list: A list of `EntryType`
|
||||
"""
|
||||
return [cls.EntryType(*row) for row in results]
|
||||
|
||||
@classmethod
|
||||
def get_fields_string(cls, prefix=None):
|
||||
if prefix:
|
||||
to_join = ("%s.%s" % (prefix, f) for f in cls.fields)
|
||||
else:
|
||||
to_join = cls.fields
|
||||
|
||||
return ", ".join(to_join)
|
||||
|
||||
|
||||
class JoinHelper(object):
|
||||
""" Used to help do joins on tables by looking at the tables' fields and
|
||||
creating a list of unique fields to use with SELECTs and a namedtuple
|
||||
to dump the results into.
|
||||
|
||||
Attributes:
|
||||
taples (list): List of `Table` classes
|
||||
EntryType (type)
|
||||
"""
|
||||
|
||||
def __init__(self, *tables):
|
||||
self.tables = tables
|
||||
|
||||
res = []
|
||||
for table in self.tables:
|
||||
res += [f for f in table.fields if f not in res]
|
||||
|
||||
self.EntryType = collections.namedtuple("JoinHelperEntry", res)
|
||||
|
||||
def get_fields(self, **prefixes):
|
||||
"""Get a string representing a list of fields for use in SELECT
|
||||
statements with the given prefixes applied to each.
|
||||
|
||||
For example::
|
||||
|
||||
JoinHelper(PdusTable, StateTable).get_fields(
|
||||
PdusTable="pdus",
|
||||
StateTable="state"
|
||||
)
|
||||
"""
|
||||
res = []
|
||||
for field in self.EntryType._fields:
|
||||
for table in self.tables:
|
||||
if field in table.fields:
|
||||
res.append("%s.%s" % (prefixes[table.__name__], field))
|
||||
break
|
||||
|
||||
return ", ".join(res)
|
||||
|
||||
def decode_results(self, rows):
|
||||
return [self.EntryType(*row) for row in rows]
|
93
synapse/storage/directory.py
Normal file
93
synapse/storage/directory.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ._base import SQLBaseStore
|
||||
from twisted.internet import defer
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
RoomAliasMapping = namedtuple(
|
||||
"RoomAliasMapping",
|
||||
("room_id", "room_alias", "servers",)
|
||||
)
|
||||
|
||||
|
||||
class DirectoryStore(SQLBaseStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_association_from_room_alias(self, room_alias):
|
||||
""" Get's the room_id and server list for a given room_alias
|
||||
|
||||
Args:
|
||||
room_alias (RoomAlias)
|
||||
|
||||
Returns:
|
||||
Deferred: results in namedtuple with keys "room_id" and
|
||||
"servers" or None if no association can be found
|
||||
"""
|
||||
room_id = yield self._simple_select_one_onecol(
|
||||
"room_aliases",
|
||||
{"room_alias": room_alias.to_string()},
|
||||
"room_id",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if not room_id:
|
||||
defer.returnValue(None)
|
||||
return
|
||||
|
||||
servers = yield self._simple_select_onecol(
|
||||
"room_alias_servers",
|
||||
{"room_alias": room_alias.to_string()},
|
||||
"server",
|
||||
)
|
||||
|
||||
if not servers:
|
||||
defer.returnValue(None)
|
||||
return
|
||||
|
||||
defer.returnValue(
|
||||
RoomAliasMapping(room_id, room_alias.to_string(), servers)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_room_alias_association(self, room_alias, room_id, servers):
|
||||
""" Creates an associatin between a room alias and room_id/servers
|
||||
|
||||
Args:
|
||||
room_alias (RoomAlias)
|
||||
room_id (str)
|
||||
servers (list)
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
yield self._simple_insert(
|
||||
"room_aliases",
|
||||
{
|
||||
"room_alias": room_alias.to_string(),
|
||||
"room_id": room_id,
|
||||
},
|
||||
)
|
||||
|
||||
for server in servers:
|
||||
# TODO(erikj): Fix this to bulk insert
|
||||
yield self._simple_insert(
|
||||
"room_alias_servers",
|
||||
{
|
||||
"room_alias": room_alias.to_string(),
|
||||
"server": server,
|
||||
}
|
||||
)
|
74
synapse/storage/feedback.py
Normal file
74
synapse/storage/feedback.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ._base import SQLBaseStore, Table
|
||||
from synapse.api.events.room import FeedbackEvent
|
||||
|
||||
import collections
|
||||
import json
|
||||
|
||||
|
||||
class FeedbackStore(SQLBaseStore):
|
||||
|
||||
def store_feedback(self, room_id, msg_id, msg_sender_id,
|
||||
fb_sender_id, fb_type, content):
|
||||
return self._simple_insert(FeedbackTable.table_name, dict(
|
||||
room_id=room_id,
|
||||
msg_id=msg_id,
|
||||
msg_sender_id=msg_sender_id,
|
||||
fb_sender_id=fb_sender_id,
|
||||
fb_type=fb_type,
|
||||
content=content,
|
||||
))
|
||||
|
||||
def get_feedback(self, room_id=None, msg_id=None, msg_sender_id=None,
|
||||
fb_sender_id=None, fb_type=None):
|
||||
query = FeedbackTable.select_statement(
|
||||
"msg_sender_id = ? AND room_id = ? AND msg_id = ? " +
|
||||
"AND fb_sender_id = ? AND feedback_type = ? " +
|
||||
"ORDER BY id DESC LIMIT 1")
|
||||
return self._execute(
|
||||
FeedbackTable.decode_single_result,
|
||||
query, msg_sender_id, room_id, msg_id, fb_sender_id, fb_type,
|
||||
)
|
||||
|
||||
def get_max_feedback_id(self):
|
||||
return self._simple_max_id(FeedbackTable.table_name)
|
||||
|
||||
|
||||
class FeedbackTable(Table):
|
||||
table_name = "feedback"
|
||||
|
||||
fields = [
|
||||
"id",
|
||||
"content",
|
||||
"feedback_type",
|
||||
"fb_sender_id",
|
||||
"msg_id",
|
||||
"room_id",
|
||||
"msg_sender_id"
|
||||
]
|
||||
|
||||
class EntryType(collections.namedtuple("FeedbackEntry", fields)):
|
||||
|
||||
def as_event(self, event_factory):
|
||||
return event_factory.create_event(
|
||||
etype=FeedbackEvent.TYPE,
|
||||
room_id=self.room_id,
|
||||
msg_id=self.msg_id,
|
||||
msg_sender_id=self.msg_sender_id,
|
||||
user_id=self.fb_sender_id,
|
||||
feedback_type=self.feedback_type,
|
||||
content=json.loads(self.content),
|
||||
)
|
80
synapse/storage/message.py
Normal file
80
synapse/storage/message.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ._base import SQLBaseStore, Table
|
||||
from synapse.api.events.room import MessageEvent
|
||||
|
||||
import collections
|
||||
import json
|
||||
|
||||
|
||||
class MessageStore(SQLBaseStore):
|
||||
|
||||
def get_message(self, user_id, room_id, msg_id):
|
||||
"""Get a message from the store.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user who sent the message.
|
||||
room_id (str): The room the message was sent in.
|
||||
msg_id (str): The unique ID for this user/room combo.
|
||||
"""
|
||||
query = MessagesTable.select_statement(
|
||||
"user_id = ? AND room_id = ? AND msg_id = ? " +
|
||||
"ORDER BY id DESC LIMIT 1")
|
||||
return self._execute(
|
||||
MessagesTable.decode_single_result,
|
||||
query, user_id, room_id, msg_id,
|
||||
)
|
||||
|
||||
def store_message(self, user_id, room_id, msg_id, content):
|
||||
"""Store a message in the store.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user who sent the message.
|
||||
room_id (str): The room the message was sent in.
|
||||
msg_id (str): The unique ID for this user/room combo.
|
||||
content (str): The content of the message (JSON)
|
||||
"""
|
||||
return self._simple_insert(MessagesTable.table_name, dict(
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
msg_id=msg_id,
|
||||
content=content,
|
||||
))
|
||||
|
||||
def get_max_message_id(self):
|
||||
return self._simple_max_id(MessagesTable.table_name)
|
||||
|
||||
|
||||
class MessagesTable(Table):
|
||||
table_name = "messages"
|
||||
|
||||
fields = [
|
||||
"id",
|
||||
"user_id",
|
||||
"room_id",
|
||||
"msg_id",
|
||||
"content"
|
||||
]
|
||||
|
||||
class EntryType(collections.namedtuple("MessageEntry", fields)):
|
||||
|
||||
def as_event(self, event_factory):
|
||||
return event_factory.create_event(
|
||||
etype=MessageEvent.TYPE,
|
||||
room_id=self.room_id,
|
||||
user_id=self.user_id,
|
||||
msg_id=self.msg_id,
|
||||
content=json.loads(self.content),
|
||||
)
|
993
synapse/storage/pdu.py
Normal file
993
synapse/storage/pdu.py
Normal file
|
@ -0,0 +1,993 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ._base import SQLBaseStore, Table, JoinHelper
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PduStore(SQLBaseStore):
|
||||
"""A collection of queries for handling PDUs.
|
||||
"""
|
||||
|
||||
def get_pdu(self, pdu_id, origin):
|
||||
"""Given a pdu_id and origin, get a PDU.
|
||||
|
||||
Args:
|
||||
txn
|
||||
pdu_id (str)
|
||||
origin (str)
|
||||
|
||||
Returns:
|
||||
PduTuple: If the pdu does not exist in the database, returns None
|
||||
"""
|
||||
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_pdu_tuple, pdu_id, origin
|
||||
)
|
||||
|
||||
def _get_pdu_tuple(self, txn, pdu_id, origin):
|
||||
res = self._get_pdu_tuples(txn, [(pdu_id, origin)])
|
||||
return res[0] if res else None
|
||||
|
||||
def _get_pdu_tuples(self, txn, pdu_id_tuples):
|
||||
results = []
|
||||
for pdu_id, origin in pdu_id_tuples:
|
||||
txn.execute(
|
||||
PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"),
|
||||
(pdu_id, origin)
|
||||
)
|
||||
|
||||
edges = [
|
||||
(r.prev_pdu_id, r.prev_origin)
|
||||
for r in PduEdgesTable.decode_results(txn.fetchall())
|
||||
]
|
||||
|
||||
query = (
|
||||
"SELECT %(fields)s FROM %(pdus)s as p "
|
||||
"LEFT JOIN %(state)s as s "
|
||||
"ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
|
||||
"WHERE p.pdu_id = ? AND p.origin = ? "
|
||||
) % {
|
||||
"fields": _pdu_state_joiner.get_fields(
|
||||
PdusTable="p", StatePdusTable="s"),
|
||||
"pdus": PdusTable.table_name,
|
||||
"state": StatePdusTable.table_name,
|
||||
}
|
||||
|
||||
txn.execute(query, (pdu_id, origin))
|
||||
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
results.append(PduTuple(PduEntry(*row), edges))
|
||||
|
||||
return results
|
||||
|
||||
def get_current_state_for_context(self, context):
|
||||
"""Get a list of PDUs that represent the current state for a given
|
||||
context
|
||||
|
||||
Args:
|
||||
context (str)
|
||||
|
||||
Returns:
|
||||
list: A list of PduTuples
|
||||
"""
|
||||
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_current_state_for_context,
|
||||
context
|
||||
)
|
||||
|
||||
def _get_current_state_for_context(self, txn, context):
|
||||
query = (
|
||||
"SELECT pdu_id, origin FROM %s WHERE context = ?"
|
||||
% CurrentStateTable.table_name
|
||||
)
|
||||
|
||||
logger.debug("get_current_state %s, Args=%s", query, context)
|
||||
txn.execute(query, (context,))
|
||||
|
||||
res = txn.fetchall()
|
||||
|
||||
logger.debug("get_current_state %d results", len(res))
|
||||
|
||||
return self._get_pdu_tuples(txn, res)
|
||||
|
||||
def persist_pdu(self, prev_pdus, **cols):
|
||||
"""Inserts a (non-state) PDU into the database.
|
||||
|
||||
Args:
|
||||
txn,
|
||||
prev_pdus (list)
|
||||
**cols: The columns to insert into the PdusTable.
|
||||
"""
|
||||
return self._db_pool.runInteraction(
|
||||
self._persist_pdu, prev_pdus, cols
|
||||
)
|
||||
|
||||
def _persist_pdu(self, txn, prev_pdus, cols):
|
||||
entry = PdusTable.EntryType(
|
||||
**{k: cols.get(k, None) for k in PdusTable.fields}
|
||||
)
|
||||
|
||||
txn.execute(PdusTable.insert_statement(), entry)
|
||||
|
||||
self._handle_prev_pdus(
|
||||
txn, entry.outlier, entry.pdu_id, entry.origin,
|
||||
prev_pdus, entry.context
|
||||
)
|
||||
|
||||
def mark_pdu_as_processed(self, pdu_id, pdu_origin):
|
||||
"""Mark a received PDU as processed.
|
||||
|
||||
Args:
|
||||
txn
|
||||
pdu_id (str)
|
||||
pdu_origin (str)
|
||||
"""
|
||||
|
||||
return self._db_pool.runInteraction(
|
||||
self._mark_as_processed, pdu_id, pdu_origin
|
||||
)
|
||||
|
||||
def _mark_as_processed(self, txn, pdu_id, pdu_origin):
|
||||
txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name)
|
||||
|
||||
def get_all_pdus_from_context(self, context):
|
||||
"""Get a list of all PDUs for a given context."""
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_all_pdus_from_context, context,
|
||||
)
|
||||
|
||||
def _get_all_pdus_from_context(self, txn, context):
|
||||
query = (
|
||||
"SELECT pdu_id, origin FROM %s "
|
||||
"WHERE context = ?"
|
||||
) % PdusTable.table_name
|
||||
|
||||
txn.execute(query, (context,))
|
||||
|
||||
return self._get_pdu_tuples(txn, txn.fetchall())
|
||||
|
||||
def get_pagination(self, context, pdu_list, limit):
|
||||
"""Get a list of Pdus for a given topic that occured before (and
|
||||
including) the pdus in pdu_list. Return a list of max size `limit`.
|
||||
|
||||
Args:
|
||||
txn
|
||||
context (str)
|
||||
pdu_list (list)
|
||||
limit (int)
|
||||
|
||||
Return:
|
||||
list: A list of PduTuples
|
||||
"""
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_paginate, context, pdu_list, limit
|
||||
)
|
||||
|
||||
def _get_paginate(self, txn, context, pdu_list, limit):
|
||||
logger.debug(
|
||||
"paginate: %s, %s, %s",
|
||||
context, repr(pdu_list), limit
|
||||
)
|
||||
|
||||
# We seed the pdu_results with the things from the pdu_list.
|
||||
pdu_results = pdu_list
|
||||
|
||||
front = pdu_list
|
||||
|
||||
query = (
|
||||
"SELECT prev_pdu_id, prev_origin FROM %(edges_table)s "
|
||||
"WHERE context = ? AND pdu_id = ? AND origin = ? "
|
||||
"LIMIT ?"
|
||||
) % {
|
||||
"edges_table": PduEdgesTable.table_name,
|
||||
}
|
||||
|
||||
# We iterate through all pdu_ids in `front` to select their previous
|
||||
# pdus. These are dumped in `new_front`. We continue until we reach the
|
||||
# limit *or* new_front is empty (i.e., we've run out of things to
|
||||
# select
|
||||
while front and len(pdu_results) < limit:
|
||||
|
||||
new_front = []
|
||||
for pdu_id, origin in front:
|
||||
logger.debug(
|
||||
"_paginate_interaction: i=%s, o=%s",
|
||||
pdu_id, origin
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
query,
|
||||
(context, pdu_id, origin, limit - len(pdu_results))
|
||||
)
|
||||
|
||||
for row in txn.fetchall():
|
||||
logger.debug(
|
||||
"_paginate_interaction: got i=%s, o=%s",
|
||||
*row
|
||||
)
|
||||
new_front.append(row)
|
||||
|
||||
front = new_front
|
||||
pdu_results += new_front
|
||||
|
||||
# We also want to update the `prev_pdus` attributes before returning.
|
||||
return self._get_pdu_tuples(txn, pdu_results)
|
||||
|
||||
def get_min_depth_for_context(self, context):
|
||||
"""Get the current minimum depth for a context
|
||||
|
||||
Args:
|
||||
txn
|
||||
context (str)
|
||||
"""
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_min_depth_for_context, context
|
||||
)
|
||||
|
||||
def _get_min_depth_for_context(self, txn, context):
|
||||
return self._get_min_depth_interaction(txn, context)
|
||||
|
||||
def _get_min_depth_interaction(self, txn, context):
|
||||
txn.execute(
|
||||
"SELECT min_depth FROM %s WHERE context = ?"
|
||||
% ContextDepthTable.table_name,
|
||||
(context,)
|
||||
)
|
||||
|
||||
row = txn.fetchone()
|
||||
|
||||
return row[0] if row else None
|
||||
|
||||
def update_min_depth_for_context(self, context, depth):
|
||||
"""Update the minimum `depth` of the given context, which is the line
|
||||
where we stop paginating backwards on.
|
||||
|
||||
Args:
|
||||
context (str)
|
||||
depth (int)
|
||||
"""
|
||||
return self._db_pool.runInteraction(
|
||||
self._update_min_depth_for_context, context, depth
|
||||
)
|
||||
|
||||
def _update_min_depth_for_context(self, txn, context, depth):
|
||||
min_depth = self._get_min_depth_interaction(txn, context)
|
||||
|
||||
do_insert = depth < min_depth if min_depth else True
|
||||
|
||||
if do_insert:
|
||||
txn.execute(
|
||||
"INSERT OR REPLACE INTO %s (context, min_depth) "
|
||||
"VALUES (?,?)" % ContextDepthTable.table_name,
|
||||
(context, depth)
|
||||
)
|
||||
|
||||
def get_latest_pdus_in_context(self, context):
|
||||
"""Get's a list of the most current pdus for a given context. This is
|
||||
used when we are sending a Pdu and need to fill out the `prev_pdus`
|
||||
key
|
||||
|
||||
Args:
|
||||
txn
|
||||
context
|
||||
"""
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_latest_pdus_in_context, context
|
||||
)
|
||||
|
||||
def _get_latest_pdus_in_context(self, txn, context):
|
||||
query = (
|
||||
"SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
|
||||
"INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
|
||||
"AND f.origin = p.origin "
|
||||
"WHERE f.context = ?"
|
||||
) % {
|
||||
"pdus": PdusTable.table_name,
|
||||
"forward": PduForwardExtremitiesTable.table_name,
|
||||
}
|
||||
|
||||
logger.debug("get_prev query: %s", query)
|
||||
|
||||
txn.execute(
|
||||
query,
|
||||
(context, )
|
||||
)
|
||||
|
||||
results = txn.fetchall()
|
||||
|
||||
return [(row[0], row[1], row[2]) for row in results]
|
||||
|
||||
def get_oldest_pdus_in_context(self, context):
|
||||
"""Get a list of Pdus that we paginated beyond yet (and haven't seen).
|
||||
This list is used when we want to paginate backwards and is the list we
|
||||
send to the remote server.
|
||||
|
||||
Args:
|
||||
txn
|
||||
context (str)
|
||||
|
||||
Returns:
|
||||
list: A list of PduIdTuple.
|
||||
"""
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_oldest_pdus_in_context, context
|
||||
)
|
||||
|
||||
def _get_oldest_pdus_in_context(self, txn, context):
|
||||
txn.execute(
|
||||
"SELECT pdu_id, origin FROM %(back)s WHERE context = ?"
|
||||
% {"back": PduBackwardExtremitiesTable.table_name, },
|
||||
(context,)
|
||||
)
|
||||
return [PduIdTuple(i, o) for i, o in txn.fetchall()]
|
||||
|
||||
def is_pdu_new(self, pdu_id, origin, context, depth):
|
||||
"""For a given Pdu, try and figure out if it's 'new', i.e., if it's
|
||||
not something we got randomly from the past, for example when we
|
||||
request the current state of the room that will probably return a bunch
|
||||
of pdus from before we joined.
|
||||
|
||||
Args:
|
||||
txn
|
||||
pdu_id (str)
|
||||
origin (str)
|
||||
context (str)
|
||||
depth (int)
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
return self._db_pool.runInteraction(
|
||||
self._is_pdu_new,
|
||||
pdu_id=pdu_id,
|
||||
origin=origin,
|
||||
context=context,
|
||||
depth=depth
|
||||
)
|
||||
|
||||
def _is_pdu_new(self, txn, pdu_id, origin, context, depth):
|
||||
# If depth > min depth in back table, then we classify it as new.
|
||||
# OR if there is nothing in the back table, then it kinda needs to
|
||||
# be a new thing.
|
||||
query = (
|
||||
"SELECT min(p.depth) FROM %(edges)s as e "
|
||||
"INNER JOIN %(back)s as b "
|
||||
"ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin "
|
||||
"INNER JOIN %(pdus)s as p "
|
||||
"ON e.pdu_id = p.pdu_id AND p.origin = e.origin "
|
||||
"WHERE p.context = ?"
|
||||
) % {
|
||||
"pdus": PdusTable.table_name,
|
||||
"edges": PduEdgesTable.table_name,
|
||||
"back": PduBackwardExtremitiesTable.table_name,
|
||||
}
|
||||
|
||||
txn.execute(query, (context,))
|
||||
|
||||
min_depth, = txn.fetchone()
|
||||
|
||||
if not min_depth or depth > int(min_depth):
|
||||
logger.debug(
|
||||
"is_new true: id=%s, o=%s, d=%s min_depth=%s",
|
||||
pdu_id, origin, depth, min_depth
|
||||
)
|
||||
return True
|
||||
|
||||
# If this pdu is in the forwards table, then it also is a new one
|
||||
query = (
|
||||
"SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?"
|
||||
) % {
|
||||
"forward": PduForwardExtremitiesTable.table_name,
|
||||
}
|
||||
|
||||
txn.execute(query, (pdu_id, origin))
|
||||
|
||||
# Did we get anything?
|
||||
if txn.fetchall():
|
||||
logger.debug(
|
||||
"is_new true: id=%s, o=%s, d=%s was forward",
|
||||
pdu_id, origin, depth
|
||||
)
|
||||
return True
|
||||
|
||||
logger.debug(
|
||||
"is_new false: id=%s, o=%s, d=%s",
|
||||
pdu_id, origin, depth
|
||||
)
|
||||
|
||||
# FINE THEN. It's probably old.
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@log_function
|
||||
def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus,
|
||||
context):
|
||||
txn.executemany(
|
||||
PduEdgesTable.insert_statement(),
|
||||
[(pdu_id, origin, p[0], p[1], context) for p in prev_pdus]
|
||||
)
|
||||
|
||||
# Update the extremities table if this is not an outlier.
|
||||
if not outlier:
|
||||
|
||||
# First, we delete the new one from the forwards extremities table.
|
||||
query = (
|
||||
"DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
|
||||
% PduForwardExtremitiesTable.table_name
|
||||
)
|
||||
txn.executemany(query, prev_pdus)
|
||||
|
||||
# We only insert as a forward extremety the new pdu if there are no
|
||||
# other pdus that reference it as a prev pdu
|
||||
query = (
|
||||
"INSERT INTO %(table)s (pdu_id, origin, context) "
|
||||
"SELECT ?, ?, ? WHERE NOT EXISTS ("
|
||||
"SELECT 1 FROM %(pdu_edges)s WHERE "
|
||||
"prev_pdu_id = ? AND prev_origin = ?"
|
||||
")"
|
||||
) % {
|
||||
"table": PduForwardExtremitiesTable.table_name,
|
||||
"pdu_edges": PduEdgesTable.table_name
|
||||
}
|
||||
|
||||
logger.debug("query: %s", query)
|
||||
|
||||
txn.execute(query, (pdu_id, origin, context, pdu_id, origin))
|
||||
|
||||
# Insert all the prev_pdus as a backwards thing, they'll get
|
||||
# deleted in a second if they're incorrect anyway.
|
||||
txn.executemany(
|
||||
PduBackwardExtremitiesTable.insert_statement(),
|
||||
[(i, o, context) for i, o in prev_pdus]
|
||||
)
|
||||
|
||||
# Also delete from the backwards extremities table all ones that
|
||||
# reference pdus that we have already seen
|
||||
query = (
|
||||
"DELETE FROM %(pdu_back)s WHERE EXISTS ("
|
||||
"SELECT 1 FROM %(pdus)s AS pdus "
|
||||
"WHERE "
|
||||
"%(pdu_back)s.pdu_id = pdus.pdu_id "
|
||||
"AND %(pdu_back)s.origin = pdus.origin "
|
||||
"AND not pdus.outlier "
|
||||
")"
|
||||
) % {
|
||||
"pdu_back": PduBackwardExtremitiesTable.table_name,
|
||||
"pdus": PdusTable.table_name,
|
||||
}
|
||||
txn.execute(query)
|
||||
|
||||
|
||||
class StatePduStore(SQLBaseStore):
|
||||
"""A collection of queries for handling state PDUs.
|
||||
"""
|
||||
|
||||
def persist_state(self, prev_pdus, **cols):
|
||||
"""Inserts a state PDU into the database
|
||||
|
||||
Args:
|
||||
txn,
|
||||
prev_pdus (list)
|
||||
**cols: The columns to insert into the PdusTable and StatePdusTable
|
||||
"""
|
||||
|
||||
return self._db_pool.runInteraction(
|
||||
self._persist_state, prev_pdus, cols
|
||||
)
|
||||
|
||||
def _persist_state(self, txn, prev_pdus, cols):
|
||||
pdu_entry = PdusTable.EntryType(
|
||||
**{k: cols.get(k, None) for k in PdusTable.fields}
|
||||
)
|
||||
state_entry = StatePdusTable.EntryType(
|
||||
**{k: cols.get(k, None) for k in StatePdusTable.fields}
|
||||
)
|
||||
|
||||
logger.debug("Inserting pdu: %s", repr(pdu_entry))
|
||||
logger.debug("Inserting state: %s", repr(state_entry))
|
||||
|
||||
txn.execute(PdusTable.insert_statement(), pdu_entry)
|
||||
txn.execute(StatePdusTable.insert_statement(), state_entry)
|
||||
|
||||
self._handle_prev_pdus(
|
||||
txn,
|
||||
pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus,
|
||||
pdu_entry.context
|
||||
)
|
||||
|
||||
def get_unresolved_state_tree(self, new_state_pdu):
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_unresolved_state_tree, new_state_pdu
|
||||
)
|
||||
|
||||
@log_function
|
||||
def _get_unresolved_state_tree(self, txn, new_pdu):
|
||||
current = self._get_current_interaction(
|
||||
txn,
|
||||
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
|
||||
)
|
||||
|
||||
ReturnType = namedtuple(
|
||||
"StateReturnType", ["new_branch", "current_branch"]
|
||||
)
|
||||
return_value = ReturnType([new_pdu], [])
|
||||
|
||||
if not current:
|
||||
logger.debug("get_unresolved_state_tree No current state.")
|
||||
return return_value
|
||||
|
||||
return_value.current_branch.append(current)
|
||||
|
||||
enum_branches = self._enumerate_state_branches(
|
||||
txn, new_pdu, current
|
||||
)
|
||||
|
||||
for branch, prev_state, state in enum_branches:
|
||||
if state:
|
||||
return_value[branch].append(state)
|
||||
else:
|
||||
break
|
||||
|
||||
return return_value
|
||||
|
||||
def update_current_state(self, pdu_id, origin, context, pdu_type,
|
||||
state_key):
|
||||
return self._db_pool.runInteraction(
|
||||
self._update_current_state,
|
||||
pdu_id, origin, context, pdu_type, state_key
|
||||
)
|
||||
|
||||
def _update_current_state(self, txn, pdu_id, origin, context, pdu_type,
|
||||
state_key):
|
||||
query = (
|
||||
"INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
|
||||
) % {
|
||||
"curr": CurrentStateTable.table_name,
|
||||
"fields": CurrentStateTable.get_fields_string(),
|
||||
"qs": ", ".join(["?"] * len(CurrentStateTable.fields))
|
||||
}
|
||||
|
||||
query_args = CurrentStateTable.EntryType(
|
||||
pdu_id=pdu_id,
|
||||
origin=origin,
|
||||
context=context,
|
||||
pdu_type=pdu_type,
|
||||
state_key=state_key
|
||||
)
|
||||
|
||||
txn.execute(query, query_args)
|
||||
|
||||
def get_current_state(self, context, pdu_type, state_key):
|
||||
"""For a given context, pdu_type, state_key 3-tuple, return what is
|
||||
currently considered the current state.
|
||||
|
||||
Args:
|
||||
txn
|
||||
context (str)
|
||||
pdu_type (str)
|
||||
state_key (str)
|
||||
|
||||
Returns:
|
||||
PduEntry
|
||||
"""
|
||||
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_current_state, context, pdu_type, state_key
|
||||
)
|
||||
|
||||
def _get_current_state(self, txn, context, pdu_type, state_key):
|
||||
return self._get_current_interaction(txn, context, pdu_type, state_key)
|
||||
|
||||
def _get_current_interaction(self, txn, context, pdu_type, state_key):
|
||||
logger.debug(
|
||||
"_get_current_interaction %s %s %s",
|
||||
context, pdu_type, state_key
|
||||
)
|
||||
|
||||
fields = _pdu_state_joiner.get_fields(
|
||||
PdusTable="p", StatePdusTable="s")
|
||||
|
||||
current_query = (
|
||||
"SELECT %(fields)s FROM %(state)s as s "
|
||||
"INNER JOIN %(pdus)s as p "
|
||||
"ON s.pdu_id = p.pdu_id AND s.origin = p.origin "
|
||||
"INNER JOIN %(curr)s as c "
|
||||
"ON s.pdu_id = c.pdu_id AND s.origin = c.origin "
|
||||
"WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? "
|
||||
) % {
|
||||
"fields": fields,
|
||||
"curr": CurrentStateTable.table_name,
|
||||
"state": StatePdusTable.table_name,
|
||||
"pdus": PdusTable.table_name,
|
||||
}
|
||||
|
||||
txn.execute(
|
||||
current_query,
|
||||
(context, pdu_type, state_key)
|
||||
)
|
||||
|
||||
row = txn.fetchone()
|
||||
|
||||
result = PduEntry(*row) if row else None
|
||||
|
||||
if not result:
|
||||
logger.debug("_get_current_interaction not found")
|
||||
else:
|
||||
logger.debug(
|
||||
"_get_current_interaction found %s %s",
|
||||
result.pdu_id, result.origin
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_next_missing_pdu(self, new_pdu):
|
||||
"""When we get a new state pdu we need to check whether we need to do
|
||||
any conflict resolution, if we do then we need to check if we need
|
||||
to go back and request some more state pdus that we haven't seen yet.
|
||||
|
||||
Args:
|
||||
txn
|
||||
new_pdu
|
||||
|
||||
Returns:
|
||||
PduIdTuple: A pdu that we are missing, or None if we have all the
|
||||
pdus required to do the conflict resolution.
|
||||
"""
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_next_missing_pdu, new_pdu
|
||||
)
|
||||
|
||||
def _get_next_missing_pdu(self, txn, new_pdu):
|
||||
logger.debug(
|
||||
"get_next_missing_pdu %s %s",
|
||||
new_pdu.pdu_id, new_pdu.origin
|
||||
)
|
||||
|
||||
current = self._get_current_interaction(
|
||||
txn,
|
||||
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
|
||||
)
|
||||
|
||||
if (not current or not current.prev_state_id
|
||||
or not current.prev_state_origin):
|
||||
return None
|
||||
|
||||
# Oh look, it's a straight clobber, so wooooo almost no-op.
|
||||
if (new_pdu.prev_state_id == current.pdu_id
|
||||
and new_pdu.prev_state_origin == current.origin):
|
||||
return None
|
||||
|
||||
enum_branches = self._enumerate_state_branches(txn, new_pdu, current)
|
||||
for branch, prev_state, state in enum_branches:
|
||||
if not state:
|
||||
return PduIdTuple(
|
||||
prev_state.prev_state_id,
|
||||
prev_state.prev_state_origin
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def handle_new_state(self, new_pdu):
|
||||
"""Actually perform conflict resolution on the new_pdu on the
|
||||
assumption we have all the pdus required to perform it.
|
||||
|
||||
Args:
|
||||
new_pdu
|
||||
|
||||
Returns:
|
||||
bool: True if the new_pdu clobbered the current state, False if not
|
||||
"""
|
||||
return self._db_pool.runInteraction(
|
||||
self._handle_new_state, new_pdu
|
||||
)
|
||||
|
||||
def _handle_new_state(self, txn, new_pdu):
|
||||
logger.debug(
|
||||
"handle_new_state %s %s",
|
||||
new_pdu.pdu_id, new_pdu.origin
|
||||
)
|
||||
|
||||
current = self._get_current_interaction(
|
||||
txn,
|
||||
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
|
||||
)
|
||||
|
||||
is_current = False
|
||||
|
||||
if (not current or not current.prev_state_id
|
||||
or not current.prev_state_origin):
|
||||
# Oh, we don't have any state for this yet.
|
||||
is_current = True
|
||||
elif (current.pdu_id == new_pdu.prev_state_id
|
||||
and current.origin == new_pdu.prev_state_origin):
|
||||
# Oh! A direct clobber. Just do it.
|
||||
is_current = True
|
||||
else:
|
||||
##
|
||||
# Ok, now loop through until we get to a common ancestor.
|
||||
max_new = int(new_pdu.power_level)
|
||||
max_current = int(current.power_level)
|
||||
|
||||
enum_branches = self._enumerate_state_branches(
|
||||
txn, new_pdu, current
|
||||
)
|
||||
for branch, prev_state, state in enum_branches:
|
||||
if not state:
|
||||
raise RuntimeError(
|
||||
"Could not find state_pdu %s %s" %
|
||||
(
|
||||
prev_state.prev_state_id,
|
||||
prev_state.prev_state_origin
|
||||
)
|
||||
)
|
||||
|
||||
if branch == 0:
|
||||
max_new = max(int(state.depth), max_new)
|
||||
else:
|
||||
max_current = max(int(state.depth), max_current)
|
||||
|
||||
is_current = max_new > max_current
|
||||
|
||||
if is_current:
|
||||
logger.debug("handle_new_state make current")
|
||||
|
||||
# Right, this is a new thing, so woo, just insert it.
|
||||
txn.execute(
|
||||
"INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
|
||||
% {
|
||||
"curr": CurrentStateTable.table_name,
|
||||
"fields": CurrentStateTable.get_fields_string(),
|
||||
"qs": ", ".join(["?"] * len(CurrentStateTable.fields))
|
||||
},
|
||||
CurrentStateTable.EntryType(
|
||||
*(new_pdu.__dict__[k] for k in CurrentStateTable.fields)
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.debug("handle_new_state not current")
|
||||
|
||||
logger.debug("handle_new_state done")
|
||||
|
||||
return is_current
|
||||
|
||||
@classmethod
|
||||
@log_function
|
||||
def _enumerate_state_branches(cls, txn, pdu_a, pdu_b):
|
||||
branch_a = pdu_a
|
||||
branch_b = pdu_b
|
||||
|
||||
get_query = (
|
||||
"SELECT %(fields)s FROM %(pdus)s as p "
|
||||
"LEFT JOIN %(state)s as s "
|
||||
"ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
|
||||
"WHERE p.pdu_id = ? AND p.origin = ? "
|
||||
) % {
|
||||
"fields": _pdu_state_joiner.get_fields(
|
||||
PdusTable="p", StatePdusTable="s"),
|
||||
"pdus": PdusTable.table_name,
|
||||
"state": StatePdusTable.table_name,
|
||||
}
|
||||
|
||||
while True:
|
||||
if (branch_a.pdu_id == branch_b.pdu_id
|
||||
and branch_a.origin == branch_b.origin):
|
||||
# Woo! We found a common ancestor
|
||||
logger.debug("_enumerate_state_branches Found common ancestor")
|
||||
break
|
||||
|
||||
do_branch_a = (
|
||||
hasattr(branch_a, "prev_state_id") and
|
||||
branch_a.prev_state_id
|
||||
)
|
||||
|
||||
do_branch_b = (
|
||||
hasattr(branch_b, "prev_state_id") and
|
||||
branch_b.prev_state_id
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"do_branch_a=%s, do_branch_b=%s",
|
||||
do_branch_a, do_branch_b
|
||||
)
|
||||
|
||||
if do_branch_a and do_branch_b:
|
||||
do_branch_a = int(branch_a.depth) > int(branch_b.depth)
|
||||
|
||||
if do_branch_a:
|
||||
pdu_tuple = PduIdTuple(
|
||||
branch_a.prev_state_id,
|
||||
branch_a.prev_state_origin
|
||||
)
|
||||
|
||||
logger.debug("getting branch_a prev %s", pdu_tuple)
|
||||
txn.execute(get_query, pdu_tuple)
|
||||
|
||||
prev_branch = branch_a
|
||||
|
||||
res = txn.fetchone()
|
||||
branch_a = PduEntry(*res) if res else None
|
||||
|
||||
logger.debug("branch_a=%s", branch_a)
|
||||
|
||||
yield (0, prev_branch, branch_a)
|
||||
|
||||
if not branch_a:
|
||||
break
|
||||
elif do_branch_b:
|
||||
pdu_tuple = PduIdTuple(
|
||||
branch_b.prev_state_id,
|
||||
branch_b.prev_state_origin
|
||||
)
|
||||
txn.execute(get_query, pdu_tuple)
|
||||
|
||||
logger.debug("getting branch_b prev %s", pdu_tuple)
|
||||
|
||||
prev_branch = branch_b
|
||||
|
||||
res = txn.fetchone()
|
||||
branch_b = PduEntry(*res) if res else None
|
||||
|
||||
logger.debug("branch_b=%s", branch_b)
|
||||
|
||||
yield (1, prev_branch, branch_b)
|
||||
|
||||
if not branch_b:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
class PdusTable(Table):
|
||||
table_name = "pdus"
|
||||
|
||||
fields = [
|
||||
"pdu_id",
|
||||
"origin",
|
||||
"context",
|
||||
"pdu_type",
|
||||
"ts",
|
||||
"depth",
|
||||
"is_state",
|
||||
"content_json",
|
||||
"unrecognized_keys",
|
||||
"outlier",
|
||||
"have_processed",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("PdusEntry", fields)
|
||||
|
||||
|
||||
class PduDestinationsTable(Table):
|
||||
table_name = "pdu_destinations"
|
||||
|
||||
fields = [
|
||||
"pdu_id",
|
||||
"origin",
|
||||
"destination",
|
||||
"delivered_ts",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("PduDestinationsEntry", fields)
|
||||
|
||||
|
||||
class PduEdgesTable(Table):
|
||||
table_name = "pdu_edges"
|
||||
|
||||
fields = [
|
||||
"pdu_id",
|
||||
"origin",
|
||||
"prev_pdu_id",
|
||||
"prev_origin",
|
||||
"context"
|
||||
]
|
||||
|
||||
EntryType = namedtuple("PduEdgesEntry", fields)
|
||||
|
||||
|
||||
class PduForwardExtremitiesTable(Table):
|
||||
table_name = "pdu_forward_extremities"
|
||||
|
||||
fields = [
|
||||
"pdu_id",
|
||||
"origin",
|
||||
"context",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("PduForwardExtremitiesEntry", fields)
|
||||
|
||||
|
||||
class PduBackwardExtremitiesTable(Table):
|
||||
table_name = "pdu_backward_extremities"
|
||||
|
||||
fields = [
|
||||
"pdu_id",
|
||||
"origin",
|
||||
"context",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("PduBackwardExtremitiesEntry", fields)
|
||||
|
||||
|
||||
class ContextDepthTable(Table):
|
||||
table_name = "context_depth"
|
||||
|
||||
fields = [
|
||||
"context",
|
||||
"min_depth",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("ContextDepthEntry", fields)
|
||||
|
||||
|
||||
class StatePdusTable(Table):
|
||||
table_name = "state_pdus"
|
||||
|
||||
fields = [
|
||||
"pdu_id",
|
||||
"origin",
|
||||
"context",
|
||||
"pdu_type",
|
||||
"state_key",
|
||||
"power_level",
|
||||
"prev_state_id",
|
||||
"prev_state_origin",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("StatePdusEntry", fields)
|
||||
|
||||
|
||||
class CurrentStateTable(Table):
|
||||
table_name = "current_state"
|
||||
|
||||
fields = [
|
||||
"pdu_id",
|
||||
"origin",
|
||||
"context",
|
||||
"pdu_type",
|
||||
"state_key",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("CurrentStateEntry", fields)
|
||||
|
||||
_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable)
|
||||
|
||||
|
||||
# TODO: These should probably be put somewhere more sensible
|
||||
PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin"))
|
||||
|
||||
PduEntry = _pdu_state_joiner.EntryType
|
||||
""" We are always interested in the join of the PdusTable and StatePdusTable,
|
||||
rather than just the PdusTable.
|
||||
|
||||
This does not include a prev_pdus key.
|
||||
"""
|
||||
|
||||
PduTuple = namedtuple(
|
||||
"PduTuple",
|
||||
("pdu_entry", "prev_pdu_list")
|
||||
)
|
||||
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
|
||||
the `prev_pdus` key of a PDU.
|
||||
"""
|
103
synapse/storage/presence.py
Normal file
103
synapse/storage/presence.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
class PresenceStore(SQLBaseStore):
|
||||
def create_presence(self, user_localpart):
|
||||
return self._simple_insert(
|
||||
table="presence",
|
||||
values={"user_id": user_localpart},
|
||||
)
|
||||
|
||||
def has_presence_state(self, user_localpart):
|
||||
return self._simple_select_one(
|
||||
table="presence",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcols=["user_id"],
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
def get_presence_state(self, user_localpart):
|
||||
return self._simple_select_one(
|
||||
table="presence",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcols=["state", "status_msg"],
|
||||
)
|
||||
|
||||
def set_presence_state(self, user_localpart, new_state):
|
||||
return self._simple_update_one(
|
||||
table="presence",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
updatevalues={"state": new_state["state"],
|
||||
"status_msg": new_state["status_msg"]},
|
||||
retcols=["state"],
|
||||
)
|
||||
|
||||
def allow_presence_visible(self, observed_localpart, observer_userid):
|
||||
return self._simple_insert(
|
||||
table="presence_allow_inbound",
|
||||
values={"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid},
|
||||
)
|
||||
|
||||
def disallow_presence_visible(self, observed_localpart, observer_userid):
|
||||
return self._simple_delete_one(
|
||||
table="presence_allow_inbound",
|
||||
keyvalues={"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid},
|
||||
)
|
||||
|
||||
def is_presence_visible(self, observed_localpart, observer_userid):
|
||||
return self._simple_select_one(
|
||||
table="presence_allow_inbound",
|
||||
keyvalues={"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid},
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
def add_presence_list_pending(self, observer_localpart, observed_userid):
|
||||
return self._simple_insert(
|
||||
table="presence_list",
|
||||
values={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid,
|
||||
"accepted": False},
|
||||
)
|
||||
|
||||
def set_presence_list_accepted(self, observer_localpart, observed_userid):
|
||||
return self._simple_update_one(
|
||||
table="presence_list",
|
||||
keyvalues={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid},
|
||||
updatevalues={"accepted": True},
|
||||
)
|
||||
|
||||
def get_presence_list(self, observer_localpart, accepted=None):
|
||||
keyvalues = {"user_id": observer_localpart}
|
||||
if accepted is not None:
|
||||
keyvalues["accepted"] = accepted
|
||||
|
||||
return self._simple_select_list(
|
||||
table="presence_list",
|
||||
keyvalues=keyvalues,
|
||||
retcols=["observed_user_id", "accepted"],
|
||||
)
|
||||
|
||||
def del_presence_list(self, observer_localpart, observed_userid):
|
||||
return self._simple_delete_one(
|
||||
table="presence_list",
|
||||
keyvalues={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid},
|
||||
)
|
51
synapse/storage/profile.py
Normal file
51
synapse/storage/profile.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
class ProfileStore(SQLBaseStore):
|
||||
def create_profile(self, user_localpart):
|
||||
return self._simple_insert(
|
||||
table="profiles",
|
||||
values={"user_id": user_localpart},
|
||||
)
|
||||
|
||||
def get_profile_displayname(self, user_localpart):
|
||||
return self._simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="displayname",
|
||||
)
|
||||
|
||||
def set_profile_displayname(self, user_localpart, new_displayname):
|
||||
return self._simple_update_one(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
updatevalues={"displayname": new_displayname},
|
||||
)
|
||||
|
||||
def get_profile_avatar_url(self, user_localpart):
|
||||
return self._simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="avatar_url",
|
||||
)
|
||||
|
||||
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
|
||||
return self._simple_update_one(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
updatevalues={"avatar_url": new_avatar_url},
|
||||
)
|
113
synapse/storage/registration.py
Normal file
113
synapse/storage/registration.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from sqlite3 import IntegrityError
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
class RegistrationStore(SQLBaseStore):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RegistrationStore, self).__init__(hs)
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_access_token_to_user(self, user_id, token):
|
||||
"""Adds an access token for the given user.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
token (str): The new access token to add.
|
||||
Raises:
|
||||
StoreError if there was a problem adding this.
|
||||
"""
|
||||
row = yield self._simple_select_one("users", {"name": user_id}, ["id"])
|
||||
if not row:
|
||||
raise StoreError(400, "Bad user ID supplied.")
|
||||
row_id = row["id"]
|
||||
yield self._simple_insert(
|
||||
"access_tokens",
|
||||
{
|
||||
"user_id": row_id,
|
||||
"token": token
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register(self, user_id, token, password_hash):
|
||||
"""Attempts to register an account.
|
||||
|
||||
Args:
|
||||
user_id (str): The desired user ID to register.
|
||||
token (str): The desired access token to use for this user.
|
||||
password_hash (str): Optional. The password hash for this user.
|
||||
Raises:
|
||||
StoreError if the user_id could not be registered.
|
||||
"""
|
||||
yield self._db_pool.runInteraction(self._register, user_id, token,
|
||||
password_hash)
|
||||
|
||||
def _register(self, txn, user_id, token, password_hash):
|
||||
now = int(self.clock.time())
|
||||
|
||||
try:
|
||||
txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
|
||||
"VALUES (?,?,?)",
|
||||
[user_id, password_hash, now])
|
||||
except IntegrityError:
|
||||
raise StoreError(400, "User ID already taken.")
|
||||
|
||||
# it's possible for this to get a conflict, but only for a single user
|
||||
# since tokens are namespaced based on their user ID
|
||||
txn.execute("INSERT INTO access_tokens(user_id, token) " +
|
||||
"VALUES (?,?)", [txn.lastrowid, token])
|
||||
|
||||
def get_user_by_id(self, user_id):
|
||||
query = ("SELECT users.name, users.password_hash FROM users "
|
||||
"WHERE users.name = ?")
|
||||
return self._execute(
|
||||
self.cursor_to_dict,
|
||||
query, user_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_by_token(self, token):
|
||||
"""Get a user from the given access token.
|
||||
|
||||
Args:
|
||||
token (str): The access token of a user.
|
||||
Returns:
|
||||
str: The user ID of the user.
|
||||
Raises:
|
||||
StoreError if no user was found.
|
||||
"""
|
||||
user_id = yield self._db_pool.runInteraction(self._query_for_auth,
|
||||
token)
|
||||
defer.returnValue(user_id)
|
||||
|
||||
def _query_for_auth(self, txn, token):
|
||||
txn.execute("SELECT users.name FROM access_tokens LEFT JOIN users" +
|
||||
" ON users.id = access_tokens.user_id WHERE token = ?",
|
||||
[token])
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
return row[0]
|
||||
|
||||
raise StoreError(404, "Token not found.")
|
129
synapse/storage/room.py
Normal file
129
synapse/storage/room.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from sqlite3 import IntegrityError
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.api.events.room import RoomTopicEvent
|
||||
|
||||
from ._base import SQLBaseStore, Table
|
||||
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RoomStore(SQLBaseStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_room(self, room_id, room_creator_user_id, is_public):
|
||||
"""Stores a room.
|
||||
|
||||
Args:
|
||||
room_id (str): The desired room ID, can be None.
|
||||
room_creator_user_id (str): The user ID of the room creator.
|
||||
is_public (bool): True to indicate that this room should appear in
|
||||
public room lists.
|
||||
Raises:
|
||||
StoreError if the room could not be stored.
|
||||
"""
|
||||
try:
|
||||
yield self._simple_insert(RoomsTable.table_name, dict(
|
||||
room_id=room_id,
|
||||
creator=room_creator_user_id,
|
||||
is_public=is_public
|
||||
))
|
||||
except IntegrityError:
|
||||
raise StoreError(409, "Room ID in use.")
|
||||
except Exception as e:
|
||||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||
raise StoreError(500, "Problem creating room.")
|
||||
|
||||
def store_room_config(self, room_id, visibility):
|
||||
return self._simple_update_one(
|
||||
table=RoomsTable.table_name,
|
||||
keyvalues={"room_id": room_id},
|
||||
updatevalues={"is_public": visibility}
|
||||
)
|
||||
|
||||
def get_room(self, room_id):
|
||||
"""Retrieve a room.
|
||||
|
||||
Args:
|
||||
room_id (str): The ID of the room to retrieve.
|
||||
Returns:
|
||||
A namedtuple containing the room information, or an empty list.
|
||||
"""
|
||||
query = RoomsTable.select_statement("room_id=?")
|
||||
return self._execute(
|
||||
RoomsTable.decode_single_result, query, room_id,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms(self, is_public, with_topics):
|
||||
"""Retrieve a list of all public rooms.
|
||||
|
||||
Args:
|
||||
is_public (bool): True if the rooms returned should be public.
|
||||
with_topics (bool): True to include the current topic for the room
|
||||
in the response.
|
||||
Returns:
|
||||
A list of room dicts containing at least a "room_id" key, and a
|
||||
"topic" key if one is set and with_topic=True.
|
||||
"""
|
||||
room_data_type = RoomTopicEvent.TYPE
|
||||
public = 1 if is_public else 0
|
||||
|
||||
latest_topic = ("SELECT max(room_data.id) FROM room_data WHERE "
|
||||
+ "room_data.type = ? GROUP BY room_id")
|
||||
|
||||
query = ("SELECT rooms.*, room_data.content FROM rooms LEFT JOIN "
|
||||
+ "room_data ON rooms.room_id = room_data.room_id WHERE "
|
||||
+ "(room_data.id IN (" + latest_topic + ") "
|
||||
+ "OR room_data.id IS NULL) AND rooms.is_public = ?")
|
||||
|
||||
res = yield self._execute(
|
||||
self.cursor_to_dict, query, room_data_type, public
|
||||
)
|
||||
|
||||
# return only the keys the specification expects
|
||||
ret_keys = ["room_id", "topic"]
|
||||
|
||||
# extract topic from the json (icky) FIXME
|
||||
for i, room_row in enumerate(res):
|
||||
try:
|
||||
content_json = json.loads(room_row["content"])
|
||||
room_row["topic"] = content_json["topic"]
|
||||
except:
|
||||
pass # no topic set
|
||||
# filter the dict based on ret_keys
|
||||
res[i] = {k: v for k, v in room_row.iteritems() if k in ret_keys}
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
|
||||
class RoomsTable(Table):
|
||||
table_name = "rooms"
|
||||
|
||||
fields = [
|
||||
"room_id",
|
||||
"is_public",
|
||||
"creator"
|
||||
]
|
||||
|
||||
EntryType = collections.namedtuple("RoomEntry", fields)
|
84
synapse/storage/roomdata.py
Normal file
84
synapse/storage/roomdata.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ._base import SQLBaseStore, Table
|
||||
|
||||
import collections
|
||||
import json
|
||||
|
||||
|
||||
class RoomDataStore(SQLBaseStore):
|
||||
|
||||
"""Provides various CRUD operations for Room Events. """
|
||||
|
||||
def get_room_data(self, room_id, etype, state_key=""):
|
||||
"""Retrieve the data stored under this type and state_key.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
etype (str)
|
||||
state_key (str)
|
||||
Returns:
|
||||
namedtuple: Or None if nothing exists at this path.
|
||||
"""
|
||||
query = RoomDataTable.select_statement(
|
||||
"room_id = ? AND type = ? AND state_key = ? "
|
||||
"ORDER BY id DESC LIMIT 1"
|
||||
)
|
||||
return self._execute(
|
||||
RoomDataTable.decode_single_result,
|
||||
query, room_id, etype, state_key,
|
||||
)
|
||||
|
||||
def store_room_data(self, room_id, etype, state_key="", content=None):
|
||||
"""Stores room specific data.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
etype (str)
|
||||
state_key (str)
|
||||
data (str)- The data to store for this path in JSON.
|
||||
Returns:
|
||||
The store ID for this data.
|
||||
"""
|
||||
return self._simple_insert(RoomDataTable.table_name, dict(
|
||||
etype=etype,
|
||||
state_key=state_key,
|
||||
room_id=room_id,
|
||||
content=content,
|
||||
))
|
||||
|
||||
def get_max_room_data_id(self):
|
||||
return self._simple_max_id(RoomDataTable.table_name)
|
||||
|
||||
|
||||
class RoomDataTable(Table):
|
||||
table_name = "room_data"
|
||||
|
||||
fields = [
|
||||
"id",
|
||||
"room_id",
|
||||
"type",
|
||||
"state_key",
|
||||
"content"
|
||||
]
|
||||
|
||||
class EntryType(collections.namedtuple("RoomDataEntry", fields)):
|
||||
|
||||
def as_event(self, event_factory):
|
||||
return event_factory.create_event(
|
||||
etype=self.type,
|
||||
room_id=self.room_id,
|
||||
content=json.loads(self.content),
|
||||
)
|
171
synapse/storage/roommember.py
Normal file
171
synapse/storage/roommember.py
Normal file
|
@ -0,0 +1,171 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.events.room import RoomMemberEvent
|
||||
|
||||
from ._base import SQLBaseStore, Table
|
||||
|
||||
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RoomMemberStore(SQLBaseStore):
|
||||
|
||||
def get_room_member(self, user_id, room_id):
|
||||
"""Retrieve the current state of a room member.
|
||||
|
||||
Args:
|
||||
user_id (str): The member's user ID.
|
||||
room_id (str): The room the member is in.
|
||||
Returns:
|
||||
namedtuple: The room member from the database, or None if this
|
||||
member does not exist.
|
||||
"""
|
||||
query = RoomMemberTable.select_statement(
|
||||
"room_id = ? AND user_id = ? ORDER BY id DESC LIMIT 1")
|
||||
return self._execute(
|
||||
RoomMemberTable.decode_single_result,
|
||||
query, room_id, user_id,
|
||||
)
|
||||
|
||||
def store_room_member(self, user_id, sender, room_id, membership, content):
|
||||
"""Store a room member in the database.
|
||||
|
||||
Args:
|
||||
user_id (str): The member's user ID.
|
||||
room_id (str): The room in relation to the member.
|
||||
membership (synapse.api.constants.Membership): The new membership
|
||||
state.
|
||||
content (dict): The content of the membership (JSON).
|
||||
"""
|
||||
content_json = json.dumps(content)
|
||||
return self._simple_insert(RoomMemberTable.table_name, dict(
|
||||
user_id=user_id,
|
||||
sender=sender,
|
||||
room_id=room_id,
|
||||
membership=membership,
|
||||
content=content_json,
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_members(self, room_id, membership=None):
|
||||
"""Retrieve the current room member list for a room.
|
||||
|
||||
Args:
|
||||
room_id (str): The room to get the list of members.
|
||||
membership (synapse.api.constants.Membership): The filter to apply
|
||||
to this list, or None to return all members with some state
|
||||
associated with this room.
|
||||
Returns:
|
||||
list of namedtuples representing the members in this room.
|
||||
"""
|
||||
query = RoomMemberTable.select_statement(
|
||||
"id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
|
||||
+ " WHERE room_id = ? GROUP BY user_id)"
|
||||
)
|
||||
res = yield self._execute(
|
||||
RoomMemberTable.decode_results, query, room_id,
|
||||
)
|
||||
# strip memberships which don't match
|
||||
if membership:
|
||||
res = [entry for entry in res if entry.membership == membership]
|
||||
defer.returnValue(res)
|
||||
|
||||
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
|
||||
""" Get all the rooms for this user where the membership for this user
|
||||
matches one in the membership list.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
membership_list (list): A list of synapse.api.constants.Membership
|
||||
values which the user must be in.
|
||||
Returns:
|
||||
A list of dicts with "room_id" and "membership" keys.
|
||||
"""
|
||||
if not membership_list:
|
||||
return defer.succeed(None)
|
||||
|
||||
args = [user_id]
|
||||
membership_placeholder = ["membership=?"] * len(membership_list)
|
||||
where_membership = "(" + " OR ".join(membership_placeholder) + ")"
|
||||
for membership in membership_list:
|
||||
args.append(membership)
|
||||
|
||||
query = ("SELECT room_id, membership FROM room_memberships"
|
||||
+ " WHERE user_id=? AND " + where_membership
|
||||
+ " GROUP BY room_id ORDER BY id DESC")
|
||||
return self._execute(
|
||||
self.cursor_to_dict, query, *args
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_joined_hosts_for_room(self, room_id):
|
||||
query = RoomMemberTable.select_statement(
|
||||
"id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
|
||||
+ " WHERE room_id = ? GROUP BY user_id)"
|
||||
)
|
||||
|
||||
res = yield self._execute(
|
||||
RoomMemberTable.decode_results, query, room_id,
|
||||
)
|
||||
|
||||
def host_from_user_id_string(user_id):
|
||||
domain = UserID.from_string(entry.user_id, self.hs).domain
|
||||
return domain
|
||||
|
||||
# strip memberships which don't match
|
||||
hosts = [
|
||||
host_from_user_id_string(entry.user_id)
|
||||
for entry in res
|
||||
if entry.membership == Membership.JOIN
|
||||
]
|
||||
|
||||
logger.debug("Returning hosts: %s from results: %s", hosts, res)
|
||||
|
||||
defer.returnValue(hosts)
|
||||
|
||||
def get_max_room_member_id(self):
|
||||
return self._simple_max_id(RoomMemberTable.table_name)
|
||||
|
||||
|
||||
class RoomMemberTable(Table):
|
||||
table_name = "room_memberships"
|
||||
|
||||
fields = [
|
||||
"id",
|
||||
"user_id",
|
||||
"sender",
|
||||
"room_id",
|
||||
"membership",
|
||||
"content"
|
||||
]
|
||||
|
||||
class EntryType(collections.namedtuple("RoomMemberEntry", fields)):
|
||||
|
||||
def as_event(self, event_factory):
|
||||
return event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
room_id=self.room_id,
|
||||
target_user_id=self.user_id,
|
||||
user_id=self.sender,
|
||||
content=json.loads(self.content),
|
||||
)
|
31
synapse/storage/schema/edge_pdus.sql
Normal file
31
synapse/storage/schema/edge_pdus.sql
Normal file
|
@ -0,0 +1,31 @@
|
|||
/* Copyright 2014 matrix.org
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
CREATE TABLE IF NOT EXISTS context_edge_pdus(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
context TEXT,
|
||||
CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS origin_edge_pdus(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin);
|
||||
CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin);
|
54
synapse/storage/schema/im.sql
Normal file
54
synapse/storage/schema/im.sql
Normal file
|
@ -0,0 +1,54 @@
|
|||
/* Copyright 2014 matrix.org
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
CREATE TABLE IF NOT EXISTS rooms(
|
||||
room_id TEXT PRIMARY KEY NOT NULL,
|
||||
is_public INTEGER,
|
||||
creator TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS room_memberships(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT NOT NULL, -- no foreign key to users table, it could be an id belonging to another home server
|
||||
sender TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
membership TEXT NOT NULL,
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS messages(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT,
|
||||
room_id TEXT,
|
||||
msg_id TEXT,
|
||||
content TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS feedback(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
content TEXT,
|
||||
feedback_type TEXT,
|
||||
fb_sender_id TEXT,
|
||||
msg_id TEXT,
|
||||
room_id TEXT,
|
||||
msg_sender_id TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS room_data(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
room_id TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
state_key TEXT NOT NULL,
|
||||
content TEXT
|
||||
);
|
106
synapse/storage/schema/pdu.sql
Normal file
106
synapse/storage/schema/pdu.sql
Normal file
|
@ -0,0 +1,106 @@
|
|||
/* Copyright 2014 matrix.org
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
-- Stores pdus and their content
|
||||
CREATE TABLE IF NOT EXISTS pdus(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
context TEXT,
|
||||
pdu_type TEXT,
|
||||
ts INTEGER,
|
||||
depth INTEGER DEFAULT 0 NOT NULL,
|
||||
is_state BOOL,
|
||||
content_json TEXT,
|
||||
unrecognized_keys TEXT,
|
||||
outlier BOOL NOT NULL,
|
||||
have_processed BOOL,
|
||||
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
|
||||
);
|
||||
|
||||
-- Stores what the current state pdu is for a given (context, pdu_type, key) tuple
|
||||
CREATE TABLE IF NOT EXISTS state_pdus(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
context TEXT,
|
||||
pdu_type TEXT,
|
||||
state_key TEXT,
|
||||
power_level TEXT,
|
||||
prev_state_id TEXT,
|
||||
prev_state_origin TEXT,
|
||||
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
|
||||
CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS current_state(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
context TEXT,
|
||||
pdu_type TEXT,
|
||||
state_key TEXT,
|
||||
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
|
||||
CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
-- Stores where each pdu we want to send should be sent and the delivery status.
|
||||
create TABLE IF NOT EXISTS pdu_destinations(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
destination TEXT,
|
||||
delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
|
||||
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pdu_forward_extremities(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
context TEXT,
|
||||
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pdu_backward_extremities(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
context TEXT,
|
||||
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pdu_edges(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
prev_pdu_id TEXT,
|
||||
prev_origin TEXT,
|
||||
context TEXT,
|
||||
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS context_depth(
|
||||
context TEXT,
|
||||
min_depth INTEGER,
|
||||
CONSTRAINT uniqueness UNIQUE (context)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context);
|
||||
|
||||
|
||||
CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin);
|
||||
-- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context);
|
||||
CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context);
|
37
synapse/storage/schema/presence.sql
Normal file
37
synapse/storage/schema/presence.sql
Normal file
|
@ -0,0 +1,37 @@
|
|||
/* Copyright 2014 matrix.org
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
CREATE TABLE IF NOT EXISTS presence(
|
||||
user_id INTEGER NOT NULL,
|
||||
state INTEGER,
|
||||
status_msg TEXT,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||
);
|
||||
|
||||
-- For each of /my/ users which possibly-remote users are allowed to see their
|
||||
-- presence state
|
||||
CREATE TABLE IF NOT EXISTS presence_allow_inbound(
|
||||
observed_user_id INTEGER NOT NULL,
|
||||
observer_user_id TEXT, -- a UserID,
|
||||
FOREIGN KEY(observed_user_id) REFERENCES users(id)
|
||||
);
|
||||
|
||||
-- For each of /my/ users (watcher), which possibly-remote users are they
|
||||
-- watching?
|
||||
CREATE TABLE IF NOT EXISTS presence_list(
|
||||
user_id INTEGER NOT NULL,
|
||||
observed_user_id TEXT, -- a UserID,
|
||||
accepted BOOLEAN,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||
);
|
20
synapse/storage/schema/profiles.sql
Normal file
20
synapse/storage/schema/profiles.sql
Normal file
|
@ -0,0 +1,20 @@
|
|||
/* Copyright 2014 matrix.org
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
CREATE TABLE IF NOT EXISTS profiles(
|
||||
user_id INTEGER NOT NULL,
|
||||
displayname TEXT,
|
||||
avatar_url TEXT,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||
);
|
12
synapse/storage/schema/room_aliases.sql
Normal file
12
synapse/storage/schema/room_aliases.sql
Normal file
|
@ -0,0 +1,12 @@
|
|||
CREATE TABLE IF NOT EXISTS room_aliases(
|
||||
room_alias TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS room_alias_servers(
|
||||
room_alias TEXT NOT NULL,
|
||||
server TEXT NOT NULL
|
||||
);
|
||||
|
||||
|
||||
|
61
synapse/storage/schema/transactions.sql
Normal file
61
synapse/storage/schema/transactions.sql
Normal file
|
@ -0,0 +1,61 @@
|
|||
/* Copyright 2014 matrix.org
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
-- Stores what transaction ids we have received and what our response was
|
||||
CREATE TABLE IF NOT EXISTS received_transactions(
|
||||
transaction_id TEXT,
|
||||
origin TEXT,
|
||||
ts INTEGER,
|
||||
response_code INTEGER,
|
||||
response_json TEXT,
|
||||
has_been_referenced BOOL default 0, -- Whether thishas been referenced by a prev_tx
|
||||
CONSTRAINT uniquesss UNIQUE (transaction_id, origin) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS transactions_txid ON received_transactions(transaction_id, origin);
|
||||
CREATE INDEX IF NOT EXISTS transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0;
|
||||
|
||||
|
||||
-- Stores what transactions we've sent, what their response was (if we got one) and whether we have
|
||||
-- since referenced the transaction in another outgoing transaction
|
||||
CREATE TABLE IF NOT EXISTS sent_transactions(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT, -- This is used to apply insertion ordering
|
||||
transaction_id TEXT,
|
||||
destination TEXT,
|
||||
response_code INTEGER DEFAULT 0,
|
||||
response_json TEXT,
|
||||
ts INTEGER
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS sent_transaction_dest ON sent_transactions(destination);
|
||||
CREATE INDEX IF NOT EXISTS sent_transaction_dest_referenced ON sent_transactions(
|
||||
destination
|
||||
);
|
||||
-- So that we can do an efficient look up of all transactions that have yet to be successfully
|
||||
-- sent.
|
||||
CREATE INDEX IF NOT EXISTS sent_transaction_sent ON sent_transactions(response_code);
|
||||
|
||||
|
||||
-- For sent transactions only.
|
||||
CREATE TABLE IF NOT EXISTS transaction_id_to_pdu(
|
||||
transaction_id INTEGER,
|
||||
destination TEXT,
|
||||
pdu_id TEXT,
|
||||
pdu_origin TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination);
|
||||
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination);
|
||||
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_index ON transaction_id_to_pdu(transaction_id, destination);
|
||||
|
31
synapse/storage/schema/users.sql
Normal file
31
synapse/storage/schema/users.sql
Normal file
|
@ -0,0 +1,31 @@
|
|||
/* Copyright 2014 matrix.org
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
CREATE TABLE IF NOT EXISTS users(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT,
|
||||
password_hash TEXT,
|
||||
creation_ts INTEGER,
|
||||
UNIQUE(name) ON CONFLICT ROLLBACK
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS access_tokens(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
device_id TEXT,
|
||||
token TEXT NOT NULL,
|
||||
last_used INTEGER,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id),
|
||||
UNIQUE(token) ON CONFLICT ROLLBACK
|
||||
);
|
282
synapse/storage/stream.py
Normal file
282
synapse/storage/stream.py
Normal file
|
@ -0,0 +1,282 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from .message import MessagesTable
|
||||
from .feedback import FeedbackTable
|
||||
from .roomdata import RoomDataTable
|
||||
from .roommember import RoomMemberTable
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamStore(SQLBaseStore):
|
||||
|
||||
def get_message_stream(self, user_id, from_key, to_key, room_id, limit=0,
|
||||
with_feedback=False):
|
||||
"""Get all messages for this user between the given keys.
|
||||
|
||||
Args:
|
||||
user_id (str): The user who is requesting messages.
|
||||
from_key (int): The ID to start returning results from (exclusive).
|
||||
to_key (int): The ID to stop returning results (exclusive).
|
||||
room_id (str): Gets messages only for this room. Can be None, in
|
||||
which case all room messages will be returned.
|
||||
Returns:
|
||||
A tuple of rows (list of namedtuples), new_id(int)
|
||||
"""
|
||||
if with_feedback and room_id: # with fb MUST specify a room ID
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_message_rows_with_feedback,
|
||||
user_id, from_key, to_key, room_id, limit
|
||||
)
|
||||
else:
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_message_rows,
|
||||
user_id, from_key, to_key, room_id, limit
|
||||
)
|
||||
|
||||
def _get_message_rows(self, txn, user_id, from_pkey, to_pkey, room_id,
|
||||
limit):
|
||||
# work out which rooms this user is joined in on and join them with
|
||||
# the room id on the messages table, bounded by the specified pkeys
|
||||
|
||||
# get all messages where the *current* membership state is 'join' for
|
||||
# this user in that room.
|
||||
query = ("SELECT messages.* FROM messages WHERE ? IN"
|
||||
+ " (SELECT membership from room_memberships WHERE user_id=?"
|
||||
+ " AND room_id = messages.room_id ORDER BY id DESC LIMIT 1)")
|
||||
query_args = ["join", user_id]
|
||||
|
||||
if room_id:
|
||||
query += " AND messages.room_id=?"
|
||||
query_args.append(room_id)
|
||||
|
||||
(query, query_args) = self._append_stream_operations(
|
||||
"messages", query, query_args, from_pkey, to_pkey, limit=limit
|
||||
)
|
||||
|
||||
logger.debug("[SQL] %s : %s", query, query_args)
|
||||
cursor = txn.execute(query, query_args)
|
||||
return self._as_events(cursor, MessagesTable, from_pkey)
|
||||
|
||||
def _get_message_rows_with_feedback(self, txn, user_id, from_pkey, to_pkey,
|
||||
room_id, limit):
|
||||
# this col represents the compressed feedback JSON as per spec
|
||||
compressed_feedback_col = (
|
||||
"'[' || group_concat('{\"sender_id\":\"' || f.fb_sender_id"
|
||||
+ " || '\",\"feedback_type\":\"' || f.feedback_type"
|
||||
+ " || '\",\"content\":' || f.content || '}') || ']'"
|
||||
)
|
||||
|
||||
global_msg_id_join = ("f.room_id = messages.room_id"
|
||||
+ " and f.msg_id = messages.msg_id"
|
||||
+ " and messages.user_id = f.msg_sender_id")
|
||||
|
||||
select_query = (
|
||||
"SELECT messages.*, f.content AS fb_content, f.fb_sender_id"
|
||||
+ ", " + compressed_feedback_col + " AS compressed_fb"
|
||||
+ " FROM messages LEFT JOIN feedback f ON " + global_msg_id_join)
|
||||
|
||||
current_membership_sub_query = (
|
||||
"(SELECT membership from room_memberships rm"
|
||||
+ " WHERE user_id=? AND room_id = rm.room_id"
|
||||
+ " ORDER BY id DESC LIMIT 1)")
|
||||
|
||||
where = (" WHERE ? IN " + current_membership_sub_query
|
||||
+ " AND messages.room_id=?")
|
||||
|
||||
query = select_query + where
|
||||
query_args = ["join", user_id, room_id]
|
||||
|
||||
(query, query_args) = self._append_stream_operations(
|
||||
"messages", query, query_args, from_pkey, to_pkey,
|
||||
limit=limit, group_by=" GROUP BY messages.id "
|
||||
)
|
||||
|
||||
logger.debug("[SQL] %s : %s", query, query_args)
|
||||
cursor = txn.execute(query, query_args)
|
||||
|
||||
# convert the result set into events
|
||||
entries = self.cursor_to_dict(cursor)
|
||||
events = []
|
||||
for entry in entries:
|
||||
# TODO we should spec the cursor > event mapping somewhere else.
|
||||
event = {}
|
||||
straight_mappings = ["msg_id", "user_id", "room_id"]
|
||||
for key in straight_mappings:
|
||||
event[key] = entry[key]
|
||||
event["content"] = json.loads(entry["content"])
|
||||
if entry["compressed_fb"]:
|
||||
event["feedback"] = json.loads(entry["compressed_fb"])
|
||||
events.append(event)
|
||||
|
||||
latest_pkey = from_pkey if len(entries) == 0 else entries[-1]["id"]
|
||||
|
||||
return (events, latest_pkey)
|
||||
|
||||
def get_room_member_stream(self, user_id, from_key, to_key):
|
||||
"""Get all room membership events for this user between the given keys.
|
||||
|
||||
Args:
|
||||
user_id (str): The user who is requesting membership events.
|
||||
from_key (int): The ID to start returning results from (exclusive).
|
||||
to_key (int): The ID to stop returning results (exclusive).
|
||||
Returns:
|
||||
A tuple of rows (list of namedtuples), new_id(int)
|
||||
"""
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_room_member_rows, user_id, from_key, to_key
|
||||
)
|
||||
|
||||
def _get_room_member_rows(self, txn, user_id, from_pkey, to_pkey):
|
||||
# get all room membership events for rooms which the user is
|
||||
# *currently* joined in on, or all invite events for this user.
|
||||
current_membership_sub_query = (
|
||||
"(SELECT membership FROM room_memberships"
|
||||
+ " WHERE user_id=? AND room_id = rm.room_id"
|
||||
+ " ORDER BY id DESC LIMIT 1)")
|
||||
|
||||
query = ("SELECT rm.* FROM room_memberships rm "
|
||||
# all membership events for rooms you've currently joined.
|
||||
+ " WHERE (? IN " + current_membership_sub_query
|
||||
# all invite membership events for this user
|
||||
+ " OR rm.membership=? AND user_id=?)"
|
||||
+ " AND rm.id > ?")
|
||||
query_args = ["join", user_id, "invite", user_id, from_pkey]
|
||||
|
||||
if to_pkey != -1:
|
||||
query += " AND rm.id < ?"
|
||||
query_args.append(to_pkey)
|
||||
|
||||
cursor = txn.execute(query, query_args)
|
||||
return self._as_events(cursor, RoomMemberTable, from_pkey)
|
||||
|
||||
def get_feedback_stream(self, user_id, from_key, to_key, room_id, limit=0):
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_feedback_rows,
|
||||
user_id, from_key, to_key, room_id, limit
|
||||
)
|
||||
|
||||
def _get_feedback_rows(self, txn, user_id, from_pkey, to_pkey, room_id,
|
||||
limit):
|
||||
# work out which rooms this user is joined in on and join them with
|
||||
# the room id on the feedback table, bounded by the specified pkeys
|
||||
|
||||
# get all messages where the *current* membership state is 'join' for
|
||||
# this user in that room.
|
||||
query = (
|
||||
"SELECT feedback.* FROM feedback WHERE ? IN "
|
||||
+ "(SELECT membership from room_memberships WHERE user_id=?"
|
||||
+ " AND room_id = feedback.room_id ORDER BY id DESC LIMIT 1)")
|
||||
query_args = ["join", user_id]
|
||||
|
||||
if room_id:
|
||||
query += " AND feedback.room_id=?"
|
||||
query_args.append(room_id)
|
||||
|
||||
(query, query_args) = self._append_stream_operations(
|
||||
"feedback", query, query_args, from_pkey, to_pkey, limit=limit
|
||||
)
|
||||
|
||||
logger.debug("[SQL] %s : %s", query, query_args)
|
||||
cursor = txn.execute(query, query_args)
|
||||
return self._as_events(cursor, FeedbackTable, from_pkey)
|
||||
|
||||
def get_room_data_stream(self, user_id, from_key, to_key, room_id,
|
||||
limit=0):
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_room_data_rows,
|
||||
user_id, from_key, to_key, room_id, limit
|
||||
)
|
||||
|
||||
def _get_room_data_rows(self, txn, user_id, from_pkey, to_pkey, room_id,
|
||||
limit):
|
||||
# work out which rooms this user is joined in on and join them with
|
||||
# the room id on the feedback table, bounded by the specified pkeys
|
||||
|
||||
# get all messages where the *current* membership state is 'join' for
|
||||
# this user in that room.
|
||||
query = (
|
||||
"SELECT room_data.* FROM room_data WHERE ? IN "
|
||||
+ "(SELECT membership from room_memberships WHERE user_id=?"
|
||||
+ " AND room_id = room_data.room_id ORDER BY id DESC LIMIT 1)")
|
||||
query_args = ["join", user_id]
|
||||
|
||||
if room_id:
|
||||
query += " AND room_data.room_id=?"
|
||||
query_args.append(room_id)
|
||||
|
||||
(query, query_args) = self._append_stream_operations(
|
||||
"room_data", query, query_args, from_pkey, to_pkey, limit=limit
|
||||
)
|
||||
|
||||
logger.debug("[SQL] %s : %s", query, query_args)
|
||||
cursor = txn.execute(query, query_args)
|
||||
return self._as_events(cursor, RoomDataTable, from_pkey)
|
||||
|
||||
def _append_stream_operations(self, table_name, query, query_args,
|
||||
from_pkey, to_pkey, limit=None,
|
||||
group_by=""):
|
||||
LATEST_ROW = -1
|
||||
order_by = ""
|
||||
if to_pkey > from_pkey:
|
||||
if from_pkey != LATEST_ROW:
|
||||
# e.g. from=5 to=9 >> from 5 to 9 >> id>5 AND id<9
|
||||
query += (" AND %s.id > ? AND %s.id < ?" %
|
||||
(table_name, table_name))
|
||||
query_args.append(from_pkey)
|
||||
query_args.append(to_pkey)
|
||||
else:
|
||||
# e.g. from=-1 to=5 >> from now to 5 >> id>5 ORDER BY id DESC
|
||||
query += " AND %s.id > ? " % table_name
|
||||
order_by = "ORDER BY id DESC"
|
||||
query_args.append(to_pkey)
|
||||
elif from_pkey > to_pkey:
|
||||
if to_pkey != LATEST_ROW:
|
||||
# from=9 to=5 >> from 9 to 5 >> id>5 AND id<9 ORDER BY id DESC
|
||||
query += (" AND %s.id > ? AND %s.id < ? " %
|
||||
(table_name, table_name))
|
||||
order_by = "ORDER BY id DESC"
|
||||
query_args.append(to_pkey)
|
||||
query_args.append(from_pkey)
|
||||
else:
|
||||
# from=5 to=-1 >> from 5 to now >> id>5
|
||||
query += " AND %s.id > ?" % table_name
|
||||
query_args.append(from_pkey)
|
||||
|
||||
query += group_by + order_by
|
||||
|
||||
if limit and limit > 0:
|
||||
query += " LIMIT ?"
|
||||
query_args.append(str(limit))
|
||||
|
||||
return (query, query_args)
|
||||
|
||||
def _as_events(self, cursor, table, from_pkey):
|
||||
data_entries = table.decode_results(cursor)
|
||||
last_pkey = from_pkey
|
||||
if data_entries:
|
||||
last_pkey = data_entries[-1].id
|
||||
|
||||
events = [
|
||||
entry.as_event(self.event_factory).get_dict()
|
||||
for entry in data_entries
|
||||
]
|
||||
|
||||
return (events, last_pkey)
|
287
synapse/storage/transactions.py
Normal file
287
synapse/storage/transactions.py
Normal file
|
@ -0,0 +1,287 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ._base import SQLBaseStore, Table
|
||||
from .pdu import PdusTable
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransactionStore(SQLBaseStore):
|
||||
"""A collection of queries for handling PDUs.
|
||||
"""
|
||||
|
||||
def get_received_txn_response(self, transaction_id, origin):
|
||||
"""For an incoming transaction from a given origin, check if we have
|
||||
already responded to it. If so, return the response code and response
|
||||
body (as a dict).
|
||||
|
||||
Args:
|
||||
transaction_id (str)
|
||||
origin(str)
|
||||
|
||||
Returns:
|
||||
tuple: None if we have not previously responded to
|
||||
this transaction or a 2-tuple of (int, dict)
|
||||
"""
|
||||
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_received_txn_response, transaction_id, origin
|
||||
)
|
||||
|
||||
def _get_received_txn_response(self, txn, transaction_id, origin):
|
||||
where_clause = "transaction_id = ? AND origin = ?"
|
||||
query = ReceivedTransactionsTable.select_statement(where_clause)
|
||||
|
||||
txn.execute(query, (transaction_id, origin))
|
||||
|
||||
results = ReceivedTransactionsTable.decode_results(txn.fetchall())
|
||||
|
||||
if results and results[0].response_code:
|
||||
return (results[0].response_code, results[0].response_json)
|
||||
else:
|
||||
return None
|
||||
|
||||
def set_received_txn_response(self, transaction_id, origin, code,
|
||||
response_dict):
|
||||
"""Persist the response we returened for an incoming transaction, and
|
||||
should return for subsequent transactions with the same transaction_id
|
||||
and origin.
|
||||
|
||||
Args:
|
||||
txn
|
||||
transaction_id (str)
|
||||
origin (str)
|
||||
code (int)
|
||||
response_json (str)
|
||||
"""
|
||||
|
||||
return self._db_pool.runInteraction(
|
||||
self._set_received_txn_response,
|
||||
transaction_id, origin, code, response_dict
|
||||
)
|
||||
|
||||
def _set_received_txn_response(self, txn, transaction_id, origin, code,
|
||||
response_json):
|
||||
query = (
|
||||
"UPDATE %s "
|
||||
"SET response_code = ?, response_json = ? "
|
||||
"WHERE transaction_id = ? AND origin = ?"
|
||||
) % ReceivedTransactionsTable.table_name
|
||||
|
||||
txn.execute(query, (code, response_json, transaction_id, origin))
|
||||
|
||||
def prep_send_transaction(self, transaction_id, destination, ts, pdu_list):
|
||||
"""Persists an outgoing transaction and calculates the values for the
|
||||
previous transaction id list.
|
||||
|
||||
This should be called before sending the transaction so that it has the
|
||||
correct value for the `prev_ids` key.
|
||||
|
||||
Args:
|
||||
transaction_id (str)
|
||||
destination (str)
|
||||
ts (int)
|
||||
pdu_list (list)
|
||||
|
||||
Returns:
|
||||
list: A list of previous transaction ids.
|
||||
"""
|
||||
|
||||
return self._db_pool.runInteraction(
|
||||
self._prep_send_transaction,
|
||||
transaction_id, destination, ts, pdu_list
|
||||
)
|
||||
|
||||
def _prep_send_transaction(self, txn, transaction_id, destination, ts,
|
||||
pdu_list):
|
||||
|
||||
# First we find out what the prev_txs should be.
|
||||
# Since we know that we are only sending one transaction at a time,
|
||||
# we can simply take the last one.
|
||||
query = "%s ORDER BY id DESC LIMIT 1" % (
|
||||
SentTransactions.select_statement("destination = ?"),
|
||||
)
|
||||
|
||||
results = txn.execute(query, (destination,))
|
||||
results = SentTransactions.decode_results(results)
|
||||
|
||||
prev_txns = [r.transaction_id for r in results]
|
||||
|
||||
# Actually add the new transaction to the sent_transactions table.
|
||||
|
||||
query = SentTransactions.insert_statement()
|
||||
txn.execute(query, SentTransactions.EntryType(
|
||||
None,
|
||||
transaction_id=transaction_id,
|
||||
destination=destination,
|
||||
ts=ts,
|
||||
response_code=0,
|
||||
response_json=None
|
||||
))
|
||||
|
||||
# Update the tx id -> pdu id mapping
|
||||
|
||||
values = [
|
||||
(transaction_id, destination, pdu[0], pdu[1])
|
||||
for pdu in pdu_list
|
||||
]
|
||||
|
||||
logger.debug("Inserting: %s", repr(values))
|
||||
|
||||
query = TransactionsToPduTable.insert_statement()
|
||||
txn.executemany(query, values)
|
||||
|
||||
return prev_txns
|
||||
|
||||
def delivered_txn(self, transaction_id, destination, code, response_dict):
|
||||
"""Persists the response for an outgoing transaction.
|
||||
|
||||
Args:
|
||||
transaction_id (str)
|
||||
destination (str)
|
||||
code (int)
|
||||
response_json (str)
|
||||
"""
|
||||
return self._db_pool.runInteraction(
|
||||
self._delivered_txn,
|
||||
transaction_id, destination, code, response_dict
|
||||
)
|
||||
|
||||
def _delivered_txn(cls, txn, transaction_id, destination,
|
||||
code, response_json):
|
||||
query = (
|
||||
"UPDATE %s "
|
||||
"SET response_code = ?, response_json = ? "
|
||||
"WHERE transaction_id = ? AND destination = ?"
|
||||
) % SentTransactions.table_name
|
||||
|
||||
txn.execute(query, (code, response_json, transaction_id, destination))
|
||||
|
||||
def get_transactions_after(self, transaction_id, destination):
|
||||
"""Get all transactions after a given local transaction_id.
|
||||
|
||||
Args:
|
||||
transaction_id (str)
|
||||
destination (str)
|
||||
|
||||
Returns:
|
||||
list: A list of `ReceivedTransactionsTable.EntryType`
|
||||
"""
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_transactions_after, transaction_id, destination
|
||||
)
|
||||
|
||||
def _get_transactions_after(cls, txn, transaction_id, destination):
|
||||
where = (
|
||||
"destination = ? AND id > (select id FROM %s WHERE "
|
||||
"transaction_id = ? AND destination = ?)"
|
||||
) % (
|
||||
SentTransactions.table_name
|
||||
)
|
||||
query = SentTransactions.select_statement(where)
|
||||
|
||||
txn.execute(query, (destination, transaction_id, destination))
|
||||
|
||||
return ReceivedTransactionsTable.decode_results(txn.fetchall())
|
||||
|
||||
def get_pdus_after_transaction(self, transaction_id, destination):
|
||||
"""For a given local transaction_id that we sent to a given destination
|
||||
home server, return a list of PDUs that were sent to that destination
|
||||
after it.
|
||||
|
||||
Args:
|
||||
txn
|
||||
transaction_id (str)
|
||||
destination (str)
|
||||
|
||||
Returns
|
||||
list: A list of PduTuple
|
||||
"""
|
||||
return self._db_pool.runInteraction(
|
||||
self._get_pdus_after_transaction,
|
||||
transaction_id, destination
|
||||
)
|
||||
|
||||
def _get_pdus_after_transaction(self, txn, transaction_id, destination):
|
||||
|
||||
# Query that first get's all transaction_ids with an id greater than
|
||||
# the one given from the `sent_transactions` table. Then JOIN on this
|
||||
# from the `tx->pdu` table to get a list of (pdu_id, origin) that
|
||||
# specify the pdus that were sent in those transactions.
|
||||
query = (
|
||||
"SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp "
|
||||
"INNER JOIN %(sent_tx)s as st "
|
||||
"ON tp.transaction_id = st.transaction_id "
|
||||
"AND tp.destination = st.destination "
|
||||
"WHERE st.id > ("
|
||||
"SELECT id FROM %(sent_tx)s "
|
||||
"WHERE transaction_id = ? AND destination = ?"
|
||||
) % {
|
||||
"tx_pdu": TransactionsToPduTable.table_name,
|
||||
"sent_tx": SentTransactions.table_name,
|
||||
}
|
||||
|
||||
txn.execute(query, (transaction_id, destination))
|
||||
|
||||
pdus = PdusTable.decode_results(txn.fetchall())
|
||||
|
||||
return self._get_pdu_tuples(txn, pdus)
|
||||
|
||||
|
||||
class ReceivedTransactionsTable(Table):
|
||||
table_name = "received_transactions"
|
||||
|
||||
fields = [
|
||||
"transaction_id",
|
||||
"origin",
|
||||
"ts",
|
||||
"response_code",
|
||||
"response_json",
|
||||
"has_been_referenced",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("ReceivedTransactionsEntry", fields)
|
||||
|
||||
|
||||
class SentTransactions(Table):
|
||||
table_name = "sent_transactions"
|
||||
|
||||
fields = [
|
||||
"id",
|
||||
"transaction_id",
|
||||
"destination",
|
||||
"ts",
|
||||
"response_code",
|
||||
"response_json",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("SentTransactionsEntry", fields)
|
||||
|
||||
|
||||
class TransactionsToPduTable(Table):
|
||||
table_name = "transaction_id_to_pdu"
|
||||
|
||||
fields = [
|
||||
"transaction_id",
|
||||
"destination",
|
||||
"pdu_id",
|
||||
"pdu_origin",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("TransactionsToPduEntry", fields)
|
79
synapse/types.py
Normal file
79
synapse/types.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
class DomainSpecificString(
|
||||
namedtuple("DomainSpecificString", ("localpart", "domain", "is_mine"))
|
||||
):
|
||||
"""Common base class among ID/name strings that have a local part and a
|
||||
domain name, prefixed with a sigil.
|
||||
|
||||
Has the fields:
|
||||
|
||||
'localpart' : The local part of the name (without the leading sigil)
|
||||
'domain' : The domain part of the name
|
||||
'is_mine' : Boolean indicating if the domain name is recognised by the
|
||||
HomeServer as being its own
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, s, hs):
|
||||
"""Parse the string given by 's' into a structure object."""
|
||||
if s[0] != cls.SIGIL:
|
||||
raise SynapseError(400, "Expected %s string to start with '%s'" % (
|
||||
cls.__name__, cls.SIGIL,
|
||||
))
|
||||
|
||||
parts = s[1:].split(':', 1)
|
||||
if len(parts) != 2:
|
||||
raise SynapseError(
|
||||
400, "Expected %s of the form '%slocalname:domain'" % (
|
||||
cls.__name__, cls.SIGIL,
|
||||
)
|
||||
)
|
||||
|
||||
domain = parts[1]
|
||||
|
||||
# This code will need changing if we want to support multiple domain
|
||||
# names on one HS
|
||||
is_mine = domain == hs.hostname
|
||||
return cls(localpart=parts[0], domain=domain, is_mine=is_mine)
|
||||
|
||||
def to_string(self):
|
||||
"""Return a string encoding the fields of the structure object."""
|
||||
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
|
||||
|
||||
@classmethod
|
||||
def create_local(cls, localpart, hs):
|
||||
"""Create a structure on the local domain"""
|
||||
return cls(localpart=localpart, domain=hs.hostname, is_mine=True)
|
||||
|
||||
|
||||
class UserID(DomainSpecificString):
|
||||
"""Structure representing a user ID."""
|
||||
SIGIL = "@"
|
||||
|
||||
|
||||
class RoomAlias(DomainSpecificString):
|
||||
"""Structure representing a room name."""
|
||||
SIGIL = "#"
|
||||
|
||||
|
||||
class RoomID(DomainSpecificString):
|
||||
"""Structure representing a room id. """
|
||||
SIGIL = "!"
|
40
synapse/util/__init__.py
Normal file
40
synapse/util/__init__.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import reactor
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class Clock(object):
|
||||
"""A small utility that obtains current time-of-day so that time may be
|
||||
mocked during unit-tests.
|
||||
|
||||
TODO(paul): Also move the sleep() functionallity into it
|
||||
"""
|
||||
|
||||
def time(self):
|
||||
"""Returns the current system time in seconds since epoch."""
|
||||
return time.time()
|
||||
|
||||
def time_msec(self):
|
||||
"""Returns the current system time in miliseconds since epoch."""
|
||||
return self.time() * 1000
|
||||
|
||||
def call_later(self, delay, callback):
|
||||
return reactor.callLater(delay, callback)
|
||||
|
||||
def cancel_call_later(self, timer):
|
||||
timer.cancel()
|
22
synapse/util/async.py
Normal file
22
synapse/util/async.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
|
||||
|
||||
def sleep(seconds):
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(seconds, d.callback, seconds)
|
||||
return d
|
108
synapse/util/distributor.py
Normal file
108
synapse/util/distributor.py
Normal file
|
@ -0,0 +1,108 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Distributor(object):
|
||||
"""A central dispatch point for loosely-connected pieces of code to
|
||||
register, observe, and fire signals.
|
||||
|
||||
Signals are named simply by strings.
|
||||
|
||||
TODO(paul): It would be nice to give signals stronger object identities,
|
||||
so we can attach metadata, docstrings, detect typoes, etc... But this
|
||||
model will do for today.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.signals = {}
|
||||
self.pre_registration = {}
|
||||
|
||||
def declare(self, name):
|
||||
if name in self.signals:
|
||||
raise KeyError("%r already has a signal named %s" % (self, name))
|
||||
|
||||
self.signals[name] = Signal(name)
|
||||
|
||||
if name in self.pre_registration:
|
||||
signal = self.signals[name]
|
||||
for observer in self.pre_registration[name]:
|
||||
signal.observe(observer)
|
||||
|
||||
def observe(self, name, observer):
|
||||
if name in self.signals:
|
||||
self.signals[name].observe(observer)
|
||||
else:
|
||||
# TODO: Avoid strong ordering dependency by allowing people to
|
||||
# pre-register observations on signals that don't exist yet.
|
||||
if name not in self.pre_registration:
|
||||
self.pre_registration[name] = []
|
||||
self.pre_registration[name].append(observer)
|
||||
|
||||
def fire(self, name, *args, **kwargs):
|
||||
if name not in self.signals:
|
||||
raise KeyError("%r does not have a signal named %s" % (self, name))
|
||||
|
||||
return self.signals[name].fire(*args, **kwargs)
|
||||
|
||||
|
||||
class Signal(object):
|
||||
"""A Signal is a dispatch point that stores a list of callables as
|
||||
observers of it.
|
||||
|
||||
Signals can be "fired", meaning that every callable observing it is
|
||||
invoked. Firing a signal does not change its state; it can be fired again
|
||||
at any later point. Firing a signal passes any arguments from the fire
|
||||
method into all of the observers.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.observers = []
|
||||
|
||||
def observe(self, observer):
|
||||
"""Adds a new callable to the observer list which will be invoked by
|
||||
the 'fire' method.
|
||||
|
||||
Each observer callable may return a Deferred."""
|
||||
self.observers.append(observer)
|
||||
|
||||
def fire(self, *args, **kwargs):
|
||||
"""Invokes every callable in the observer list, passing in the args and
|
||||
kwargs. Exceptions thrown by observers are logged but ignored. It is
|
||||
not an error to fire a signal with no observers.
|
||||
|
||||
Returns a Deferred that will complete when all the observers have
|
||||
completed."""
|
||||
deferreds = []
|
||||
for observer in self.observers:
|
||||
d = defer.maybeDeferred(observer, *args, **kwargs)
|
||||
|
||||
def eb(failure):
|
||||
logger.warning(
|
||||
"%s signal observer %s failed: %r",
|
||||
self.name, observer, failure,
|
||||
exc_info=(
|
||||
failure.type,
|
||||
failure.value,
|
||||
failure.getTracebackObject()))
|
||||
deferreds.append(d.addErrback(eb))
|
||||
|
||||
return defer.DeferredList(deferreds)
|
98
synapse/util/jsonobject.py
Normal file
98
synapse/util/jsonobject.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
|
||||
class JsonEncodedObject(object):
|
||||
""" A common base class for defining protocol units that are represented
|
||||
as JSON.
|
||||
|
||||
Attributes:
|
||||
unrecognized_keys (dict): A dict containing all the key/value pairs we
|
||||
don't recognize.
|
||||
"""
|
||||
|
||||
valid_keys = [] # keys we will store
|
||||
"""A list of strings that represent keys we know about
|
||||
and can handle. If we have values for these keys they will be
|
||||
included in the `dictionary` instance variable.
|
||||
"""
|
||||
|
||||
internal_keys = [] # keys to ignore while building dict
|
||||
"""A list of strings that should *not* be encoded into JSON.
|
||||
"""
|
||||
|
||||
required_keys = []
|
||||
"""A list of strings that we require to exist. If they are not given upon
|
||||
construction it raises an exception.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
""" Takes the dict of `kwargs` and loads all keys that are *valid*
|
||||
(i.e., are included in the `valid_keys` list) into the dictionary`
|
||||
instance variable.
|
||||
|
||||
Any keys that aren't recognized are added to the `unrecognized_keys`
|
||||
attribute.
|
||||
|
||||
Args:
|
||||
**kwargs: Attributes associated with this protocol unit.
|
||||
"""
|
||||
for required_key in self.required_keys:
|
||||
if required_key not in kwargs:
|
||||
raise RuntimeError("Key %s is required" % required_key)
|
||||
|
||||
self.unrecognized_keys = {} # Keys we were given not listed as valid
|
||||
for k, v in kwargs.items():
|
||||
if k in self.valid_keys or k in self.internal_keys:
|
||||
self.__dict__[k] = v
|
||||
else:
|
||||
self.unrecognized_keys[k] = v
|
||||
|
||||
def get_dict(self):
|
||||
""" Converts this protocol unit into a :py:class:`dict`, ready to be
|
||||
encoded as JSON.
|
||||
|
||||
The keys it encodes are: `valid_keys` - `internal_keys`
|
||||
|
||||
Returns
|
||||
dict
|
||||
"""
|
||||
d = {
|
||||
k: _encode(v) for (k, v) in self.__dict__.items()
|
||||
if k in self.valid_keys and k not in self.internal_keys
|
||||
}
|
||||
d.update(self.unrecognized_keys)
|
||||
return copy.deepcopy(d)
|
||||
|
||||
def get_full_dict(self):
|
||||
d = {
|
||||
k: v for (k, v) in self.__dict__.items()
|
||||
if k in self.valid_keys or k in self.internal_keys
|
||||
}
|
||||
d.update(self.unrecognized_keys)
|
||||
return copy.deepcopy(d)
|
||||
|
||||
def __str__(self):
|
||||
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
|
||||
|
||||
def _encode(obj):
|
||||
if type(obj) is list:
|
||||
return [_encode(o) for o in obj]
|
||||
|
||||
if isinstance(obj, JsonEncodedObject):
|
||||
return obj.get_dict()
|
||||
|
||||
return obj
|
67
synapse/util/lockutils.py
Normal file
67
synapse/util/lockutils.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Lock(object):
|
||||
|
||||
def __init__(self, deferred):
|
||||
self._deferred = deferred
|
||||
self.released = False
|
||||
|
||||
def release(self):
|
||||
self.released = True
|
||||
self._deferred.callback(None)
|
||||
|
||||
def __del__(self):
|
||||
if not self.released:
|
||||
logger.critical("Lock was destructed but never released!")
|
||||
self.release()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.release()
|
||||
|
||||
|
||||
class LockManager(object):
|
||||
""" Utility class that allows us to lock based on a `key` """
|
||||
|
||||
def __init__(self):
|
||||
self._lock_deferreds = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def lock(self, key):
|
||||
""" Allows us to block until it is our turn.
|
||||
Args:
|
||||
key (str)
|
||||
Returns:
|
||||
Lock
|
||||
"""
|
||||
new_deferred = defer.Deferred()
|
||||
old_deferred = self._lock_deferreds.get(key)
|
||||
self._lock_deferreds[key] = new_deferred
|
||||
|
||||
if old_deferred:
|
||||
yield old_deferred
|
||||
|
||||
defer.returnValue(Lock(new_deferred))
|
65
synapse/util/logutils.py
Normal file
65
synapse/util/logutils.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from inspect import getcallargs
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
def log_function(f):
|
||||
""" Function decorator that logs every call to that function.
|
||||
"""
|
||||
func_name = f.__name__
|
||||
lineno = f.func_code.co_firstlineno
|
||||
pathname = f.func_code.co_filename
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
name = f.__module__
|
||||
logger = logging.getLogger(name)
|
||||
level = logging.DEBUG
|
||||
|
||||
if logger.isEnabledFor(level):
|
||||
bound_args = getcallargs(f, *args, **kwargs)
|
||||
|
||||
def format(value):
|
||||
r = str(value)
|
||||
if len(r) > 50:
|
||||
r = r[:50] + "..."
|
||||
return r
|
||||
|
||||
func_args = [
|
||||
"%s=%s" % (k, format(v)) for k, v in bound_args.items()
|
||||
]
|
||||
|
||||
msg_args = {
|
||||
"func_name": func_name,
|
||||
"args": ", ".join(func_args)
|
||||
}
|
||||
|
||||
record = logging.LogRecord(
|
||||
name=name,
|
||||
level=level,
|
||||
pathname=pathname,
|
||||
lineno=lineno,
|
||||
msg="Invoked '%(func_name)s' with args: %(args)s",
|
||||
args=msg_args,
|
||||
exc_info=None
|
||||
)
|
||||
|
||||
logger.handle(record)
|
||||
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapped
|
24
synapse/util/stringutils.py
Normal file
24
synapse/util/stringutils.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
import string
|
||||
|
||||
|
||||
def origin_from_ucid(ucid):
|
||||
return ucid.split("@", 1)[1]
|
||||
|
||||
|
||||
def random_string(length):
|
||||
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
|
Loading…
Add table
Add a link
Reference in a new issue