Skip to content

rs_server_common/middlewares.md

<< Back to index

Common functions for fastapi middlewares

AuthenticationMiddleware

Bases: BaseHTTPMiddleware

Implement authentication verification.

Source code in docs/rs-server/services/common/rs_server_common/middlewares.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class AuthenticationMiddleware(BaseHTTPMiddleware):  # pylint: disable=too-few-public-methods
    """
    Implement authentication verification.
    """

    def __init__(self, app, must_be_authenticated, dispatch=None):
        self.must_be_authenticated = must_be_authenticated
        super().__init__(app, dispatch)

    async def dispatch(self, request: Request, call_next: Callable):
        """
        Middleware implementation.
        """

        if common_settings.CLUSTER_MODE and self.must_be_authenticated(request.url.path):
            try:
                # Check the api key validity, passed in HTTP header, or oauth2 autentication (keycloak)
                await authentication.authenticate(
                    request=request,
                    apikey_value=request.headers.get(APIKEY_HEADER, None),
                )

            # Login and redirect to the calling endpoint.
            except LoginAndRedirect:
                return await oauth2.login(request)

        # Call the next middleware
        return await call_next(request)

dispatch(request, call_next) async

Middleware implementation.

Source code in docs/rs-server/services/common/rs_server_common/middlewares.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
async def dispatch(self, request: Request, call_next: Callable):
    """
    Middleware implementation.
    """

    if common_settings.CLUSTER_MODE and self.must_be_authenticated(request.url.path):
        try:
            # Check the api key validity, passed in HTTP header, or oauth2 autentication (keycloak)
            await authentication.authenticate(
                request=request,
                apikey_value=request.headers.get(APIKEY_HEADER, None),
            )

        # Login and redirect to the calling endpoint.
        except LoginAndRedirect:
            return await oauth2.login(request)

    # Call the next middleware
    return await call_next(request)

ErrorResponse

Bases: TypedDict

A JSON error response returned by the API.

The STAC API spec expects that code and description are both present in the payload.

Attributes:

Name Type Description
code str

A code representing the error, semantics are up to implementor.

description str

A description of the error.

Source code in docs/rs-server/services/common/rs_server_common/middlewares.py
58
59
60
61
62
63
64
65
66
67
68
69
70
class ErrorResponse(TypedDict):
    """A JSON error response returned by the API.

    The STAC API spec expects that `code` and `description` are both present in
    the payload.

    Attributes:
        code: A code representing the error, semantics are up to implementor.
        description: A description of the error.
    """

    code: str
    description: str

HandleExceptionsMiddleware

Bases: BaseHTTPMiddleware

Middleware to catch all exceptions and return a JSONResponse instead of raising them. This is useful in FastAPI when HttpExceptions are raised within the code but need to be handled gracefully.

Source code in docs/rs-server/services/common/rs_server_common/middlewares.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class HandleExceptionsMiddleware(BaseHTTPMiddleware):  # pylint: disable=too-few-public-methods
    """
    Middleware to catch all exceptions and return a JSONResponse instead of raising them.
    This is useful in FastAPI when HttpExceptions are raised within the code but need to be handled gracefully.
    """

    async def dispatch(self, request: Request, call_next: Callable):
        try:
            return await call_next(request)
        except StarletteHTTPException as http_exception:
            # Log stack trace and return HTTP exception details
            logger.error(traceback.format_exc())
            return JSONResponse(status_code=http_exception.status_code, content=str(http_exception.detail))
        except Exception as exception:  # pylint: disable=broad-exception-caught
            # Log stack trace and return generic error response
            logger.error(traceback.format_exc())
            return (
                JSONResponse(
                    content=ErrorResponse(code=exception.__class__.__name__, description=str(exception)),
                    status_code=status.HTTP_400_BAD_REQUEST,
                )
                if self.is_bad_request(request, exception)
                else JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=str(exception))
            )

    def is_bad_request(self, request: Request, e: Exception) -> bool:
        """Determines if the request that raised this exception shall be considered as a bad request"""
        return "bbox" in request.query_params and (
            str(e).endswith(" must have 4 or 6 values.") or str(e).startswith("could not convert string to float: ")
        )

is_bad_request(request, e)

Determines if the request that raised this exception shall be considered as a bad request

Source code in docs/rs-server/services/common/rs_server_common/middlewares.py
128
129
130
131
132
def is_bad_request(self, request: Request, e: Exception) -> bool:
    """Determines if the request that raised this exception shall be considered as a bad request"""
    return "bbox" in request.query_params and (
        str(e).endswith(" must have 4 or 6 values.") or str(e).startswith("could not convert string to float: ")
    )

PaginationLinksMiddleware

Bases: BaseHTTPMiddleware

Middleware to implement 'first' button's functionality in STAC Browser

