Skip to content

rs_dpr_service/utils/middlewares.md

<< Back to index

Common functions for fastapi middlewares.

NOTE: COPY-PASTED FROM RS-SERVER.

HandleExceptionsMiddleware

Bases: BaseHTTPMiddleware

Middleware to catch all exceptions and return a JSONResponse instead of raising them. This is useful in FastAPI when HttpExceptions are raised within the code but need to be handled gracefully.

Attributes:

Name Type Description
rfc7807 bool

If true, the returned content is compliant with RFC 7807. This is used by pygeoapi/ogc services.

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
class HandleExceptionsMiddleware(BaseHTTPMiddleware):
    """
    Middleware to catch all exceptions and return a JSONResponse instead of raising them.
    This is useful in FastAPI when HttpExceptions are raised within the code but need to be handled gracefully.

    Attributes:
        rfc7807 (bool): If true, the returned content is compliant with RFC 7807. This is used by pygeoapi/ogc services.
        False by default = compliant to Stac specifications.
    """

    def __init__(self, app, rfc7807: bool = False, dispatch=None):
        """Constructor"""
        self.rfc7807: bool = rfc7807
        super().__init__(app, dispatch)

    @staticmethod
    def disable_default_exception_handler(app: FastAPI):
        """
        Disable the default FastAPI exception handler for HTTPException and StarletteHTTPException.
        We just re-raise the exceptions so they'll be handled by HandleExceptionsMiddleware.
        """

        @app.exception_handler(HTTPException)
        @app.exception_handler(StarletteHTTPException)
        async def exception_handler(_request: Request, _exc: HTTPException):
            """Implement disable_default_exception_handler"""
            # Note: we could raise(_exc) but it would increase the stack trace length with this module info.
            # We can just call raise because this function is called from an except clause.
            raise  # pylint: disable=misplaced-bare-raise

    async def dispatch(self, request: Request, call_next: Callable):
        try:
            # Call next middleware, get and return response, handle errors
            response = await call_next(request)
            return await self.handle_errors(response)

        except Exception as exc:  # pylint: disable=broad-exception-caught
            return await self.handle_exceptions(request, exc)

    @staticmethod
    def format_code(status_code: int) -> str:
        """Convert e.g. HTTP_500_INTERNAL_SERVER_ERROR into 'InternalServerError'"""
        phrase = HTTPStatus(status_code).phrase
        return "".join(word.title() for word in phrase.split())

    @staticmethod
    def rfc7807_response(status_code: int, detail: str) -> Rfc7807ErrorResponse:
        """Return Rfc7807ErrorResponse instance"""
        return Rfc7807ErrorResponse(
            type=f"https://developer.mozilla.org/en/docs/Web/HTTP/Reference/Status/{status_code}",
            status=status_code,
            detail=detail,
        )

    async def handle_errors(self, response: StreamingResponse) -> Response:
        """
        If no errors, just return the original response.
        In case of errors, log, format and return the response contents.
        """
        if not 400 <= response.status_code < 600:
            return response  # no error, return the original response

        # Read response content
        try:
            content = await read_streaming_response(response)

        # If we fail to read content, just return the original response
        except Exception as exc:  # pylint: disable=broad-exception-caught
            logger.error(exc)
            return response

        # The content should be formated as a XxxErrorResponse
        formatted: Rfc7807ErrorResponse | StacErrorResponse | None = None
        try:
            if self.rfc7807:
                formatted = Rfc7807ErrorResponse(
                    type=str(content["type"]),
                    status=int(content["status"]),
                    detail=str(content["detail"]),
                )
            else:
                formatted = StacErrorResponse(code=str(content["code"]), description=str(content["description"]))
            if formatted != content:
                formatted = None
        except Exception:  # pylint: disable=broad-exception-caught # nosec B110
            pass

        # Else format the content
        if not formatted:
            description = json.dumps(content) if isinstance(content, (dict, list, set)) else str(content)
            if self.rfc7807:
                formatted = self.rfc7807_response(response.status_code, detail=description)
            else:
                formatted = StacErrorResponse(code=self.format_code(response.status_code), description=description)

        logger.error(f"{response.status_code}: {json.dumps(formatted)}")
        return JSONResponse(status_code=response.status_code, content=formatted)

    async def handle_exceptions(self, request: Request, exc: Exception) -> JSONResponse:
        """In case of exceptions, log the response contents"""

        # Log current stack trace
        logger.exception(exc)

        # Calculate HTTP response status code (int) and StacErrorResponse code (str) and description (str)
        if isinstance(exc, StarletteHTTPException):
            status_code = exc.status_code
            # Format int status code into str
            str_code = self.format_code(exc.status_code)
            description = str(exc.detail)

        else:
            # Use generic 400 or 500 code
            status_code = (
                status.HTTP_400_BAD_REQUEST
                if HandleExceptionsMiddleware.is_bad_request(request, exc)
                else status.HTTP_500_INTERNAL_SERVER_ERROR
            )
            str_code = exc.__class__.__name__
            description = str(exc)

        error_response: Rfc7807ErrorResponse | StacErrorResponse
        if self.rfc7807:
            error_response = self.rfc7807_response(status_code, detail=description)
        else:
            error_response = StacErrorResponse(code=str_code, description=description)
        return JSONResponse(status_code=status_code, content=error_response)

    @staticmethod
    def is_bad_request(request: Request, e: Exception) -> bool:
        """
        Determines if the request that raised this exception shall be considered as a bad request
        and return a 400 error code.

        This function can be overriden by the caller if needed with:
        HandleExceptionsMiddleware.is_bad_request = my_callable
        """
        return "bbox" in request.query_params and (
            str(e).endswith(" must have 4 or 6 values.") or str(e).startswith("could not convert string to float: ")
        )

