Add foreign key constraint to event_forward_extremities. (#15751)

This commit is contained in:
Erik Johnston 2023-07-05 10:43:19 +01:00 committed by GitHub
parent c303eca8cc
commit 95a96b21eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 699 additions and 11 deletions

1
changelog.d/15751.misc Normal file
View File

@ -0,0 +1 @@
Add foreign key constraint to `event_forward_extremities`.

View File

@ -61,6 +61,7 @@ from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpda
from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
from synapse.storage.databases.main.event_push_actions import EventPushActionsStore from synapse.storage.databases.main.event_push_actions import EventPushActionsStore
from synapse.storage.databases.main.events_bg_updates import ( from synapse.storage.databases.main.events_bg_updates import (
EventsBackgroundUpdatesStore, EventsBackgroundUpdatesStore,
@ -239,6 +240,7 @@ class Store(
PresenceBackgroundUpdateStore, PresenceBackgroundUpdateStore,
ReceiptsBackgroundUpdateStore, ReceiptsBackgroundUpdateStore,
RelationsWorkerStore, RelationsWorkerStore,
EventFederationWorkerStore,
): ):
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)

View File

@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import logging import logging
from enum import IntEnum from enum import Enum, IntEnum
from types import TracebackType from types import TracebackType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@ -24,12 +25,16 @@ from typing import (
Iterable, Iterable,
List, List,
Optional, Optional,
Sequence,
Tuple,
Type, Type,
) )
import attr import attr
from pydantic import BaseModel
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection, Cursor from synapse.storage.types import Connection, Cursor
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock, json_encoder from synapse.util import Clock, json_encoder
@ -48,6 +53,78 @@ DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]] MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
class Constraint(metaclass=abc.ABCMeta):
"""Base class representing different constraints.
Used by `register_background_validate_constraint_and_delete_rows`.
"""
@abc.abstractmethod
def make_check_clause(self, table: str) -> str:
"""Returns an SQL expression that checks the row passes the constraint."""
pass
@abc.abstractmethod
def make_constraint_clause_postgres(self) -> str:
"""Returns an SQL clause for creating the constraint.
Only used on Postgres DBs
"""
pass
@attr.s(auto_attribs=True)
class ForeignKeyConstraint(Constraint):
"""A foreign key constraint.
Attributes:
referenced_table: The "parent" table name.
columns: The list of mappings of columns from table to referenced table
"""
referenced_table: str
columns: Sequence[Tuple[str, str]]
def make_check_clause(self, table: str) -> str:
join_clause = " AND ".join(
f"{col1} = {table}.{col2}" for col1, col2 in self.columns
)
return f"EXISTS (SELECT 1 FROM {self.referenced_table} WHERE {join_clause})"
def make_constraint_clause_postgres(self) -> str:
column1_list = ", ".join(col1 for col1, col2 in self.columns)
column2_list = ", ".join(col2 for col1, col2 in self.columns)
return f"FOREIGN KEY ({column1_list}) REFERENCES {self.referenced_table} ({column2_list})"
@attr.s(auto_attribs=True)
class NotNullConstraint(Constraint):
"""A NOT NULL column constraint"""
column: str
def make_check_clause(self, table: str) -> str:
return f"{self.column} IS NOT NULL"
def make_constraint_clause_postgres(self) -> str:
return f"CHECK ({self.column} IS NOT NULL)"
class ValidateConstraintProgress(BaseModel):
"""The format of the progress JSON for validate constraint background
updates.
Used by `register_background_validate_constraint_and_delete_rows`.
"""
class State(str, Enum):
check = "check"
validate = "validate"
state: State = State.validate
lower_bound: Sequence[Any] = ()
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class _BackgroundUpdateHandler: class _BackgroundUpdateHandler:
"""A handler for a given background update. """A handler for a given background update.
@ -740,6 +817,179 @@ class BackgroundUpdater:
logger.info("Adding index %s to %s", index_name, table) logger.info("Adding index %s to %s", index_name, table)
await self.db_pool.runWithConnection(runner) await self.db_pool.runWithConnection(runner)
def register_background_validate_constraint_and_delete_rows(
self,
update_name: str,
table: str,
constraint_name: str,
constraint: Constraint,
unique_columns: Sequence[str],
) -> None:
"""Helper for store classes to do a background validate constraint, and
delete rows that do not pass the constraint check.
Note: This deletes rows that don't match the constraint. This may not be
appropriate in all situations, and so the suitability of using this
method should be considered on a case-by-case basis.
This only applies on PostgreSQL.
For SQLite the table gets recreated as part of the schema delta and the
data is copied over synchronously (or whatever the correct way to
describe it as).
Args:
update_name: The name of the background update.
table: The table with the invalid constraint.
constraint_name: The name of the constraint
constraint: A `Constraint` object matching the type of constraint.
unique_columns: A sequence of columns that form a unique constraint
on the table. Used to iterate over the table.
"""
assert isinstance(
self.db_pool.engine, engines.PostgresEngine
), "validate constraint background update registered for non-Postres database"
async def updater(progress: JsonDict, batch_size: int) -> int:
return await self.validate_constraint_and_delete_in_background(
update_name=update_name,
table=table,
constraint_name=constraint_name,
constraint=constraint,
unique_columns=unique_columns,
progress=progress,
batch_size=batch_size,
)
self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
updater, oneshot=True
)
async def validate_constraint_and_delete_in_background(
self,
update_name: str,
table: str,
constraint_name: str,
constraint: Constraint,
unique_columns: Sequence[str],
progress: JsonDict,
batch_size: int,
) -> int:
"""Validates a table constraint that has been marked as `NOT VALID`,
deleting rows that don't pass the constraint check.
This will delete rows that do not meet the validation check.
update_name: str,
table: str,
constraint_name: str,
constraint: Constraint,
unique_columns: Sequence[str],
"""
# We validate the constraint by:
# 1. Trying to validate the constraint as is. If this succeeds then
# we're done.
# 2. Otherwise, we manually scan the table to remove rows that don't
# match the constraint.
# 3. We try re-validating the constraint.
parsed_progress = ValidateConstraintProgress.parse_obj(progress)
if parsed_progress.state == ValidateConstraintProgress.State.check:
return_columns = ", ".join(unique_columns)
order_columns = ", ".join(unique_columns)
where_clause = ""
args: List[Any] = []
if parsed_progress.lower_bound:
where_clause = f"""WHERE ({order_columns}) > ({", ".join("?" for _ in unique_columns)})"""
args.extend(parsed_progress.lower_bound)
args.append(batch_size)
sql = f"""
SELECT
{return_columns},
{constraint.make_check_clause(table)} AS check
FROM {table}
{where_clause}
ORDER BY {order_columns}
LIMIT ?
"""
def validate_constraint_in_background_check(
txn: "LoggingTransaction",
) -> None:
txn.execute(sql, args)
rows = txn.fetchall()
new_progress = parsed_progress.copy()
if not rows:
new_progress.state = ValidateConstraintProgress.State.validate
self._background_update_progress_txn(
txn, update_name, new_progress.dict()
)
return
new_progress.lower_bound = rows[-1][:-1]
to_delete = [row[:-1] for row in rows if not row[-1]]
if to_delete:
logger.warning(
"Deleting %d rows that do not pass new constraint",
len(to_delete),
)
self.db_pool.simple_delete_many_batch_txn(
txn, table=table, keys=unique_columns, values=to_delete
)
self._background_update_progress_txn(
txn, update_name, new_progress.dict()
)
await self.db_pool.runInteraction(
"validate_constraint_in_background_check",
validate_constraint_in_background_check,
)
return batch_size
elif parsed_progress.state == ValidateConstraintProgress.State.validate:
sql = f"ALTER TABLE {table} VALIDATE CONSTRAINT {constraint_name}"
def validate_constraint_in_background_validate(
txn: "LoggingTransaction",
) -> None:
txn.execute(sql)
try:
await self.db_pool.runInteraction(
"validate_constraint_in_background_validate",
validate_constraint_in_background_validate,
)
await self._end_background_update(update_name)
except self.db_pool.engine.module.IntegrityError as e:
# If we get an integrity error here, then we go back and recheck the table.
logger.warning("Integrity error when validating constraint: %s", e)
await self._background_update_progress(
update_name,
ValidateConstraintProgress(
state=ValidateConstraintProgress.State.check
).dict(),
)
return batch_size
else:
raise Exception(
f"Unrecognized state '{parsed_progress.state}' when trying to validate_constraint_and_delete_in_background"
)
async def _end_background_update(self, update_name: str) -> None: async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue. """Removes a completed background update task from the queue.
@ -795,3 +1045,86 @@ class BackgroundUpdater:
keyvalues={"update_name": update_name}, keyvalues={"update_name": update_name},
updatevalues={"progress_json": progress_json}, updatevalues={"progress_json": progress_json},
) )
def run_validate_constraint_and_delete_rows_schema_delta(
txn: "LoggingTransaction",
ordering: int,
update_name: str,
table: str,
constraint_name: str,
constraint: Constraint,
sqlite_table_name: str,
sqlite_table_schema: str,
) -> None:
"""Runs a schema delta to add a constraint to the table. This should be run
in a schema delta file.
For PostgreSQL the constraint is added and validated in the background.
For SQLite the table is recreated and data copied across immediately. This
is done by the caller passing in a script to create the new table. Note that
table indexes and triggers are copied over automatically.
There must be a corresponding call to
`register_background_validate_constraint_and_delete_rows` to register the
background update in one of the data store classes.
Attributes:
txn ordering, update_name: For adding a row to background_updates table.
table: The table to add constraint to. constraint_name: The name of the
new constraint constraint: A `Constraint` object describing the
constraint sqlite_table_name: For SQLite the name of the empty copy of
table sqlite_table_schema: A SQL script for creating the above table.
"""
if isinstance(txn.database_engine, PostgresEngine):
# For postgres we can just add the constraint and mark it as NOT VALID,
# and then insert a background update to go and check the validity in
# the background.
txn.execute(
f"""
ALTER TABLE {table}
ADD CONSTRAINT {constraint_name} {constraint.make_constraint_clause_postgres()}
NOT VALID
"""
)
txn.execute(
"INSERT INTO background_updates (ordering, update_name, progress_json) VALUES (?, ?, '{}')",
(ordering, update_name),
)
else:
# For SQLite, we:
# 1. fetch all indexes/triggers/etc related to the table
# 2. create an empty copy of the table
# 3. copy across the rows (that satisfy the check)
# 4. replace the old table with the new able.
# 5. add back all the indexes/triggers/etc
# Fetch the indexes/triggers/etc. Note that `sql` column being null is
# due to indexes being auto created based on the class definition (e.g.
# PRIMARY KEY), and so don't need to be recreated.
txn.execute(
"""
SELECT sql FROM sqlite_master
WHERE tbl_name = ? AND type != 'table' AND sql IS NOT NULL
""",
(table,),
)
extras = [row[0] for row in txn]
txn.execute(sqlite_table_schema)
sql = f"""
INSERT INTO {sqlite_table_name} SELECT * FROM {table}
WHERE {constraint.make_check_clause(table)}
"""
txn.execute(sql)
txn.execute(f"DROP TABLE {table}")
txn.execute(f"ALTER TABLE {sqlite_table_name} RENAME TO {table}")
for extra in extras:
txn.execute(extra)

View File

@ -2313,6 +2313,43 @@ class DatabasePool:
return txn.rowcount return txn.rowcount
@staticmethod
def simple_delete_many_batch_txn(
txn: LoggingTransaction,
table: str,
keys: Collection[str],
values: Iterable[Iterable[Any]],
) -> None:
"""Executes a DELETE query on the named table.
The input is given as a list of rows, where each row is a list of values.
(Actually any iterable is fine.)
Args:
txn: The transaction to use.
table: string giving the table name
keys: list of column names
values: for each row, a list of values in the same order as `keys`
"""
if isinstance(txn.database_engine, PostgresEngine):
# We use `execute_values` as it can be a lot faster than `execute_batch`,
# but it's only available on postgres.
sql = "DELETE FROM %s WHERE (%s) IN (VALUES ?)" % (
table,
", ".join(k for k in keys),
)
txn.execute_values(sql, values, fetch=False)
else:
sql = "DELETE FROM %s WHERE (%s) = (%s)" % (
table,
", ".join(k for k in keys),
", ".join("?" for _ in keys),
)
txn.execute_batch(sql, values)
def get_cache_dict( def get_cache_dict(
self, self,
db_conn: LoggingDatabaseConnection, db_conn: LoggingDatabaseConnection,

View File

@ -38,6 +38,7 @@ from synapse.events import EventBase, make_event_from_dict
from synapse.logging.opentracing import tag_args, trace from synapse.logging.opentracing import tag_args, trace
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.background_updates import ForeignKeyConstraint
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
LoggingDatabaseConnection, LoggingDatabaseConnection,
@ -140,6 +141,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000) self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000)
if isinstance(self.database_engine, PostgresEngine):
self.db_pool.updates.register_background_validate_constraint_and_delete_rows(
update_name="event_forward_extremities_event_id_foreign_key_constraint_update",
table="event_forward_extremities",
constraint_name="event_forward_extremities_event_id",
constraint=ForeignKeyConstraint("events", [("event_id", "event_id")]),
unique_columns=("event_id", "room_id"),
)
async def get_auth_chain( async def get_auth_chain(
self, room_id: str, event_ids: Collection[str], include_given: bool = False self, room_id: str, event_ids: Collection[str], include_given: bool = False
) -> List[EventBase]: ) -> List[EventBase]:

View File

@ -415,12 +415,6 @@ class PersistEventsStore:
backfilled=False, backfilled=False,
) )
self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremities,
max_stream_order=max_stream_order,
)
# Ensure that we don't have the same event twice. # Ensure that we don't have the same event twice.
events_and_contexts = self._filter_events_and_contexts_for_duplicates( events_and_contexts = self._filter_events_and_contexts_for_duplicates(
events_and_contexts events_and_contexts
@ -439,6 +433,12 @@ class PersistEventsStore:
self._store_event_txn(txn, events_and_contexts=events_and_contexts) self._store_event_txn(txn, events_and_contexts=events_and_contexts)
self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremities,
max_stream_order=max_stream_order,
)
self._persist_transaction_ids_txn(txn, events_and_contexts) self._persist_transaction_ids_txn(txn, events_and_contexts)
# Insert into event_to_state_groups. # Insert into event_to_state_groups.

View File

@ -0,0 +1,51 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This migration adds foreign key constraint to `event_forward_extremities` table.
"""
from synapse.storage.background_updates import (
ForeignKeyConstraint,
run_validate_constraint_and_delete_rows_schema_delta,
)
from synapse.storage.database import LoggingTransaction
from synapse.storage.engines import BaseDatabaseEngine
FORWARD_EXTREMITIES_TABLE_SCHEMA = """
CREATE TABLE event_forward_extremities2(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
UNIQUE (event_id, room_id),
CONSTRAINT event_forward_extremities_event_id FOREIGN KEY (event_id) REFERENCES events (event_id)
)
"""
def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None:
run_validate_constraint_and_delete_rows_schema_delta(
cur,
ordering=7803,
update_name="event_forward_extremities_event_id_foreign_key_constraint_update",
table="event_forward_extremities",
constraint_name="event_forward_extremities_event_id",
constraint=ForeignKeyConstraint("events", [("event_id", "event_id")]),
sqlite_table_name="event_forward_extremities2",
sqlite_table_schema=FORWARD_EXTREMITIES_TABLE_SCHEMA,
)
# We can't add a similar constraint to `event_backward_extremities` as the
# events in there don't exist in the `events` table and `event_edges`
# doesn't have a unique constraint on `prev_event_id` (so we can't make a
# foreign key point to it).