Source code in docs/rs-server/services/common/rs_server_common/middlewares.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
class PaginationLinksMiddleware(BaseHTTPMiddleware):
    """
    Middleware to implement 'first' button's functionality in STAC Browser
    """

    async def dispatch(
        self,
        request: Request,
        call_next: Callable,
    ):  # pylint: disable=too-many-branches,too-many-statements

        # Only for /search in auxip, prip, cadip
        if request.url.path in ["/auxip/search", "/cadip/search", "/prip/search", "/catalog/search"]:

            first_link: dict[str, Any] = {
                "rel": "first",
                "type": "application/geo+json",
                "method": request.method,
                "href": f"{str(request.base_url).rstrip('/')}{request.url.path}",
                "title": "First link",
            }

            if common_settings.CLUSTER_MODE:
                first_link["href"] = f"https://{str(request.base_url.hostname).rstrip('/')}{request.url.path}"

            if request.method == "GET":
                # parse query params to remove any 'prev' or 'next'
                query_dict = dict(request.query_params)

                query_dict.pop("token", None)
                if "page" in query_dict:
                    query_dict["page"] = "1"
                new_query_string = urlencode(query_dict, doseq=True)
                first_link["href"] += f"?{new_query_string}"

            elif request.method == "POST":
                try:
                    query = await request.json()
                    body = {}

                    for key in ["datetime", "limit"]:
                        if key in query and query[key] is not None:
                            body[key] = query[key]

                    if "token" in query and request.url.path != "/catalog/search":
                        body["token"] = "page=1"  # nosec

                    first_link["body"] = body
                except Exception:  # pylint: disable = broad-exception-caught
                    logger.error(traceback.format_exc())

            response = await call_next(request)

            encoding = response.headers.get("content-encoding", "")
            if encoding == "br":
                body_bytes = b"".join([section async for section in response.body_iterator])
                response_body = brotli.decompress(body_bytes)

                if request.url.path == "/catalog/search":
                    first_link["auth:refs"] = ["apikey", "openid", "oauth2"]
            else:
                response_body = b""
                async for chunk in response.body_iterator:
                    response_body += chunk

            try:
                data = json.loads(response_body)

                links = data.get("links", [])
                has_prev = any(link.get("rel") == "previous" for link in links)

                if has_prev is True:
                    links.append(first_link)
                    data["links"] = links

                headers = dict(response.headers)
                headers.pop("content-length", None)

                if encoding == "br":
                    new_body = brotli.compress(json.dumps(data).encode("utf-8"))
                else:
                    new_body = json.dumps(data).encode("utf-8")

                response = Response(
                    content=new_body,
                    status_code=response.status_code,
                    headers=headers,
                    media_type="application/json",
                )
            except Exception:  # pylint: disable = broad-exception-caught
                headers = dict(response.headers)
                headers.pop("content-length", None)

                response = Response(
                    content=response_body,
                    status_code=response.status_code,
                    headers=headers,
                    media_type=response.headers.get("content-type"),
                )
        else:
            return await call_next(request)

        return response

StacLinksTitleMiddleware

Bases: BaseHTTPMiddleware

Middleware used to update links with title

Source code in docs/rs-server/services/common/rs_server_common/middlewares.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
class StacLinksTitleMiddleware(BaseHTTPMiddleware):
    """Middleware used to update links with title"""

    def __init__(self, app: FastAPI, title: str = "Default Title"):
        """
        Initialize the middleware.

        Args:
            app: The FastAPI application instance to attach the middleware to.
            title: Default title to use for STAC links if no specific title is provided.
        """
        super().__init__(app)
        self.title = title

    async def dispatch(self, request: Request, call_next):
        """
        Intercept and modify outgoing responses to ensure all STAC links have proper titles.

        This middleware method:
        1. Awaits the response from the next handler.
        2. Reads and parses the response body as JSON.
        3. Updates the "title" property of each link using `get_link_title`.
        4. Rebuilds the response without the original Content-Length header to prevent mismatches.
        5. If the response body is not JSON, returns it unchanged.

        Args:
            request: The incoming FastAPI Request object.
            call_next: The next ASGI handler in the middleware chain.

        Returns:
            A FastAPI Response object with updated STAC link titles.
        """
        response = await call_next(request)

        body = b""
        async for chunk in response.body_iterator:
            body += chunk

        try:
            data = json.loads(body)

            if isinstance(data, dict) and "links" in data:
                for link in data["links"]:
                    if isinstance(link, dict):
                        # normalize href to decode any %xx
                        if "href" in link:
                            link["href"] = normalize_href(link["href"])
                        # update title
                        link["title"] = get_link_title(link, data)

            headers = dict(response.headers)
            headers.pop("content-length", None)

            response = Response(
                content=json.dumps(data, ensure_ascii=False).encode("utf-8"),
                status_code=response.status_code,
                headers=headers,
                media_type="application/json",
            )
        except Exception:  # pylint: disable = broad-exception-caught
            headers = dict(response.headers)
            headers.pop("content-length", None)

            response = Response(
                content=body,
                status_code=response.status_code,
                headers=headers,
                media_type=response.headers.get("content-type"),
            )

        return response