__init__(app, rfc7807=False, dispatch=None)

Constructor

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
94
95
96
97
def __init__(self, app, rfc7807: bool = False, dispatch=None):
    """Constructor"""
    self.rfc7807: bool = rfc7807
    super().__init__(app, dispatch)

disable_default_exception_handler(app) staticmethod

Disable the default FastAPI exception handler for HTTPException and StarletteHTTPException. We just re-raise the exceptions so they'll be handled by HandleExceptionsMiddleware.

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
@staticmethod
def disable_default_exception_handler(app: FastAPI):
    """
    Disable the default FastAPI exception handler for HTTPException and StarletteHTTPException.
    We just re-raise the exceptions so they'll be handled by HandleExceptionsMiddleware.
    """

    @app.exception_handler(HTTPException)
    @app.exception_handler(StarletteHTTPException)
    async def exception_handler(_request: Request, _exc: HTTPException):
        """Implement disable_default_exception_handler"""
        # Note: we could raise(_exc) but it would increase the stack trace length with this module info.
        # We can just call raise because this function is called from an except clause.
        raise  # pylint: disable=misplaced-bare-raise

format_code(status_code) staticmethod

Convert e.g. HTTP_500_INTERNAL_SERVER_ERROR into 'InternalServerError'

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
123
124
125
126
127
@staticmethod
def format_code(status_code: int) -> str:
    """Convert e.g. HTTP_500_INTERNAL_SERVER_ERROR into 'InternalServerError'"""
    phrase = HTTPStatus(status_code).phrase
    return "".join(word.title() for word in phrase.split())

handle_errors(response) async

If no errors, just return the original response. In case of errors, log, format and return the response contents.

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
async def handle_errors(self, response: StreamingResponse) -> Response:
    """
    If no errors, just return the original response.
    In case of errors, log, format and return the response contents.
    """
    if not 400 <= response.status_code < 600:
        return response  # no error, return the original response

    # Read response content
    try:
        content = await read_streaming_response(response)

    # If we fail to read content, just return the original response
    except Exception as exc:  # pylint: disable=broad-exception-caught
        logger.error(exc)
        return response

    # The content should be formated as a XxxErrorResponse
    formatted: Rfc7807ErrorResponse | StacErrorResponse | None = None
    try:
        if self.rfc7807:
            formatted = Rfc7807ErrorResponse(
                type=str(content["type"]),
                status=int(content["status"]),
                detail=str(content["detail"]),
            )
        else:
            formatted = StacErrorResponse(code=str(content["code"]), description=str(content["description"]))
        if formatted != content:
            formatted = None
    except Exception:  # pylint: disable=broad-exception-caught # nosec B110
        pass

    # Else format the content
    if not formatted:
        description = json.dumps(content) if isinstance(content, (dict, list, set)) else str(content)
        if self.rfc7807:
            formatted = self.rfc7807_response(response.status_code, detail=description)
        else:
            formatted = StacErrorResponse(code=self.format_code(response.status_code), description=description)

    logger.error(f"{response.status_code}: {json.dumps(formatted)}")
    return JSONResponse(status_code=response.status_code, content=formatted)

handle_exceptions(request, exc) async

