Prevent memory leak from reoccurring when presence is disabled. (#12656)

This commit is contained in:
Erik Johnston 2022-05-06 17:41:57 +01:00 committed by GitHub
parent 2607b3e181
commit 4337d33a73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 30 deletions

1
changelog.d/12656.misc Normal file
View File

@ -0,0 +1 @@
Prevent memory leak from reoccurring when presence is disabled.

View File

@ -659,6 +659,7 @@ class PresenceHandler(BasePresenceHandler):
) )
now = self.clock.time_msec() now = self.clock.time_msec()
if self._presence_enabled:
for state in self.user_to_current_state.values(): for state in self.user_to_current_state.values():
self.wheel_timer.insert( self.wheel_timer.insert(
now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER
@ -804,6 +805,13 @@ class PresenceHandler(BasePresenceHandler):
This is currently used to bump the max presence stream ID without changing any This is currently used to bump the max presence stream ID without changing any
user's presence (see PresenceHandler.add_users_to_send_full_presence_to). user's presence (see PresenceHandler.add_users_to_send_full_presence_to).
""" """
if not self._presence_enabled:
# We shouldn't get here if presence is disabled, but we check anyway
# to ensure that we don't a) send out presence federation and b)
# don't add things to the wheel timer that will never be handled.
logger.warning("Tried to update presence states when presence is disabled")
return
now = self.clock.time_msec() now = self.clock.time_msec()
with Measure(self.clock, "presence_update_states"): with Measure(self.clock, "presence_update_states"):
@ -1229,6 +1237,10 @@ class PresenceHandler(BasePresenceHandler):
): ):
raise SynapseError(400, "Invalid presence state") raise SynapseError(400, "Invalid presence state")
# If presence is disabled, no-op
if not self.hs.config.server.use_presence:
return
user_id = target_user.to_string() user_id = target_user.to_string()
prev_state = await self.current_state_for_user(user_id) prev_state = await self.current_state_for_user(user_id)

View File

@ -11,17 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Generic, List, TypeVar import logging
from typing import Generic, Hashable, List, Set, TypeVar
T = TypeVar("T") import attr
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Hashable)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _Entry(Generic[T]): class _Entry(Generic[T]):
__slots__ = ["end_key", "queue"] end_key: int
elements: Set[T] = attr.Factory(set)
def __init__(self, end_key: int) -> None:
self.end_key: int = end_key
self.queue: List[T] = []
class WheelTimer(Generic[T]): class WheelTimer(Generic[T]):
@ -48,17 +51,27 @@ class WheelTimer(Generic[T]):
then: When to return the object strictly after. then: When to return the object strictly after.
""" """
then_key = int(then / self.bucket_size) + 1 then_key = int(then / self.bucket_size) + 1
now_key = int(now / self.bucket_size)
if self.entries: if self.entries:
min_key = self.entries[0].end_key min_key = self.entries[0].end_key
max_key = self.entries[-1].end_key max_key = self.entries[-1].end_key
if min_key < now_key - 10:
# If we have ten buckets that are due and still nothing has
# called `fetch()` then we likely have a bug that is causing a
# memory leak.
logger.warning(
"Inserting into a wheel timer that hasn't been read from recently. Item: %s",
obj,
)
if then_key <= max_key: if then_key <= max_key:
# The max here is to protect against inserts for times in the past # The max here is to protect against inserts for times in the past
self.entries[max(min_key, then_key) - min_key].queue.append(obj) self.entries[max(min_key, then_key) - min_key].elements.add(obj)
return return
next_key = int(now / self.bucket_size) + 1 next_key = now_key + 1
if self.entries: if self.entries:
last_key = self.entries[-1].end_key last_key = self.entries[-1].end_key
else: else:
@ -71,7 +84,7 @@ class WheelTimer(Generic[T]):
# to insert. This ensures there are no gaps. # to insert. This ensures there are no gaps.
self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1)) self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1))
self.entries[-1].queue.append(obj) self.entries[-1].elements.add(obj)
def fetch(self, now: int) -> List[T]: def fetch(self, now: int) -> List[T]:
"""Fetch any objects that have timed out """Fetch any objects that have timed out
@ -84,11 +97,11 @@ class WheelTimer(Generic[T]):
""" """
now_key = int(now / self.bucket_size) now_key = int(now / self.bucket_size)
ret = [] ret: List[T] = []
while self.entries and self.entries[0].end_key <= now_key: while self.entries and self.entries[0].end_key <= now_key:
ret.extend(self.entries.pop(0).queue) ret.extend(self.entries.pop(0).elements)
return ret return ret
def __len__(self) -> int: def __len__(self) -> int:
return sum(len(entry.queue) for entry in self.entries) return sum(len(entry.elements) for entry in self.entries)