View File

@ -20,7 +20,14 @@ from twisted.internet.defer import Deferred, ensureDeferred
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import (
BackgroundUpdater,
ForeignKeyConstraint,
NotNullConstraint,
run_validate_constraint_and_delete_rows_schema_delta,
)
from synapse.storage.database import LoggingTransaction
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
@ -404,3 +411,221 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
self.pump() self.pump()
self._update_ctx_manager.__aexit__.assert_called() self._update_ctx_manager.__aexit__.assert_called()
self.get_success(do_update_d) self.get_success(do_update_d)
class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase):
"""Tests the validate contraint and delete background handlers."""
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates
# the base test class should have run the real bg updates for us
self.assertTrue(
self.get_success(self.updates.has_completed_background_updates())
)
self.store = self.hs.get_datastores().main
def test_not_null_constraint(self) -> None:
# Create the initial tables, where we have some invalid data.
"""Tests adding a not null constraint."""
table_sql = """
CREATE TABLE test_constraint(
a INT PRIMARY KEY,
b INT
);
"""
self.get_success(
self.store.db_pool.execute(
"test_not_null_constraint", lambda _: None, table_sql
)
)
# We add an index so that we can check that its correctly recreated when
# using SQLite.
index_sql = "CREATE INDEX test_index ON test_constraint(a)"
self.get_success(
self.store.db_pool.execute(
"test_not_null_constraint", lambda _: None, index_sql
)
)
self.get_success(
self.store.db_pool.simple_insert("test_constraint", {"a": 1, "b": 1})
)
self.get_success(
self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": None})
)
self.get_success(
self.store.db_pool.simple_insert("test_constraint", {"a": 3, "b": 3})
)
# Now lets do the migration
table2_sqlite = """
CREATE TABLE test_constraint2(
a INT PRIMARY KEY,
b INT,
CONSTRAINT test_constraint_name CHECK (b is NOT NULL)
);
"""
def delta(txn: LoggingTransaction) -> None:
run_validate_constraint_and_delete_rows_schema_delta(
txn,
ordering=1000,
update_name="test_bg_update",
table="test_constraint",
constraint_name="test_constraint_name",
constraint=NotNullConstraint("b"),
sqlite_table_name="test_constraint2",
sqlite_table_schema=table2_sqlite,
)
self.get_success(
self.store.db_pool.runInteraction(
"test_not_null_constraint",
delta,
)
)
if isinstance(self.store.database_engine, PostgresEngine):
# Postgres uses a background update
self.updates.register_background_validate_constraint_and_delete_rows(
"test_bg_update",
table="test_constraint",
constraint_name="test_constraint_name",
constraint=NotNullConstraint("b"),
unique_columns=["a"],
)
# Tell the DataStore that it hasn't finished all updates yet
self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
self.wait_for_background_updates()
# Check the correct values are in the new table.
rows = self.get_success(
self.store.db_pool.simple_select_list(
table="test_constraint",
keyvalues={},
retcols=("a", "b"),
)
)
self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
# And check that invalid rows get correctly rejected.
self.get_failure(
self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": None}),
exc=self.store.database_engine.module.IntegrityError,
)
# Check the index is still there for SQLite.
if isinstance(self.store.database_engine, Sqlite3Engine):
# Ensure the index exists in the schema.
self.get_success(
self.store.db_pool.simple_select_one_onecol(
table="sqlite_master",
keyvalues={"tbl_name": "test_constraint"},
retcol="name",
)
)
def test_foreign_constraint(self) -> None:
"""Tests adding a not foreign key constraint."""
# Create the initial tables, where we have some invalid data.
base_sql = """
CREATE TABLE base_table(
b INT PRIMARY KEY
);
"""
table_sql = """
CREATE TABLE test_constraint(
a INT PRIMARY KEY,
b INT NOT NULL
);
"""
self.get_success(
self.store.db_pool.execute(
"test_foreign_key_constraint", lambda _: None, base_sql
)
)
self.get_success(
self.store.db_pool.execute(
"test_foreign_key_constraint", lambda _: None, table_sql
)
)
self.get_success(self.store.db_pool.simple_insert("base_table", {"b": 1}))
self.get_success(
self.store.db_pool.simple_insert("test_constraint", {"a": 1, "b": 1})
)
self.get_success(
self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": 2})
)
self.get_success(self.store.db_pool.simple_insert("base_table", {"b": 3}))
self.get_success(
self.store.db_pool.simple_insert("test_constraint", {"a": 3, "b": 3})
)
table2_sqlite = """
CREATE TABLE test_constraint2(
a INT PRIMARY KEY,
b INT NOT NULL,
CONSTRAINT test_constraint_name FOREIGN KEY (b) REFERENCES base_table (b)
);
"""
def delta(txn: LoggingTransaction) -> None:
run_validate_constraint_and_delete_rows_schema_delta(
txn,
ordering=1000,
update_name="test_bg_update",
table="test_constraint",
constraint_name="test_constraint_name",
constraint=ForeignKeyConstraint("base_table", [("b", "b")]),
sqlite_table_name="test_constraint2",
sqlite_table_schema=table2_sqlite,
)
self.get_success(
self.store.db_pool.runInteraction(
"test_foreign_key_constraint",
delta,
)
)
if isinstance(self.store.database_engine, PostgresEngine):
# Postgres uses a background update
self.updates.register_background_validate_constraint_and_delete_rows(
"test_bg_update",
table="test_constraint",
constraint_name="test_constraint_name",
constraint=ForeignKeyConstraint("base_table", [("b", "b")]),
unique_columns=["a"],
)
# Tell the DataStore that it hasn't finished all updates yet
self.store.db_pool.updates._all_done = False
# Now let's actually drive the updates to completion
self.wait_for_background_updates()
# Check the correct values are in the new table.
rows = self.get_success(
self.store.db_pool.simple_select_list(
table="test_constraint",
keyvalues={},
retcols=("a", "b"),
)
)
self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}])
# And check that invalid rows get correctly rejected.
self.get_failure(
self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": 2}),
exc=self.store.database_engine.module.IntegrityError,
)