In case of exceptions, log the response contents

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
async def handle_exceptions(self, request: Request, exc: Exception) -> JSONResponse:
    """In case of exceptions, log the response contents"""

    # Log current stack trace
    logger.exception(exc)

    # Calculate HTTP response status code (int) and StacErrorResponse code (str) and description (str)
    if isinstance(exc, StarletteHTTPException):
        status_code = exc.status_code
        # Format int status code into str
        str_code = self.format_code(exc.status_code)
        description = str(exc.detail)

    else:
        # Use generic 400 or 500 code
        status_code = (
            status.HTTP_400_BAD_REQUEST
            if HandleExceptionsMiddleware.is_bad_request(request, exc)
            else status.HTTP_500_INTERNAL_SERVER_ERROR
        )
        str_code = exc.__class__.__name__
        description = str(exc)

    error_response: Rfc7807ErrorResponse | StacErrorResponse
    if self.rfc7807:
        error_response = self.rfc7807_response(status_code, detail=description)
    else:
        error_response = StacErrorResponse(code=str_code, description=description)
    return JSONResponse(status_code=status_code, content=error_response)

is_bad_request(request, e) staticmethod

Determines if the request that raised this exception shall be considered as a bad request and return a 400 error code.

This function can be overriden by the caller if needed with: HandleExceptionsMiddleware.is_bad_request = my_callable

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
212
213
214
215
216
217
218
219
220
221
222
223
@staticmethod
def is_bad_request(request: Request, e: Exception) -> bool:
    """
    Determines if the request that raised this exception shall be considered as a bad request
    and return a 400 error code.

    This function can be overriden by the caller if needed with:
    HandleExceptionsMiddleware.is_bad_request = my_callable
    """
    return "bbox" in request.query_params and (
        str(e).endswith(" must have 4 or 6 values.") or str(e).startswith("could not convert string to float: ")
    )

rfc7807_response(status_code, detail) staticmethod

Return Rfc7807ErrorResponse instance

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
129
130
131
132
133
134
135
136
@staticmethod
def rfc7807_response(status_code: int, detail: str) -> Rfc7807ErrorResponse:
    """Return Rfc7807ErrorResponse instance"""
    return Rfc7807ErrorResponse(
        type=f"https://developer.mozilla.org/en/docs/Web/HTTP/Reference/Status/{status_code}",
        status=status_code,
        detail=detail,
    )

HealthMiddleware

Bases: BaseHTTPMiddleware

When Kubernetes calls the /health or /ping endpoint from this service, return response immediately, because if the latency is too high (>2s) Kubernetes will kill and restart the pod.

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class HealthMiddleware(BaseHTTPMiddleware):
    """
    When Kubernetes calls the /health or /ping endpoint from this service, return response immediately,
    because if the latency is too high (>2s) Kubernetes will kill and restart the pod.
    """

    async def dispatch(self, request: Request, call_next: Callable):
        """Middleware implementation"""

        if request.url.path in ["/health", "/_mgmt/health", "/catalog/_mgmt/health"]:
            # NOTE: for the catalog we could call "await self.api.health_check(request)" like in stac_fastapi.api.app
            # but this async call may be slow and so may kill the pod. So we hardcode the response instead.
            return JSONResponse({"healthy": True}, status.HTTP_200_OK)
        if request.url.path in ["/ping", "/_mgmt/ping", "/catalog/_mgmt/ping"]:
            return JSONResponse({"message": "PONG"}, status.HTTP_200_OK)

        # All other endpoints
        return await call_next(request)

dispatch(request, call_next) async

Middleware implementation

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
232
233
234
235
236
237
238
239
240
241
242
243
async def dispatch(self, request: Request, call_next: Callable):
    """Middleware implementation"""

    if request.url.path in ["/health", "/_mgmt/health", "/catalog/_mgmt/health"]:
        # NOTE: for the catalog we could call "await self.api.health_check(request)" like in stac_fastapi.api.app
        # but this async call may be slow and so may kill the pod. So we hardcode the response instead.
        return JSONResponse({"healthy": True}, status.HTTP_200_OK)
    if request.url.path in ["/ping", "/_mgmt/ping", "/catalog/_mgmt/ping"]:
        return JSONResponse({"message": "PONG"}, status.HTTP_200_OK)

    # All other endpoints
    return await call_next(request)

Rfc7807ErrorResponse

Bases: TypedDict

A JSON error response returned by the API, compliant with the RFC 7807 specification.

Attributes:

Name Type Description
type str

https://developer.mozilla.org/en/docs/Web/HTTP/Reference/Status/{status_code}

status int

HTTP response status code