__init__(app, title='Default Title')

Initialize the middleware.

Parameters:

Name Type Description Default
app FastAPI

The FastAPI application instance to attach the middleware to.

required
title str

Default title to use for STAC links if no specific title is provided.

'Default Title'
Source code in docs/rs-server/services/common/rs_server_common/middlewares.py
278
279
280
281
282
283
284
285
286
287
def __init__(self, app: FastAPI, title: str = "Default Title"):
    """
    Initialize the middleware.

    Args:
        app: The FastAPI application instance to attach the middleware to.
        title: Default title to use for STAC links if no specific title is provided.
    """
    super().__init__(app)
    self.title = title

dispatch(request, call_next) async

Intercept and modify outgoing responses to ensure all STAC links have proper titles.

This middleware method: 1. Awaits the response from the next handler. 2. Reads and parses the response body as JSON. 3. Updates the "title" property of each link using get_link_title. 4. Rebuilds the response without the original Content-Length header to prevent mismatches. 5. If the response body is not JSON, returns it unchanged.

Parameters:

Name Type Description Default
request Request

The incoming FastAPI Request object.

required
call_next

The next ASGI handler in the middleware chain.

required

Returns:

Type Description

A FastAPI Response object with updated STAC link titles.

Source code in docs/rs-server/services/common/rs_server_common/middlewares.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
async def dispatch(self, request: Request, call_next):
    """
    Intercept and modify outgoing responses to ensure all STAC links have proper titles.

    This middleware method:
    1. Awaits the response from the next handler.
    2. Reads and parses the response body as JSON.
    3. Updates the "title" property of each link using `get_link_title`.
    4. Rebuilds the response without the original Content-Length header to prevent mismatches.
    5. If the response body is not JSON, returns it unchanged.

    Args:
        request: The incoming FastAPI Request object.
        call_next: The next ASGI handler in the middleware chain.

    Returns:
        A FastAPI Response object with updated STAC link titles.
    """
    response = await call_next(request)

    body = b""
    async for chunk in response.body_iterator:
        body += chunk

    try:
        data = json.loads(body)

        if isinstance(data, dict) and "links" in data:
            for link in data["links"]:
                if isinstance(link, dict):
                    # normalize href to decode any %xx
                    if "href" in link:
                        link["href"] = normalize_href(link["href"])
                    # update title
                    link["title"] = get_link_title(link, data)

        headers = dict(response.headers)
        headers.pop("content-length", None)

        response = Response(
            content=json.dumps(data, ensure_ascii=False).encode("utf-8"),
            status_code=response.status_code,
            headers=headers,
            media_type="application/json",
        )
    except Exception:  # pylint: disable = broad-exception-caught
        headers = dict(response.headers)
        headers.pop("content-length", None)

        response = Response(
            content=body,
            status_code=response.status_code,
            headers=headers,
            media_type=response.headers.get("content-type"),
        )

    return response

apply_middlewares(app)

Applies necessary middlewares and authentication routes to the FastAPI application.

This function ensures that: 1. SessionMiddleware is inserted after HandleExceptionsMiddleware to enable cookie storage. 2. OAuth2 authentication routes are added to the FastAPI application.

Parameters:

Name Type Description Default
app FastAPI

The FastAPI application instance.

required

Raises:

Type Description
RuntimeError

If the function is called after the application has already started.

Returns:

Name Type Description
FastAPI

The modified FastAPI application instance with the required middleware and authentication routes.

Source code in docs/rs-server/services/common/rs_server_common/middlewares.py
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
def apply_middlewares(app: FastAPI):
    """
    Applies necessary middlewares and authentication routes to the FastAPI application.

    This function ensures that:
    1. `SessionMiddleware` is inserted after `HandleExceptionsMiddleware` to enable cookie storage.
    2. OAuth2 authentication routes are added to the FastAPI application.

    Args:
        app (FastAPI): The FastAPI application instance.

    Raises:
        RuntimeError: If the function is called after the application has already started.

    Returns:
        FastAPI: The modified FastAPI application instance with the required middleware and authentication routes.
    """

    # Insert the SessionMiddleware (to save cookies) after the HandleExceptionsMiddleware middleware.
    # Code copy/pasted from app.add_middleware(SessionMiddleware, secret_key=cookie_secret)
    cookie_secret = os.environ["RSPY_COOKIE_SECRET"]
    insert_middleware_after(app, HandleExceptionsMiddleware, SessionMiddleware, secret_key=cookie_secret)

    # Get the oauth2 router
    oauth2_router = oauth2.get_router(app)

    # Add it to the FastAPI application
    app.include_router(
        oauth2_router,
        tags=["Authentication"],
        prefix=AUTH_PREFIX,
        include_in_schema=True,
    )
    return app

