Source code for kilter.service.runner

# Copyright 2022-2023 Dominik Sekotill <dom.sekotill@kodo.org.uk>
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

"""
Coordinate receiving and sending raw messages with a filter and Session object

The primary class in this module (`Runner`) is intended to be used with an
`anyio.abc.Listener`, which can be obtained, for instance, from
`anyio.create_tcp_listener()`.
"""

from __future__ import annotations

import logging
from collections import defaultdict
from collections.abc import AsyncGenerator
from typing import Final
from typing import TypeAlias
from warnings import warn

import anyio.abc
from anyio.streams.stapled import StapledObjectStream
from async_generator import aclosing
from typing_extensions import Self

from kilter.protocol.buffer import SimpleBuffer
from kilter.protocol.core import EditMessage
from kilter.protocol.core import EventMessage
from kilter.protocol.core import FilterProtocol
from kilter.protocol.core import ResponseMessage
from kilter.protocol.messages import ProtocolFlags

from .options import get_flags
from .options import get_macros
from .session import *
from .util import Broadcast
from .util import qualname

MessageChannel: TypeAlias = anyio.abc.ObjectStream[Message]
Sender: TypeAlias = AsyncGenerator[None, ResponseMessage|EditMessage|Negotiate|Skip]

kiB: Final = 2**10
MiB: Final = 2**20

# TODO: Convert to Union type alias once python/mypy#14242 is fixed
_VALID_FINAL_RESPONSES: Final = Reject, Discard, Accept, TemporaryFailure, ReplyCode
_VALID_EVENT_MESSAGE: TypeAlias = Helo | EnvelopeFrom | EnvelopeRecipient | Data | \
	Unknown | Header | EndOfHeaders | Body | EndOfMessage | Abort

_logger = logging.getLogger(__package__)


