# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Set, Tuple

from twisted.trial import unittest

from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.storage.databases.main.events import _LinkMap
from synapse.types import create_requester

from tests.unittest import HomeserverTestCase


class EventChainStoreTestCase(HomeserverTestCase):
    def prepare(self, reactor, clock, hs):
        self.store = hs.get_datastores().main
        self._next_stream_ordering = 1

    def test_simple(self):
        """Test that the example in `docs/auth_chain_difference_algorithm.md`
        works.
        """

        event_factory = self.hs.get_event_builder_factory()
        bob = "@creator:test"
        alice = "@alice:test"
        room_id = "!room:test"

        # Ensure that we have a rooms entry so that we generate the chain index.
        self.get_success(
            self.store.store_room(
                room_id=room_id,
                room_creator_user_id="",
                is_public=True,
                room_version=RoomVersions.V6,
            )
        )

        create = self.get_success(
            event_factory.for_room_version(
                RoomVersions.V6,
                {
                    "type": EventTypes.Create,
                    "state_key": "",
                    "sender": bob,
                    "room_id": room_id,
                    "content": {"tag": "create"},
                },
            ).build(prev_event_ids=[], auth_event_ids=[])
        )

        bob_join = self.get_success(
            event_factory.for_room_version(
                RoomVersions.V6,
                {
                    "type": EventTypes.Member,
                    "state_key": bob,
                    "sender": bob,
                    "room_id": room_id,
                    "content": {"tag": "bob_join"},
                },
            ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
        )

        power = self.get_success(
            event_factory.for_room_version(
                RoomVersions.V6,
                {
                    "type": EventTypes.PowerLevels,
                    "state_key": "",
                    "sender": bob,
                    "room_id": room_id,
                    "content": {"tag": "power"},
                },
            ).build(
                prev_event_ids=[],
                auth_event_ids=[create.event_id, bob_join.event_id],
            )
        )

        alice_invite = self.get_success(
            event_factory.for_room_version(
                RoomVersions.V6,
                {
                    "type": EventTypes.Member,
                    "state_key": alice,
                    "sender": bob,
                    "room_id": room_id,
                    "content": {"tag": "alice_invite"},
                },
            ).build(
                prev_event_ids=[],
                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
            )
        )

        alice_join = self.get_success(
            event_factory.for_room_version(
                RoomVersions.V6,
                {
                    "type": EventTypes.Member,
                    "state_key": alice,
                    "sender": alice,
                    "room_id": room_id,
                    "content": {"tag": "alice_join"},
                },
            ).build(
                prev_event_ids=[],
                auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
            )
        )

        power_2 = self.get_success(
            event_factory.for_room_version(
                RoomVersions.V6,
                {
                    "type": EventTypes.PowerLevels,
                    "state_key": "",
                    "sender": bob,
                    "room_id": room_id,
                    "content": {"tag": "power_2"},
                },
            ).build(
                prev_event_ids=[],
                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
            )
        )

        bob_join_2 = self.get_success(
            event_factory.for_room_version(
                RoomVersions.V6,
                {
                    "type": EventTypes.Member,
                    "state_key": bob,
                    "sender": bob,
                    "room_id": room_id,
                    "content": {"tag": "bob_join_2"},
                },
            ).build(
                prev_event_ids=[],
                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
            )
        )

        alice_join2 = self.get_success(
            event_factory.for_room_version(
                RoomVersions.V6,
                {
                    "type": EventTypes.Member,
                    "state_key": alice,
                    "sender": alice,
                    "room_id": room_id,
                    "content": {"tag": "alice_join2"},
                },
            ).build(
                prev_event_ids=[],
                auth_event_ids=[
                    create.event_id,
                    alice_join.event_id,
                    power_2.event_id,
                ],
            )
        )

        events = [
            create,
            bob_join,
            power,
            alice_invite,
            alice_join,
            bob_join_2,
            power_2,
            alice_join2,
        ]

        expected_links = [
            (bob_join, create),
            (power, create),
            (power, bob_join),
            (alice_invite, create),
            (alice_invite, power),
            (alice_invite, bob_join),
            (bob_join_2, power),
            (alice_join2, power_2),
        ]

        self.persist(events)
        chain_map, link_map = self.fetch_chains(events)

        # Check that the expected links and only the expected links have been
        # added.
        self.assertEqual(len(expected_links), len(list(link_map.get_additions())))

        for start, end in expected_links:
            start_id, start_seq = chain_map[start.event_id]
            end_id, end_seq = chain_map[end.event_id]

            self.assertIn(
                (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
            )

        # Test that everything can reach the create event, but the create event
        # can't reach anything.
        for event in events[1:]:
            self.assertTrue(
                link_map.exists_path_from(
                    chain_map[event.event_id], chain_map[create.event_id]
                ),
            )

            self.assertFalse(
                link_map.exists_path_from(
                    chain_map[create.event_id],
                    chain_map[event.event_id],
                ),
            )

    def test_out_of_order_events(self):
        """Test that we handle persisting events that we don't have the full
        auth chain for yet (which should only happen for out of band memberships).
        """
        event_factory = self.hs.get_event_builder_factory()
        bob = "@creator:test"
        alice = "@alice:test"
        room_id = "!room:test"

        # Ensure that we have a rooms entry so that we generate the chain index.
        self.get_success(
            self.store.store_room(
                room_id=room_id,
                room_creator_user_id="",
                is_public=True,
                room_version=RoomVersions.V6,
            )
        )

        # First persist the base room.
        create = self.get_success(
            event_factory.for_room_version(
                RoomVersions.V6,
                {
                    "type": EventTypes.Create,
                    "state_key": "",
                    "sender": bob,
                    "room_id": room_id,
                    "content": {"tag": "create"},
                },
            ).build(prev_event_ids=[], auth_event_ids=[])
        )

        bob_join = self.get_success(
            event_factory.for_room_version(
                RoomVersions.V6,
                {
                    "type": EventTypes.Member,
                    "state_key": bob,
                    "sender": bob,
                    "room_id": room_id,
                    "content": {"tag": "bob_join"},
                },
            ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
        )

        power = self.get_success(
            event_factory.for_room_version(
                RoomVersions.V6,
                {
                    "type": EventTypes.PowerLevels,
                    "state_key": "",
                    "sender": bob,
                    "room_id": room_id,
                    "content": {"tag": "power"},
                },
            ).build(
                prev_event_ids=[],
                auth_event_ids=[create.event_id, bob_join.event_id],
            )
        )

        self.persist([create, bob_join, power])

        # Now persist an invite and a couple of memberships out of order.
        alice_invite = self.get_success(
            event_factory.for_room_version(
                RoomVersions.V6,
                {
                    "type": EventTypes.Member,
                    "state_key": alice,
                    "sender": bob,
                    "room_id": room_id,
                    "content": {"tag": "alice_invite"},
                },
            ).build(
                prev_event_ids=[],
                auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
            )
        )

        alice_join = self.get_success(
            event_factory.for_room_version(
                RoomVersions.V6,
                {
                    "type": EventTypes.Member,
                    "state_key": alice,
                    "sender": alice,
                    "room_id": room_id,
                    "content": {"tag": "alice_join"},
                },
            ).build(
                prev_event_ids=[],
                auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
            )
        )

        alice_join2 = self.get_success(
            event_factory.for_room_version(
                RoomVersions.V6,
                {
                    "type": EventTypes.Member,
                    "state_key": alice,
                    "sender": alice,
                    "room_id": room_id,
                    "content": {"tag": "alice_join2"},
                },
            ).build(
                prev_event_ids=[],
                auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
            )
        )

        self.persist([alice_join])
        self.persist([alice_join2])
        self.persist([alice_invite])

        # The end result should be sane.
        events = [create, bob_join, power, alice_invite, alice_join]

        chain_map, link_map = self.fetch_chains(events)

        expected_links = [
            (bob_join, create),
            (power, create),
            (power, bob_join),
            (alice_invite, create),
            (alice_invite, power),
            (alice_invite, bob_join),
        ]

        # Check that the expected links and only the expected links have been
        # added.
        self.assertEqual(len(expected_links), len(list(link_map.get_additions())))

        for start, end in expected_links:
            start_id, start_seq = chain_map[start.event_id]
            end_id, end_seq = chain_map[end.event_id]

            self.assertIn(
                (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
            )

    def persist(
        self,
        events: List[EventBase],
    ):
        """Persist the given events and check that the links generated match
        those given.
        """

        persist_events_store = self.hs.get_datastores().persist_events

        for e in events:
            e.internal_metadata.stream_ordering = self._next_stream_ordering
            self._next_stream_ordering += 1

        def _persist(txn):
            # We need to persist the events to the events and state_events
            # tables.
            persist_events_store._store_event_txn(
                txn,
                [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
            )

            # Actually call the function that calculates the auth chain stuff.
            persist_events_store._persist_event_auth_chain_txn(txn, events)

        self.get_success(
            persist_events_store.db_pool.runInteraction(
                "_persist",
                _persist,
            )
        )

    def fetch_chains(
        self, events: List[EventBase]
    ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:

        # Fetch the map from event ID -> (chain ID, sequence number)
        rows = self.get_success(
            self.store.db_pool.simple_select_many_batch(
                table="event_auth_chains",
                column="event_id",
                iterable=[e.event_id for e in events],
                retcols=("event_id", "chain_id", "sequence_number"),
                keyvalues={},
            )
        )

        chain_map = {
            row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
        }

        # Fetch all the links and pass them to the _LinkMap.
        rows = self.get_success(
            self.store.db_pool.simple_select_many_batch(
                table="event_auth_chain_links",
                column="origin_chain_id",
                iterable=[chain_id for chain_id, _ in chain_map.values()],
                retcols=(
                    "origin_chain_id",
                    "origin_sequence_number",
                    "target_chain_id",
                    "target_sequence_number",
                ),
                keyvalues={},
            )
        )

        link_map = _LinkMap()
        for row in rows:
            added = link_map.add_link(
                (row["origin_chain_id"], row["origin_sequence_number"]),
                (row["target_chain_id"], row["target_sequence_number"]),
            )

            # We shouldn't have persisted any redundant links
            self.assertTrue(added)

        return chain_map, link_map


class LinkMapTestCase(unittest.TestCase):
    def test_simple(self):
        """Basic tests for the LinkMap."""
        link_map = _LinkMap()

        link_map.add_link((1, 1), (2, 1), new=False)
        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
        self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
        self.assertCountEqual(link_map.get_additions(), [])
        self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
        self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
        self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
        self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))

        # Attempting to add a redundant link is ignored.
        self.assertFalse(link_map.add_link((1, 4), (2, 1)))
        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])

        # Adding new non-redundant links works
        self.assertTrue(link_map.add_link((1, 3), (2, 3)))
        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])

        self.assertTrue(link_map.add_link((2, 5), (1, 3)))
        self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
        self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])

        self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])