Determine a human-readable STAC link title based on the link relation and context.

Source code in docs/rs-server/services/common/rs_server_common/middlewares.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def get_link_title(link: dict, entity: dict) -> str:
    """
    Determine a human-readable STAC link title based on the link relation and context.
    """
    rel = link.get("rel")
    href = link.get("href", "")
    if "title" in link:
        # don't overwrite
        return link["title"]
    match rel:
        # --- special cases needing entity context ---
        case "collection":
            return entity.get("title") or entity.get("id") or REL_TITLES["collection"]
        case "item":
            return entity.get("title") or entity.get("id") or REL_TITLES["item"]
        case "self" if entity.get("type") == "Catalog":
            return "STAC Landing Page"
        case "self" if href.endswith("/collections"):
            return "All Collections"
        case "child":
            path = urlparse(href).path
            collection_id = path.split("/")[-1] if path else "unknown"
            return f"All from collection {collection_id}"
        # --- all others: just lookup in REL_TITLES ---
        case _:
            return REL_TITLES.get(rel, href or "Unknown Entity")  # type: ignore

insert_middleware_after(app, previous_mw_class, middleware_class, *args, **kwargs)

Insert the given middleware after an existing one in a FastAPI application.

Parameters:

Name Type Description Default
app FastAPI

FastAPI application

required
previous_mw_class str

Class of middleware after which the new middleware has to be inserted

required
middleware_class Middleware

Class of middleware to insert

required
args args

args for middleware_class constructor

()
kwargs kwargs

kwargs for middleware_class constructor

{}

Raises:

Type Description
RuntimeError

if the application has already started

Returns:

Name Type Description
FastAPI

The modified FastAPI application instance with the required middleware.

Source code in docs/rs-server/services/common/rs_server_common/middlewares.py
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
def insert_middleware_after(
    app: FastAPI,
    previous_mw_class: _MiddlewareFactory,
    middleware_class: _MiddlewareFactory[P],
    *args: P.args,
    **kwargs: P.kwargs,
):
    """Insert the given middleware after an existing one in a FastAPI application.

    Args:
        app (FastAPI): FastAPI application
        previous_mw_class (str): Class of middleware after which the new middleware has to be inserted
        middleware_class (Middleware): Class of middleware to insert
        args: args for middleware_class constructor
        kwargs: kwargs for middleware_class constructor

    Raises:
        RuntimeError: if the application has already started

    Returns:
        FastAPI: The modified FastAPI application instance with the required middleware.
    """
    # Existing middlewares
    middleware_names = [middleware.cls for middleware in app.user_middleware]
    middleware_index = middleware_names.index(previous_mw_class)
    return insert_middleware_at(app, middleware_index + 1, Middleware(middleware_class, *args, **kwargs))

insert_middleware_at(app, index, middleware)

Insert the given middleware at the specified index in a FastAPI application.

Parameters:

Name Type Description Default
app FastAPI

FastAPI application

required
index int

index at which the middleware has to be inserted

required
middleware Middleware

Middleware to insert

required

Raises:

Type Description
RuntimeError

if the application has already started

Returns:

Name Type Description
FastAPI

The modified FastAPI application instance with the required middleware.

Source code in docs/rs-server/services/common/rs_server_common/middlewares.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
def insert_middleware_at(app: FastAPI, index: int, middleware: Middleware):
    """Insert the given middleware at the specified index in a FastAPI application.

    Args:
        app (FastAPI): FastAPI application
        index (int): index at which the middleware has to be inserted
        middleware (Middleware): Middleware to insert

    Raises:
        RuntimeError: if the application has already started

    Returns:
        FastAPI: The modified FastAPI application instance with the required middleware.
    """
    if app.middleware_stack:
        raise RuntimeError("Cannot add middleware after an application has started")
    if not any(m.cls == middleware.cls for m in app.user_middleware):
        logger.debug("Adding %s", middleware)
        app.user_middleware.insert(index, middleware)
    return app

normalize_href(href)

Encode query parameters in href to match expected STAC format.

Source code in docs/rs-server/services/common/rs_server_common/middlewares.py
268
269
270
271
272
def normalize_href(href: str) -> str:
    """Encode query parameters in href to match expected STAC format."""
    parsed = urlparse(href)
    query = urlencode(parse_qsl(parsed.query), safe="")  # encode ":" -> "%3A"
    return urlunparse(parsed._replace(query=query))