detail str

A description of the error.

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
70
71
72
73
74
75
76
77
78
79
80
81
class Rfc7807ErrorResponse(TypedDict):
    """A JSON error response returned by the API, compliant with the RFC 7807 specification.

    Attributes:
        type: https://developer.mozilla.org/en/docs/Web/HTTP/Reference/Status/{status_code}
        status: HTTP response status code
        detail: A description of the error.
    """

    type: str
    status: int
    detail: str

StacErrorResponse

Bases: TypedDict

A JSON error response returned by the API, compliant with the STAC specification.

The STAC API spec expects that code and description are both present in the payload.

Attributes:

Name Type Description
code str

A code representing the error, semantics are up to implementor.

description str

A description of the error.

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
55
56
57
58
59
60
61
62
63
64
65
66
67
class StacErrorResponse(TypedDict):
    """A JSON error response returned by the API, compliant with the STAC specification.

    The STAC API spec expects that `code` and `description` are both present in
    the payload.

    Attributes:
        code: A code representing the error, semantics are up to implementor.
        description: A description of the error.
    """

    code: str
    description: str

insert_middleware_after(app, previous_mw_class, middleware_class, *args, **kwargs)

Insert the given middleware after an existing one in a FastAPI application.

Parameters:

Name Type Description Default
app FastAPI

FastAPI application

required
previous_mw_class str

Class of middleware after which the new middleware has to be inserted

required
middleware_class Middleware

Class of middleware to insert

required
args args

args for middleware_class constructor

()
kwargs kwargs

kwargs for middleware_class constructor

{}

Raises:

Type Description
RuntimeError

if the application has already started

Returns:

Name Type Description
FastAPI

The modified FastAPI application instance with the required middleware.

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def insert_middleware_after(
    app: FastAPI,
    previous_mw_class: _MiddlewareFactory,
    middleware_class: _MiddlewareFactory[P],
    *args: P.args,
    **kwargs: P.kwargs,
):
    """Insert the given middleware after an existing one in a FastAPI application.

    Args:
        app (FastAPI): FastAPI application
        previous_mw_class (str): Class of middleware after which the new middleware has to be inserted
        middleware_class (Middleware): Class of middleware to insert
        args: args for middleware_class constructor
        kwargs: kwargs for middleware_class constructor

    Raises:
        RuntimeError: if the application has already started

    Returns:
        FastAPI: The modified FastAPI application instance with the required middleware.
    """
    existing_middlewares = [middleware.cls for middleware in app.user_middleware]
    middleware_index = existing_middlewares.index(previous_mw_class)
    return insert_middleware_at(app, middleware_index + 1, Middleware(middleware_class, *args, **kwargs))

insert_middleware_at(app, index, middleware)

Insert the given middleware at the specified index in a FastAPI application.

Parameters:

Name Type Description Default
app FastAPI

FastAPI application

required
index int

index at which the middleware has to be inserted

required
middleware Middleware

Middleware to insert

required

Raises:

Type Description
RuntimeError

if the application has already started

Returns:

Name Type Description
FastAPI

The modified FastAPI application instance with the required middleware.

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def insert_middleware_at(app: FastAPI, index: int, middleware: Middleware):
    """Insert the given middleware at the specified index in a FastAPI application.

    Args:
        app (FastAPI): FastAPI application
        index (int): index at which the middleware has to be inserted
        middleware (Middleware): Middleware to insert

    Raises:
        RuntimeError: if the application has already started

    Returns:
        FastAPI: The modified FastAPI application instance with the required middleware.
    """
    if app.middleware_stack:
        raise RuntimeError("Cannot add middleware after an application has started")
    if not any(m.cls == middleware.cls for m in app.user_middleware):
        logger.debug("Adding %s", middleware)
        app.user_middleware.insert(index, middleware)
    return app

read_streaming_response(response) async

Read a json-formatted streaming response content

Source code in docs/rs-dpr-service/rs_dpr_service/utils/middlewares.py
42
43
44
45
46
47
48
49
50
51
52
async def read_streaming_response(response: StreamingResponse) -> Any | None:
    """Read a json-formatted streaming response content"""
    body = [chunk async for chunk in response.body_iterator]
    splits = map(lambda x: x if isinstance(x, bytes) else x.encode(), body)  # type: ignore[union-attr]
    str_content = b"".join(splits).decode()
    py_content = json.loads(str_content) if str_content else None

    # Reset the StreamingResponse so it can be used again
    response.body_iterator = iterate_in_threadpool(iter(body))

    return py_content