class EventChainBackgroundUpdateTestCase(HomeserverTestCase):

    servlets = [
        admin.register_servlets,
        room.register_servlets,
        login.register_servlets,
    ]

    def prepare(self, reactor, clock, hs):
        self.store = hs.get_datastores().main
        self.user_id = self.register_user("foo", "pass")
        self.token = self.login("foo", "pass")
        self.requester = create_requester(self.user_id)

    def _generate_room(self) -> Tuple[str, List[Set[str]]]:
        """Insert a room without a chain cover index."""
        room_id = self.helper.create_room_as(self.user_id, tok=self.token)

        # Mark the room as not having a chain cover index
        self.get_success(
            self.store.db_pool.simple_update(
                table="rooms",
                keyvalues={"room_id": room_id},
                updatevalues={"has_auth_chain_index": False},
                desc="test",
            )
        )

        # Create a fork in the DAG with different events.
        event_handler = self.hs.get_event_creation_handler()
        latest_event_ids = self.get_success(
            self.store.get_prev_events_for_room(room_id)
        )
        event, context = self.get_success(
            event_handler.create_event(
                self.requester,
                {
                    "type": "some_state_type",
                    "state_key": "",
                    "content": {},
                    "room_id": room_id,
                    "sender": self.user_id,
                },
                prev_event_ids=latest_event_ids,
            )
        )
        self.get_success(
            event_handler.handle_new_client_event(self.requester, event, context)
        )
        state1 = set(self.get_success(context.get_current_state_ids()).values())

        event, context = self.get_success(
            event_handler.create_event(
                self.requester,
                {
                    "type": "some_state_type",
                    "state_key": "",
                    "content": {},
                    "room_id": room_id,
                    "sender": self.user_id,
                },
                prev_event_ids=latest_event_ids,
            )
        )
        self.get_success(
            event_handler.handle_new_client_event(self.requester, event, context)
        )
        state2 = set(self.get_success(context.get_current_state_ids()).values())

        # Delete the chain cover info.

        def _delete_tables(txn):
            txn.execute("DELETE FROM event_auth_chains")
            txn.execute("DELETE FROM event_auth_chain_links")

        self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))

        return room_id, [state1, state2]

    def test_background_update_single_room(self):
        """Test that the background update to calculate auth chains for historic
        rooms works correctly.
        """

        # Create a room
        room_id, states = self._generate_room()

        # Insert and run the background update.
        self.get_success(
            self.store.db_pool.simple_insert(
                "background_updates",
                {"update_name": "chain_cover", "progress_json": "{}"},
            )
        )

        # Ugh, have to reset this flag
        self.store.db_pool.updates._all_done = False

        self.wait_for_background_updates()

        # Test that the `has_auth_chain_index` has been set
        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))

        # Test that calculating the auth chain difference using the newly
        # calculated chain cover works.
        self.get_success(
            self.store.db_pool.runInteraction(
                "test",
                self.store._get_auth_chain_difference_using_cover_index_txn,
                room_id,
                states,
            )
        )

    def test_background_update_multiple_rooms(self):
        """Test that the background update to calculate auth chains for historic
        rooms works correctly.
        """
        # Create a room
        room_id1, states1 = self._generate_room()
        room_id2, states2 = self._generate_room()
        room_id3, states2 = self._generate_room()

        # Insert and run the background update.
        self.get_success(
            self.store.db_pool.simple_insert(
                "background_updates",
                {"update_name": "chain_cover", "progress_json": "{}"},
            )
        )

        # Ugh, have to reset this flag
        self.store.db_pool.updates._all_done = False

        self.wait_for_background_updates()

        # Test that the `has_auth_chain_index` has been set
        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))

        # Test that calculating the auth chain difference using the newly
        # calculated chain cover works.
        self.get_success(
            self.store.db_pool.runInteraction(
                "test",
                self.store._get_auth_chain_difference_using_cover_index_txn,
                room_id1,
                states1,
            )
        )

    def test_background_update_single_large_room(self):
        """Test that the background update to calculate auth chains for historic
        rooms works correctly.
        """

        # Create a room
        room_id, states = self._generate_room()

        # Add a bunch of state so that it takes multiple iterations of the
        # background update to process the room.
        for i in range(0, 150):
            self.helper.send_state(
                room_id, event_type="m.test", body={"index": i}, tok=self.token
            )

        # Insert and run the background update.
        self.get_success(
            self.store.db_pool.simple_insert(
                "background_updates",
                {"update_name": "chain_cover", "progress_json": "{}"},
            )
        )

        # Ugh, have to reset this flag
        self.store.db_pool.updates._all_done = False

        iterations = 0
        while not self.get_success(
            self.store.db_pool.updates.has_completed_background_updates()
        ):
            iterations += 1
            self.get_success(
                self.store.db_pool.updates.do_next_background_update(False), by=0.1
            )

        # Ensure that we did actually take multiple iterations to process the
        # room.
        self.assertGreater(iterations, 1)

        # Test that the `has_auth_chain_index` has been set
        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))

        # Test that calculating the auth chain difference using the newly
        # calculated chain cover works.
        self.get_success(
            self.store.db_pool.runInteraction(
                "test",
                self.store._get_auth_chain_difference_using_cover_index_txn,
                room_id,
                states,
            )
        )

    def test_background_update_multiple_large_room(self):
        """Test that the background update to calculate auth chains for historic
        rooms works correctly.
        """

        # Create the rooms
        room_id1, _ = self._generate_room()
        room_id2, _ = self._generate_room()

        # Add a bunch of state so that it takes multiple iterations of the
        # background update to process the room.
        for i in range(0, 150):
            self.helper.send_state(
                room_id1, event_type="m.test", body={"index": i}, tok=self.token
            )

        for i in range(0, 150):
            self.helper.send_state(
                room_id2, event_type="m.test", body={"index": i}, tok=self.token
            )

        # Insert and run the background update.
        self.get_success(
            self.store.db_pool.simple_insert(
                "background_updates",
                {"update_name": "chain_cover", "progress_json": "{}"},
            )
        )

        # Ugh, have to reset this flag
        self.store.db_pool.updates._all_done = False

        iterations = 0
        while not self.get_success(
            self.store.db_pool.updates.has_completed_background_updates()
        ):
            iterations += 1
            self.get_success(
                self.store.db_pool.updates.do_next_background_update(False), by=0.1
            )

        # Ensure that we did actually take multiple iterations to process the
        # room.
        self.assertGreater(iterations, 1)

        # Test that the `has_auth_chain_index` has been set
        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))