anonymousland-synapse/synapse/federation/sender/transaction_manager.py

173 lines
5.8 KiB
Python
Raw Normal View History

2019-03-13 16:02:56 -04:00
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
2020-08-12 09:03:08 -04:00
from typing import TYPE_CHECKING, List, Tuple
2019-03-13 16:02:56 -04:00
from canonicaljson import json
2019-03-13 16:02:56 -04:00
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
2019-03-13 16:02:56 -04:00
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.logging.opentracing import (
extract_text_map,
set_tag,
start_active_span_follows_from,
tags,
whitelisted_homeserver,
)
2019-03-13 16:02:56 -04:00
from synapse.util.metrics import measure_func
if TYPE_CHECKING:
import synapse.server
2019-03-13 16:02:56 -04:00
logger = logging.getLogger(__name__)
class TransactionManager(object):
"""Helper class which handles building and sending transactions
shared between PerDestinationQueue objects
"""
2019-06-20 05:32:02 -04:00
def __init__(self, hs: "synapse.server.HomeServer"):
2019-03-13 16:02:56 -04:00
self._server_name = hs.hostname
2019-06-20 05:32:02 -04:00
self.clock = hs.get_clock() # nb must be called this for @measure_func
2019-03-13 16:02:56 -04:00
self._store = hs.get_datastore()
self._transaction_actions = TransactionActions(self._store)
self._transport_layer = hs.get_federation_transport_client()
# HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec())
@measure_func("_send_new_transaction")
async def send_new_transaction(
2020-08-12 09:03:08 -04:00
self,
destination: str,
pending_pdus: List[Tuple[EventBase, int]],
pending_edus: List[Edu],
):
2019-03-13 16:02:56 -04:00
# Make a transaction-sending opentracing span. This span follows on from
# all the edus in that transaction. This needs to be done since there is
# no active span here, so if the edus were not received by the remote the
# span would have no causality and it would be forgotten.
2019-03-13 16:02:56 -04:00
span_contexts = []
keep_destination = whitelisted_homeserver(destination)
for edu in pending_edus:
context = edu.get_context()
if context:
span_contexts.append(extract_text_map(json.loads(context)))
if keep_destination:
edu.strip_context()
2019-03-13 16:02:56 -04:00
with start_active_span_follows_from("send_transaction", span_contexts):
2019-03-13 16:02:56 -04:00
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
2019-03-13 16:02:56 -04:00
success = True
2019-03-13 16:02:56 -04:00
logger.debug("TX [%s] _attempt_new_transaction", destination)
2019-03-13 16:02:56 -04:00
txn_id = str(self._next_txn_id)
2019-03-13 16:02:56 -04:00
logger.debug(
"TX [%s] {%s} Attempting new transaction (pdus: %d, edus: %d)",
destination,
txn_id,
len(pdus),
len(edus),
)
2019-03-13 16:02:56 -04:00
transaction = Transaction.create_new(
origin_server_ts=int(self.clock.time_msec()),
transaction_id=txn_id,
origin=self._server_name,
destination=destination,
pdus=pdus,
edus=edus,
2019-03-13 16:02:56 -04:00
)
self._next_txn_id += 1
2019-03-13 16:02:56 -04:00
logger.info(
"TX [%s] {%s} Sending transaction [%s], (PDUs: %d, EDUs: %d)",
destination,
txn_id,
transaction.transaction_id,
len(pdus),
len(edus),
)
2019-03-13 16:02:56 -04:00
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self.clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
try:
response = await self._transport_layer.send_transaction(
transaction, json_data_cb
)
code = 200
except HttpResponseException as e:
code = e.code
response = e.response
if e.code in (401, 404, 429) or 500 <= e.code:
logger.info(
"TX [%s] {%s} got %d response", destination, txn_id, code
)
raise e
logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
if code == 200:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warning(
"TX [%s] {%s} Remote returned error for %s: %s",
destination,
txn_id,
e_id,
r,
)
else:
for p in pdus:
logger.warning(
"TX [%s] {%s} Failed to send event %s",
2019-06-20 05:32:02 -04:00
destination,
txn_id,
p.event_id,
2019-03-13 16:02:56 -04:00
)
success = False
2019-03-13 16:02:56 -04:00
set_tag(tags.ERROR, not success)
return success