| | import asyncio |
| | import asyncio.streams |
| | import sys |
| | import traceback |
| | import warnings |
| | from collections import deque |
| | from contextlib import suppress |
| | from html import escape as html_escape |
| | from http import HTTPStatus |
| | from logging import Logger |
| | from typing import ( |
| | TYPE_CHECKING, |
| | Any, |
| | Awaitable, |
| | Callable, |
| | Deque, |
| | Optional, |
| | Sequence, |
| | Tuple, |
| | Type, |
| | Union, |
| | cast, |
| | ) |
| |
|
| | import attr |
| | import yarl |
| |
|
| | from .abc import AbstractAccessLogger, AbstractStreamWriter |
| | from .base_protocol import BaseProtocol |
| | from .helpers import ceil_timeout |
| | from .http import ( |
| | HttpProcessingError, |
| | HttpRequestParser, |
| | HttpVersion10, |
| | RawRequestMessage, |
| | StreamWriter, |
| | ) |
| | from .http_exceptions import BadHttpMethod |
| | from .log import access_logger, server_logger |
| | from .streams import EMPTY_PAYLOAD, StreamReader |
| | from .tcp_helpers import tcp_keepalive |
| | from .web_exceptions import HTTPException, HTTPInternalServerError |
| | from .web_log import AccessLogger |
| | from .web_request import BaseRequest |
| | from .web_response import Response, StreamResponse |
| |
|
| | __all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError") |
| |
|
| | if TYPE_CHECKING: |
| | from .web_server import Server |
| |
|
| |
|
| | _RequestFactory = Callable[ |
| | [ |
| | RawRequestMessage, |
| | StreamReader, |
| | "RequestHandler", |
| | AbstractStreamWriter, |
| | "asyncio.Task[None]", |
| | ], |
| | BaseRequest, |
| | ] |
| |
|
| | _RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]] |
| |
|
| | ERROR = RawRequestMessage( |
| | "UNKNOWN", |
| | "/", |
| | HttpVersion10, |
| | {}, |
| | {}, |
| | True, |
| | None, |
| | False, |
| | False, |
| | yarl.URL("/"), |
| | ) |
| |
|
| |
|
| | class RequestPayloadError(Exception): |
| | """Payload parsing error.""" |
| |
|
| |
|
| | class PayloadAccessError(Exception): |
| | """Payload was accessed after response was sent.""" |
| |
|
| |
|
| | _PAYLOAD_ACCESS_ERROR = PayloadAccessError() |
| |
|
| |
|
| | @attr.s(auto_attribs=True, frozen=True, slots=True) |
| | class _ErrInfo: |
| | status: int |
| | exc: BaseException |
| | message: str |
| |
|
| |
|
| | _MsgType = Tuple[Union[RawRequestMessage, _ErrInfo], StreamReader] |
| |
|
| |
|
| | class RequestHandler(BaseProtocol): |
| | """HTTP protocol implementation. |
| | |
| | RequestHandler handles incoming HTTP request. It reads request line, |
| | request headers and request payload and calls handle_request() method. |
| | By default it always returns with 404 response. |
| | |
| | RequestHandler handles errors in incoming request, like bad |
| | status line, bad headers or incomplete payload. If any error occurs, |
| | connection gets closed. |
| | |
| | keepalive_timeout -- number of seconds before closing |
| | keep-alive connection |
| | |
| | tcp_keepalive -- TCP keep-alive is on, default is on |
| | |
| | debug -- enable debug mode |
| | |
| | logger -- custom logger object |
| | |
| | access_log_class -- custom class for access_logger |
| | |
| | access_log -- custom logging object |
| | |
| | access_log_format -- access log format string |
| | |
| | loop -- Optional event loop |
| | |
| | max_line_size -- Optional maximum header line size |
| | |
| | max_field_size -- Optional maximum header field size |
| | |
| | max_headers -- Optional maximum header size |
| | |
| | timeout_ceil_threshold -- Optional value to specify |
| | threshold to ceil() timeout |
| | values |
| | |
| | """ |
| |
|
| | __slots__ = ( |
| | "_request_count", |
| | "_keepalive", |
| | "_manager", |
| | "_request_handler", |
| | "_request_factory", |
| | "_tcp_keepalive", |
| | "_next_keepalive_close_time", |
| | "_keepalive_handle", |
| | "_keepalive_timeout", |
| | "_lingering_time", |
| | "_messages", |
| | "_message_tail", |
| | "_handler_waiter", |
| | "_waiter", |
| | "_task_handler", |
| | "_upgrade", |
| | "_payload_parser", |
| | "_request_parser", |
| | "_reading_paused", |
| | "logger", |
| | "debug", |
| | "access_log", |
| | "access_logger", |
| | "_close", |
| | "_force_close", |
| | "_current_request", |
| | "_timeout_ceil_threshold", |
| | "_request_in_progress", |
| | ) |
| |
|
| | def __init__( |
| | self, |
| | manager: "Server", |
| | *, |
| | loop: asyncio.AbstractEventLoop, |
| | |
| | keepalive_timeout: float = 3630, |
| | tcp_keepalive: bool = True, |
| | logger: Logger = server_logger, |
| | access_log_class: Type[AbstractAccessLogger] = AccessLogger, |
| | access_log: Logger = access_logger, |
| | access_log_format: str = AccessLogger.LOG_FORMAT, |
| | debug: bool = False, |
| | max_line_size: int = 8190, |
| | max_headers: int = 32768, |
| | max_field_size: int = 8190, |
| | lingering_time: float = 10.0, |
| | read_bufsize: int = 2**16, |
| | auto_decompress: bool = True, |
| | timeout_ceil_threshold: float = 5, |
| | ): |
| | super().__init__(loop) |
| |
|
| | |
| | self._request_count = 0 |
| | self._keepalive = False |
| | self._current_request: Optional[BaseRequest] = None |
| | self._manager: Optional[Server] = manager |
| | self._request_handler: Optional[_RequestHandler] = manager.request_handler |
| | self._request_factory: Optional[_RequestFactory] = manager.request_factory |
| |
|
| | self._tcp_keepalive = tcp_keepalive |
| | |
| | self._next_keepalive_close_time = 0.0 |
| | self._keepalive_handle: Optional[asyncio.Handle] = None |
| | self._keepalive_timeout = keepalive_timeout |
| | self._lingering_time = float(lingering_time) |
| |
|
| | self._messages: Deque[_MsgType] = deque() |
| | self._message_tail = b"" |
| |
|
| | self._waiter: Optional[asyncio.Future[None]] = None |
| | self._handler_waiter: Optional[asyncio.Future[None]] = None |
| | self._task_handler: Optional[asyncio.Task[None]] = None |
| |
|
| | self._upgrade = False |
| | self._payload_parser: Any = None |
| | self._request_parser: Optional[HttpRequestParser] = HttpRequestParser( |
| | self, |
| | loop, |
| | read_bufsize, |
| | max_line_size=max_line_size, |
| | max_field_size=max_field_size, |
| | max_headers=max_headers, |
| | payload_exception=RequestPayloadError, |
| | auto_decompress=auto_decompress, |
| | ) |
| |
|
| | self._timeout_ceil_threshold: float = 5 |
| | try: |
| | self._timeout_ceil_threshold = float(timeout_ceil_threshold) |
| | except (TypeError, ValueError): |
| | pass |
| |
|
| | self.logger = logger |
| | self.debug = debug |
| | self.access_log = access_log |
| | if access_log: |
| | self.access_logger: Optional[AbstractAccessLogger] = access_log_class( |
| | access_log, access_log_format |
| | ) |
| | else: |
| | self.access_logger = None |
| |
|
| | self._close = False |
| | self._force_close = False |
| | self._request_in_progress = False |
| |
|
| | def __repr__(self) -> str: |
| | return "<{} {}>".format( |
| | self.__class__.__name__, |
| | "connected" if self.transport is not None else "disconnected", |
| | ) |
| |
|
| | @property |
| | def keepalive_timeout(self) -> float: |
| | return self._keepalive_timeout |
| |
|
| | async def shutdown(self, timeout: Optional[float] = 15.0) -> None: |
| | """Do worker process exit preparations. |
| | |
| | We need to clean up everything and stop accepting requests. |
| | It is especially important for keep-alive connections. |
| | """ |
| | self._force_close = True |
| |
|
| | if self._keepalive_handle is not None: |
| | self._keepalive_handle.cancel() |
| |
|
| | |
| | if self._request_in_progress: |
| | |
| | |
| | |
| | self._handler_waiter = self._loop.create_future() |
| | try: |
| | async with ceil_timeout(timeout): |
| | await self._handler_waiter |
| | except (asyncio.CancelledError, asyncio.TimeoutError): |
| | self._handler_waiter = None |
| | if ( |
| | sys.version_info >= (3, 11) |
| | and (task := asyncio.current_task()) |
| | and task.cancelling() |
| | ): |
| | raise |
| | |
| | try: |
| | async with ceil_timeout(timeout): |
| | if self._current_request is not None: |
| | self._current_request._cancel(asyncio.CancelledError()) |
| |
|
| | if self._task_handler is not None and not self._task_handler.done(): |
| | await asyncio.shield(self._task_handler) |
| | except (asyncio.CancelledError, asyncio.TimeoutError): |
| | if ( |
| | sys.version_info >= (3, 11) |
| | and (task := asyncio.current_task()) |
| | and task.cancelling() |
| | ): |
| | raise |
| |
|
| | |
| | if self._task_handler is not None: |
| | self._task_handler.cancel() |
| |
|
| | self.force_close() |
| |
|
| | def connection_made(self, transport: asyncio.BaseTransport) -> None: |
| | super().connection_made(transport) |
| |
|
| | real_transport = cast(asyncio.Transport, transport) |
| | if self._tcp_keepalive: |
| | tcp_keepalive(real_transport) |
| |
|
| | assert self._manager is not None |
| | self._manager.connection_made(self, real_transport) |
| |
|
| | loop = self._loop |
| | if sys.version_info >= (3, 12): |
| | task = asyncio.Task(self.start(), loop=loop, eager_start=True) |
| | else: |
| | task = loop.create_task(self.start()) |
| | self._task_handler = task |
| |
|
| | def connection_lost(self, exc: Optional[BaseException]) -> None: |
| | if self._manager is None: |
| | return |
| | self._manager.connection_lost(self, exc) |
| |
|
| | |
| | handler_cancellation = self._manager.handler_cancellation |
| |
|
| | self.force_close() |
| | super().connection_lost(exc) |
| | self._manager = None |
| | self._request_factory = None |
| | self._request_handler = None |
| | self._request_parser = None |
| |
|
| | if self._keepalive_handle is not None: |
| | self._keepalive_handle.cancel() |
| |
|
| | if self._current_request is not None: |
| | if exc is None: |
| | exc = ConnectionResetError("Connection lost") |
| | self._current_request._cancel(exc) |
| |
|
| | if handler_cancellation and self._task_handler is not None: |
| | self._task_handler.cancel() |
| |
|
| | self._task_handler = None |
| |
|
| | if self._payload_parser is not None: |
| | self._payload_parser.feed_eof() |
| | self._payload_parser = None |
| |
|
| | def set_parser(self, parser: Any) -> None: |
| | |
| | assert self._payload_parser is None |
| |
|
| | self._payload_parser = parser |
| |
|
| | if self._message_tail: |
| | self._payload_parser.feed_data(self._message_tail) |
| | self._message_tail = b"" |
| |
|
| | def eof_received(self) -> None: |
| | pass |
| |
|
| | def data_received(self, data: bytes) -> None: |
| | if self._force_close or self._close: |
| | return |
| | |
| | messages: Sequence[_MsgType] |
| | if self._payload_parser is None and not self._upgrade: |
| | assert self._request_parser is not None |
| | try: |
| | messages, upgraded, tail = self._request_parser.feed_data(data) |
| | except HttpProcessingError as exc: |
| | messages = [ |
| | (_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD) |
| | ] |
| | upgraded = False |
| | tail = b"" |
| |
|
| | for msg, payload in messages or (): |
| | self._request_count += 1 |
| | self._messages.append((msg, payload)) |
| |
|
| | waiter = self._waiter |
| | if messages and waiter is not None and not waiter.done(): |
| | |
| | waiter.set_result(None) |
| |
|
| | self._upgrade = upgraded |
| | if upgraded and tail: |
| | self._message_tail = tail |
| |
|
| | |
| | elif self._payload_parser is None and self._upgrade and data: |
| | self._message_tail += data |
| |
|
| | |
| | elif data: |
| | eof, tail = self._payload_parser.feed_data(data) |
| | if eof: |
| | self.close() |
| |
|
| | def keep_alive(self, val: bool) -> None: |
| | """Set keep-alive connection mode. |
| | |
| | :param bool val: new state. |
| | """ |
| | self._keepalive = val |
| | if self._keepalive_handle: |
| | self._keepalive_handle.cancel() |
| | self._keepalive_handle = None |
| |
|
| | def close(self) -> None: |
| | """Close connection. |
| | |
| | Stop accepting new pipelining messages and close |
| | connection when handlers done processing messages. |
| | """ |
| | self._close = True |
| | if self._waiter: |
| | self._waiter.cancel() |
| |
|
| | def force_close(self) -> None: |
| | """Forcefully close connection.""" |
| | self._force_close = True |
| | if self._waiter: |
| | self._waiter.cancel() |
| | if self.transport is not None: |
| | self.transport.close() |
| | self.transport = None |
| |
|
| | def log_access( |
| | self, request: BaseRequest, response: StreamResponse, time: float |
| | ) -> None: |
| | if self.access_logger is not None and self.access_logger.enabled: |
| | self.access_logger.log(request, response, self._loop.time() - time) |
| |
|
| | def log_debug(self, *args: Any, **kw: Any) -> None: |
| | if self.debug: |
| | self.logger.debug(*args, **kw) |
| |
|
| | def log_exception(self, *args: Any, **kw: Any) -> None: |
| | self.logger.exception(*args, **kw) |
| |
|
| | def _process_keepalive(self) -> None: |
| | self._keepalive_handle = None |
| | if self._force_close or not self._keepalive: |
| | return |
| |
|
| | loop = self._loop |
| | now = loop.time() |
| | close_time = self._next_keepalive_close_time |
| | if now < close_time: |
| | |
| | self._keepalive_handle = loop.call_at(close_time, self._process_keepalive) |
| | return |
| |
|
| | |
| | if self._waiter and not self._waiter.done(): |
| | self.force_close() |
| |
|
| | async def _handle_request( |
| | self, |
| | request: BaseRequest, |
| | start_time: float, |
| | request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]], |
| | ) -> Tuple[StreamResponse, bool]: |
| | self._request_in_progress = True |
| | try: |
| | try: |
| | self._current_request = request |
| | resp = await request_handler(request) |
| | finally: |
| | self._current_request = None |
| | except HTTPException as exc: |
| | resp = exc |
| | resp, reset = await self.finish_response(request, resp, start_time) |
| | except asyncio.CancelledError: |
| | raise |
| | except asyncio.TimeoutError as exc: |
| | self.log_debug("Request handler timed out.", exc_info=exc) |
| | resp = self.handle_error(request, 504) |
| | resp, reset = await self.finish_response(request, resp, start_time) |
| | except Exception as exc: |
| | resp = self.handle_error(request, 500, exc) |
| | resp, reset = await self.finish_response(request, resp, start_time) |
| | else: |
| | |
| | if getattr(resp, "__http_exception__", False): |
| | warnings.warn( |
| | "returning HTTPException object is deprecated " |
| | "(#2415) and will be removed, " |
| | "please raise the exception instead", |
| | DeprecationWarning, |
| | ) |
| |
|
| | resp, reset = await self.finish_response(request, resp, start_time) |
| | finally: |
| | self._request_in_progress = False |
| | if self._handler_waiter is not None: |
| | self._handler_waiter.set_result(None) |
| |
|
| | return resp, reset |
| |
|
| | async def start(self) -> None: |
| | """Process incoming request. |
| | |
| | It reads request line, request headers and request payload, then |
| | calls handle_request() method. Subclass has to override |
| | handle_request(). start() handles various exceptions in request |
| | or response handling. Connection is being closed always unless |
| | keep_alive(True) specified. |
| | """ |
| | loop = self._loop |
| | handler = asyncio.current_task(loop) |
| | assert handler is not None |
| | manager = self._manager |
| | assert manager is not None |
| | keepalive_timeout = self._keepalive_timeout |
| | resp = None |
| | assert self._request_factory is not None |
| | assert self._request_handler is not None |
| |
|
| | while not self._force_close: |
| | if not self._messages: |
| | try: |
| | |
| | self._waiter = loop.create_future() |
| | await self._waiter |
| | finally: |
| | self._waiter = None |
| |
|
| | message, payload = self._messages.popleft() |
| |
|
| | start = loop.time() |
| |
|
| | manager.requests_count += 1 |
| | writer = StreamWriter(self, loop) |
| | if isinstance(message, _ErrInfo): |
| | |
| | request_handler = self._make_error_handler(message) |
| | message = ERROR |
| | else: |
| | request_handler = self._request_handler |
| |
|
| | request = self._request_factory(message, payload, self, writer, handler) |
| | try: |
| | |
| | coro = self._handle_request(request, start, request_handler) |
| | if sys.version_info >= (3, 12): |
| | task = asyncio.Task(coro, loop=loop, eager_start=True) |
| | else: |
| | task = loop.create_task(coro) |
| | try: |
| | resp, reset = await task |
| | except ConnectionError: |
| | self.log_debug("Ignored premature client disconnection") |
| | break |
| |
|
| | |
| | del task |
| | if reset: |
| | self.log_debug("Ignored premature client disconnection 2") |
| | break |
| |
|
| | |
| | self._keepalive = bool(resp.keep_alive) |
| |
|
| | |
| | if not payload.is_eof(): |
| | lingering_time = self._lingering_time |
| | if not self._force_close and lingering_time: |
| | self.log_debug( |
| | "Start lingering close timer for %s sec.", lingering_time |
| | ) |
| |
|
| | now = loop.time() |
| | end_t = now + lingering_time |
| |
|
| | try: |
| | while not payload.is_eof() and now < end_t: |
| | async with ceil_timeout(end_t - now): |
| | |
| | await payload.readany() |
| | now = loop.time() |
| | except (asyncio.CancelledError, asyncio.TimeoutError): |
| | if ( |
| | sys.version_info >= (3, 11) |
| | and (t := asyncio.current_task()) |
| | and t.cancelling() |
| | ): |
| | raise |
| |
|
| | |
| | if not payload.is_eof() and not self._force_close: |
| | self.log_debug("Uncompleted request.") |
| | self.close() |
| |
|
| | payload.set_exception(_PAYLOAD_ACCESS_ERROR) |
| |
|
| | except asyncio.CancelledError: |
| | self.log_debug("Ignored premature client disconnection") |
| | raise |
| | except Exception as exc: |
| | self.log_exception("Unhandled exception", exc_info=exc) |
| | self.force_close() |
| | finally: |
| | if self.transport is None and resp is not None: |
| | self.log_debug("Ignored premature client disconnection.") |
| | elif not self._force_close: |
| | if self._keepalive and not self._close: |
| | |
| | if keepalive_timeout is not None: |
| | now = loop.time() |
| | close_time = now + keepalive_timeout |
| | self._next_keepalive_close_time = close_time |
| | if self._keepalive_handle is None: |
| | self._keepalive_handle = loop.call_at( |
| | close_time, self._process_keepalive |
| | ) |
| | else: |
| | break |
| |
|
| | |
| | if not self._force_close: |
| | self._task_handler = None |
| | if self.transport is not None: |
| | self.transport.close() |
| |
|
| | async def finish_response( |
| | self, request: BaseRequest, resp: StreamResponse, start_time: float |
| | ) -> Tuple[StreamResponse, bool]: |
| | """Prepare the response and write_eof, then log access. |
| | |
| | This has to |
| | be called within the context of any exception so the access logger |
| | can get exception information. Returns True if the client disconnects |
| | prematurely. |
| | """ |
| | request._finish() |
| | if self._request_parser is not None: |
| | self._request_parser.set_upgraded(False) |
| | self._upgrade = False |
| | if self._message_tail: |
| | self._request_parser.feed_data(self._message_tail) |
| | self._message_tail = b"" |
| | try: |
| | prepare_meth = resp.prepare |
| | except AttributeError: |
| | if resp is None: |
| | self.log_exception("Missing return statement on request handler") |
| | else: |
| | self.log_exception( |
| | "Web-handler should return a response instance, " |
| | "got {!r}".format(resp) |
| | ) |
| | exc = HTTPInternalServerError() |
| | resp = Response( |
| | status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers |
| | ) |
| | prepare_meth = resp.prepare |
| | try: |
| | await prepare_meth(request) |
| | await resp.write_eof() |
| | except ConnectionError: |
| | self.log_access(request, resp, start_time) |
| | return resp, True |
| |
|
| | self.log_access(request, resp, start_time) |
| | return resp, False |
| |
|
| | def handle_error( |
| | self, |
| | request: BaseRequest, |
| | status: int = 500, |
| | exc: Optional[BaseException] = None, |
| | message: Optional[str] = None, |
| | ) -> StreamResponse: |
| | """Handle errors. |
| | |
| | Returns HTTP response with specific status code. Logs additional |
| | information. It always closes current connection. |
| | """ |
| | if self._request_count == 1 and isinstance(exc, BadHttpMethod): |
| | |
| | |
| | |
| | |
| | self.logger.debug("Error handling request", exc_info=exc) |
| | else: |
| | self.log_exception("Error handling request", exc_info=exc) |
| |
|
| | |
| | if request.writer.output_size > 0: |
| | raise ConnectionError( |
| | "Response is sent already, cannot send another response " |
| | "with the error message" |
| | ) |
| |
|
| | ct = "text/plain" |
| | if status == HTTPStatus.INTERNAL_SERVER_ERROR: |
| | title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR) |
| | msg = HTTPStatus.INTERNAL_SERVER_ERROR.description |
| | tb = None |
| | if self.debug: |
| | with suppress(Exception): |
| | tb = traceback.format_exc() |
| |
|
| | if "text/html" in request.headers.get("Accept", ""): |
| | if tb: |
| | tb = html_escape(tb) |
| | msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>" |
| | message = ( |
| | "<html><head>" |
| | "<title>{title}</title>" |
| | "</head><body>\n<h1>{title}</h1>" |
| | "\n{msg}\n</body></html>\n" |
| | ).format(title=title, msg=msg) |
| | ct = "text/html" |
| | else: |
| | if tb: |
| | msg = tb |
| | message = title + "\n\n" + msg |
| |
|
| | resp = Response(status=status, text=message, content_type=ct) |
| | resp.force_close() |
| |
|
| | return resp |
| |
|
| | def _make_error_handler( |
| | self, err_info: _ErrInfo |
| | ) -> Callable[[BaseRequest], Awaitable[StreamResponse]]: |
| | async def handler(request: BaseRequest) -> StreamResponse: |
| | return self.handle_error( |
| | request, err_info.status, err_info.exc, err_info.message |
| | ) |
| |
|
| | return handler |
| |
|