View File

@ -20,6 +20,7 @@ from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.api.room_versions import ( from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS, KNOWN_ROOM_VERSIONS,
EventFormatVersions, EventFormatVersions,
@ -98,8 +99,32 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
room2 = "#room2" room2 = "#room2"
room3 = "#room3" room3 = "#room3"
def insert_event(txn: Cursor, i: int, room_id: str) -> None: def insert_event(txn: LoggingTransaction, i: int, room_id: str) -> None:
event_id = "$event_%i:local" % i event_id = "$event_%i:local" % i
# We need to insert into events table to get around the foreign key constraint.
self.store.db_pool.simple_insert_txn(
txn,
table="events",
values={
"instance_name": "master",
"stream_ordering": self.store._stream_id_gen.get_next_txn(txn),
"topological_ordering": 1,
"depth": 1,
"event_id": event_id,
"room_id": room_id,
"type": EventTypes.Message,
"processed": True,
"outlier": False,
"origin_server_ts": 0,
"received_ts": 0,
"sender": "@user:local",
"contains_url": False,
"state_key": None,
"rejection_reason": None,
},
)
txn.execute( txn.execute(
( (
"INSERT INTO event_forward_extremities (room_id, event_id) " "INSERT INTO event_forward_extremities (room_id, event_id) "
@ -113,10 +138,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.store.db_pool.runInteraction("insert", insert_event, i, room1) self.store.db_pool.runInteraction("insert", insert_event, i, room1)
) )
self.get_success( self.get_success(
self.store.db_pool.runInteraction("insert", insert_event, i, room2) self.store.db_pool.runInteraction(
"insert", insert_event, i + 100, room2
)
) )
self.get_success( self.get_success(
self.store.db_pool.runInteraction("insert", insert_event, i, room3) self.store.db_pool.runInteraction(
"insert", insert_event, i + 200, room3
)
) )
# Test simple case # Test simple case