[docs]class NegotiationError(Exception): """ An error raised when MTAs are not compatible with the filter """
class _CloseFilter: def __init__(self, filtr: Filter): self.filter = filtr class _Broadcast(Broadcast[EventMessage]): def __init__(self) -> None: super().__init__() self.task_status: anyio.abc.TaskStatus[None]|None = None async def shutdown_hook(self) -> None: await self.pre_receive_hook() async def pre_receive_hook(self) -> None: if self.task_status is not None: self.task_status.started() self.task_status = None
[docs]class Runner: """ A filter runner that coordinates passing data between a stream and multiple filters Instances can be used as handlers that can be passed to `anyio.abc.Listener.serve()` or used with any `anyio.abc.ByteStream`. """ def __init__(self, *filters: Filter): if len(filters) == 0: # pragma: no-cover raise TypeError("Runner requires at least one filter to run") self.filters = list(filters) self.use_skip = True async def __call__(self, client: anyio.abc.ByteStream) -> None: """ Return an awaitable that starts and coordinates filters """ buff = SimpleBuffer(1*MiB) proto = FilterProtocol(abort_on_unknown=True) sender = _sender(client, proto) macro: Macro|None = None aborted = False await sender.asend(None) # type: ignore # initialise async with ( anyio.create_task_group() as tasks, aclosing(sender), aclosing(client), _TaskRunner(tasks) as runner, ): while 1: try: buff[:] = await client.receive(buff.available) except ( anyio.EndOfStream, anyio.ClosedResourceError, anyio.BrokenResourceError, ): await runner.aclose() return for message in proto.read_from(buff): if __debug__: _logger.debug(f"received: {message}") match message: case Negotiate(): await sender.asend(await self._negotiate(message)) case Macro() as macro: # Note that this Macro will hang around as "macro"; this is for # Connect messages. await runner.set_macros(macro) case Connect(): await self._prepare_filters(message, sender, runner) if macro: await runner.set_macros(macro) needs_response = proto.needs_response(message) match await runner.start(needs_response, True, self.use_skip): case None: assert not needs_response case _CloseFilter() as notif: self.filters.remove(notif.filter) case c_resp if needs_response: assert c_resp is not None and not isinstance(c_resp, _CloseFilter) await sender.asend(c_resp) case c_resp: raise RuntimeError(f"unexpected response: {c_resp}") case Abort(): aborted = True await runner.abort(message) case Close(): await runner.aclose() return case _: if aborted: aborted = False await runner.start(True, False, self.use_skip) needs_response = proto.needs_response(message) match await runner.message_events(message, needs_response): case None: assert not needs_response case _CloseFilter() as notif: self.filters.remove(notif.filter) case resp if needs_response: assert resp is not None and not isinstance(resp, _CloseFilter) await sender.asend(resp) case resp: raise RuntimeError(f"unexpected response: {resp}") async def _negotiate(self, message: Negotiate) -> Negotiate: _logger.info("Negotiating with MTA") optmask = ProtocolFlags.NONE options = \ ProtocolFlags.SKIP | \ ProtocolFlags.NO_HELO | \ ProtocolFlags.NO_SENDER | ProtocolFlags.NO_RECIPIENT | \ ProtocolFlags.NO_DATA | ProtocolFlags.NO_BODY | \ ProtocolFlags.NO_HEADERS | ProtocolFlags.NO_END_OF_HEADERS | \ ProtocolFlags.NR_CONNECT | ProtocolFlags.NR_HELO | \ ProtocolFlags.NR_SENDER | ProtocolFlags.NR_RECIPIENT | \ ProtocolFlags.NR_DATA | ProtocolFlags.NR_BODY | \ ProtocolFlags.NR_HEADER | ProtocolFlags.NR_END_OF_HEADERS actions = ActionFlags.NONE macros = defaultdict(set) options &= message.protocol_flags # Remove unoffered initial flags, they are not required for filtr in self.filters: flags = get_flags(filtr) optmask |= flags.unset_options options |= flags.set_options actions |= flags.set_actions for stage, names in get_macros(filtr).items(): macros[stage].update(names) options &= ~optmask if (missing_actions := actions & ~message.action_flags): raise NegotiationError(f"MTA does not accept {missing_actions}") if (missing_options := options & ~message.protocol_flags): raise NegotiationError(f"MTA does not offer {missing_options}") self.use_skip = ProtocolFlags.SKIP in options return Negotiate(6, actions, options, dict(macros)) async def _prepare_filters( self, message: Connect, sender: Sender, runner: _TaskRunner, ) -> None: _logger.info(f"Client connected from {message.hostname}") for fltr in self.filters: session = Session(message, sender, _Broadcast()) runner.add_filter(fltr, session)
class _TaskRunner: def __init__(self, tasks: anyio.abc.TaskGroup): self.tasks = tasks self.filters = list[tuple[Filter, Session]]() self.channels = dict[MessageChannel, Filter]() async def __aenter__(self) -> Self: return self async def __aexit__(self, *_: object) -> None: await self.aclose() def add_filter(self, flter: Filter, session: Session, /) -> None: self.filters.append((flter, session)) async def start( self, needs_response: bool, first_connect: bool, use_skip: bool, ) -> ResponseMessage|_CloseFilter|None: if self.channels: raise RuntimeError(f"{self} is already running tasks") final: ResponseMessage = Accept() for flter, session in self.filters: lchannel, rchannel = _make_message_channel() self.channels[lchannel] = flter match await self.tasks.start(self._runner, flter, session, rchannel, use_skip): case Accept(): del self.channels[lchannel] case Continue(): continue case TemporaryFailure() as final: # replaces final pass case Reject()|Discard()|ReplyCode() as resp: if not first_connect: _logger.warning( f"Ignoring unexpected response from filter after restart: " f"{qualname(flter)} -> {resp}", ) continue if not needs_response: _logger.warning( f"Unexpected response from filter {qualname(flter)}", ) return _CloseFilter(flter) return resp case _ as arg: # pragma: no-cover raise TypeError(f"task_status.started called with bad type: {arg!r}") if not needs_response: return None return final if len(self.channels) == 0 else Continue() async def set_macros(self, message: Macro) -> None: if self.channels: for channel in self.channels: await channel.send(message) else: for _, session in self.filters: await session.deliver(message) async def message_events( self, message: _VALID_EVENT_MESSAGE, needs_response: bool, ) -> ResponseMessage|Skip|_CloseFilter|None: skip = isinstance(message, Body) for channel in list(self.channels): await channel.send(message) match (await channel.receive()): case Skip(): continue case Continue(): skip = False case Accept() as resp: flter = await self.close_channel(channel) if len(self.channels) == 0: _logger.info(f"Returning response Accept from {qualname(flter)}") return resp _logger.info(f"Holding response Accept from {qualname(flter)}") case (Reject() | Discard() | TemporaryFailure() | ReplyCode()) as resp: flter = await self.close_channel(channel) if not needs_response: _logger.warning(f"Unexpected response from filter {qualname(flter)}") return _CloseFilter(flter) _logger.info(f"Returning response {type(resp).__name__} from {qualname(flter)}") return resp assert len(self.channels) > 0, "Running filters reached zero without a response?!" if not needs_response: return None return Skip() if skip else Continue() async def close_channel(self, channel: MessageChannel) -> Filter: await channel.aclose() return self.channels.pop(channel) async def abort(self, abort: Abort) -> None: if not self.channels: return _logger.info("Aborting filters") for channel in self.channels: await channel.send(abort) await channel.receive() await channel.aclose() self.channels.clear() async def aclose(self) -> None: if self.channels: _logger.info("Closing filters") self.tasks.cancel_scope.cancel() self.channels.clear() @staticmethod async def _runner( fltr: Filter, session: Session, channel: MessageChannel, use_skip: bool, *, task_status: anyio.abc.TaskStatus[ResponseMessage], ) -> None: final_resp: ResponseMessage|None = None async def _filter_wrap( task_status: anyio.abc.TaskStatus[None], ) -> None: nonlocal final_resp async with session: assert isinstance(session.broadcast, _Broadcast) session.broadcast.task_status = task_status try: final_resp = await fltr(session) except Aborted: _logger.debug(f"Aborted filter {qualname(fltr)}") return except Exception: _logger.exception(f"Error in filter {qualname(fltr)}") final_resp = TemporaryFailure() if not isinstance(final_resp, _VALID_FINAL_RESPONSES): warn(f"expected a valid response from {qualname(fltr)}, got {final_resp}") final_resp = TemporaryFailure() async with anyio.create_task_group() as tasks: await tasks.start(_filter_wrap) task_status.started(final_resp or Continue()) while final_resp is None: try: message = await channel.receive() except (anyio.EndOfStream, anyio.ClosedResourceError): tasks.cancel_scope.cancel() return if isinstance(message, Macro): await session.deliver(message) continue # TODO: Upgrade and remove ignores once python/mypy#14242 is in assert isinstance(message, _VALID_EVENT_MESSAGE) # type: ignore[misc,arg-type] resp = await session.deliver(message) # type: ignore[arg-type] if isinstance(message, Abort): await channel.send(Continue()) await channel.aclose() return if final_resp is not None: break # type: ignore[unreachable] await channel.send(Skip() if use_skip and resp == Skip else Continue()) await channel.send(final_resp) def _make_message_channel() -> tuple[MessageChannel, MessageChannel]: lsend, rrecv = anyio.create_memory_object_stream(1, Message) # type: ignore rsend, lrecv = anyio.create_memory_object_stream(1, Message) # type: ignore return StapledObjectStream(lsend, lrecv), StapledObjectStream(rsend, rrecv) async def _sender(client: anyio.abc.ByteSendStream, proto: FilterProtocol) -> Sender: buff = SimpleBuffer(1*kiB) while 1: proto.write_to(buff, (message := (yield))) if __debug__: _logger.debug(f"sent: {message}") await client.send(buff[:]) del buff[:]