Skip to content

rs_server_common/utils/pytest/pytest_common_tests.md

<< Back to index

Implement tests that are common to several services.

test_handle_exceptions_middleware(client, mocker, rfc7807=False)

Test that HandleExceptionsMiddleware handles and logs errors as expected.

Parameters:

Name Type Description Default
rfc7807 bool

If true, the returned content is compliant with RFC 7807. This is used by pygeoapi/ogc services.

False
Source code in docs/rs-server/services/common/rs_server_common/utils/pytest/pytest_common_tests.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def test_handle_exceptions_middleware(client, mocker, rfc7807: bool = False):
    """
    Test that HandleExceptionsMiddleware handles and logs errors as expected.

    Args:
        rfc7807 (bool): If true, the returned content is compliant with RFC 7807. This is used by pygeoapi/ogc services.
        False by default = compliant to Stac specifications.
    """
    app = client.app

    # Spy calls to logger.error(...)
    spy_log_error = mocker.spy(middlewares.logger, "error")

    def test_case(
        mocked_endpoint: Callable,
        expected_status: int,
        expected_content: StacErrorResponse | Rfc7807ErrorResponse,
        raise_from_func: bool,
        raise_from_dependency: bool,
    ):
        """
        Test cases.

        Args:
            mocked_endpoint: mocked endpoint implementation. It should return an error or raise an exception.
            expected_status: expected http response status code
            expected_content: expected http response content
            raise_from_func: will the endpoint raise an exception ?
            raise_from_dependency: will the endpoint dependency raise an exception ?
        """

        # Implement a new endpoint that will call our mock
        endpoint_path = "/test_endpoint"

        # Raise exception from the endpoint dependency
        if raise_from_dependency:

            @app.get(endpoint_path)
            def test_endpoint_func(_param=Depends(mocked_endpoint)):
                return "ok"

        # Other cases
        else:

            @app.get(endpoint_path)
            def test_endpoint_func():
                return mocked_endpoint()

        # Call the endpoint
        response = client.get(endpoint_path)

        # Check the expected http response
        assert response.status_code == expected_status  # int status
        # {"code": "xxx", "description": yyy"} or {"type": "xxx", status: yyy, "detail": "zzz"}
        assert response.json() == expected_content

        # Check that logger.error was called once
        spy_log_error.assert_called_once()
        logged_message = spy_log_error.call_args[0][0]

        if raise_from_func or raise_from_dependency:
            # If an exception was raised, then the log was called with the stack trace (exc_info=True arg)
            assert spy_log_error.call_args[1]["exc_info"] is True

            # The logged stack trace should contain either
            # HTTPException(status_code=<expected_status>, detail=<expected_content>)
            # or <ErrorType>(<expected_content>)
            if rfc7807:
                assert expected_content["detail"] in str(logged_message)
            else:
                assert expected_content["description"] in str(logged_message)

        # If no exception, we should have logged the str: '<status>: <message>'
        else:
            assert str(expected_status) in logged_message
            assert json.dumps(expected_content) in logged_message

        # Reset the spy
        spy_log_error.reset_mock()

        # Remove the mocked endpoint
        app.router.routes = list(filter(lambda route: route.path != endpoint_path, app.router.routes))

    ###############
    # Test case 1 #
    ###############

    content = "message from return_error_1"
    if rfc7807:
        error_response = rfc7807_response(status.HTTP_418_IM_A_TEAPOT, detail=content)
    else:
        error_response = StacErrorResponse(code="I'MATeapot", description=content)

    def return_error_1():
        """Test case when the endpoint returns a JSONResponse with a dict content == the expected ErrorResponse"""
        return JSONResponse(status_code=status.HTTP_418_IM_A_TEAPOT, content=error_response)

    test_case(
        mocked_endpoint=return_error_1,
        expected_status=status.HTTP_418_IM_A_TEAPOT,
        expected_content=error_response,
        raise_from_func=False,
        raise_from_dependency=False,
    )

    ###############
    # Test case 2 #
    ###############

    dict_content = {"custom field": "message from return_error_2"}
    if rfc7807:
        expected_content = rfc7807_response(status.HTTP_418_IM_A_TEAPOT, detail=json.dumps(dict_content))
    else:
        expected_content = StacErrorResponse(code="I'MATeapot", description=json.dumps(dict_content))

    def return_error_2():
        """Test case when the endpoint returns a JSONResponse with a dict content != StacErrorResponse"""
        return JSONResponse(status_code=status.HTTP_418_IM_A_TEAPOT, content=dict_content)

    test_case(
        mocked_endpoint=return_error_2,
        expected_status=status.HTTP_418_IM_A_TEAPOT,
        # The returned error content is formated by HandleExceptionsMiddleware
        expected_content=expected_content,
        raise_from_func=False,
        raise_from_dependency=False,
    )

    ###############
    # Test case 3 #
    ###############

    content = "message from return_error_3"
    if rfc7807:
        expected_content = rfc7807_response(status.HTTP_418_IM_A_TEAPOT, detail=content)
    else:
        expected_content = StacErrorResponse(code="I'MATeapot", description=content)

    def return_error_3():
        """Test case when the endpoint returns a JSONResponse with a string content"""
        return JSONResponse(status_code=status.HTTP_418_IM_A_TEAPOT, content=content)

    test_case(
        mocked_endpoint=return_error_3,
        expected_status=status.HTTP_418_IM_A_TEAPOT,
        # The returned error content is formated by HandleExceptionsMiddleware
        expected_content=expected_content,
        raise_from_func=False,
        raise_from_dependency=False,
    )

    ###############
    # Test case 4 #
    ###############

    content = "message from raise_http"
    if rfc7807:
        expected_content = rfc7807_response(status.HTTP_418_IM_A_TEAPOT, detail=content)
    else:
        expected_content = StacErrorResponse(code="I'MATeapot", description=content)

    for exception_type in HTTPException, StarletteHTTPException:

        def raise_http():
            """Test case when the endpoint or dependency raises an HTTPException or StarletteHTTPException"""
            raise exception_type(status.HTTP_418_IM_A_TEAPOT, content)

        for raise_case in True, False:  # raise from either endpoint or dependency
            test_case(
                mocked_endpoint=raise_http,
                expected_status=status.HTTP_418_IM_A_TEAPOT,
                expected_content=expected_content,
                raise_from_func=raise_case,
                raise_from_dependency=not raise_case,
            )

    ###############
    # Test case 5 #
    ###############

    content = "message from raise_value_error"
    if rfc7807:
        expected_content = rfc7807_response(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=content)
    else:
        expected_content = StacErrorResponse(code="ValueError", description=content)

    def raise_value_error():
        """Test case when the endpoint or dependency raises any Exception different than HTTPException"""
        raise ValueError(content)

    for raise_case in True, False:  # raise from either endpoint or dependency
        test_case(
            mocked_endpoint=raise_value_error,
            expected_status=status.HTTP_500_INTERNAL_SERVER_ERROR,  # a generic 500 server-side error is logged
            expected_content=expected_content,
            raise_from_func=raise_case,
            raise_from_dependency=not raise_case,
        )

    # The server can override the HandleExceptionsMiddleware.is_bad_request function
    # that determines if a generic 400 client-side error is logged instead of 500
    old_bad_request = HandleExceptionsMiddleware.is_bad_request
    try:
        HandleExceptionsMiddleware.is_bad_request = lambda *_, **__: True  # always log 400

        if rfc7807:
            expected_content = rfc7807_response(status.HTTP_400_BAD_REQUEST, detail=content)

        for raise_case in True, False:  # raise from either endpoint or dependency
            test_case(
                mocked_endpoint=raise_value_error,
                expected_status=status.HTTP_400_BAD_REQUEST,
                expected_content=expected_content,
                raise_from_func=raise_case,
                raise_from_dependency=not raise_case,
            )

    # Restore old function
    finally:
        HandleExceptionsMiddleware.is_bad_request = old_bad_request

test_middleware_order(client, use_auth_middleware)

Check that the FastAPI application middlewares were inserted in the right order.

When sending a request, the order of the middlewares must be: Health -> CORS -> HandleExceptions -> Session -> Authentication -> [any other middlewares ...] Then after processing the request, the response is sent in the opposite order: [any other middlewares ...] -> Authentication -> Session -> HandleExceptions -> CORS -> Health

The reason for this is that
  • Health returns an HTTP 200 OK status for the /health and /ping probe endpoints, it must be first to be as responsive as possible.
  • Then CORS will respond to requests coming from the stac browsers.
  • Then HandleExceptions is used to format error responses coming from all following middlewares and service.
  • Then Authentication will block access to unauthorized users to all following middlewares and service. But it must be preceded by the SessionMiddleware.

But some services don't use the AuthenticationMiddleware, instead the authentication is implemented as an endpoint dependency. In this case the SessionMiddleware it at the end.

Source code in docs/rs-server/services/common/rs_server_common/utils/pytest/pytest_common_tests.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def test_middleware_order(client, use_auth_middleware: bool):
    """
    Check that the FastAPI application middlewares were inserted in the right order.

    When sending a request, the order of the middlewares must be:
    Health -> CORS -> HandleExceptions -> Session -> Authentication -> [any other middlewares ...]
    Then after processing the request, the response is sent in the opposite order:
    [any other middlewares ...] -> Authentication -> Session -> HandleExceptions -> CORS -> Health

    The reason for this is that:
      - Health returns an HTTP 200 OK status for the /health and /ping probe endpoints, it must be first to be
        as responsive as possible.
      - Then CORS will respond to requests coming from the stac browsers.
      - Then HandleExceptions is used to format error responses coming from all following middlewares and service.
      - Then Authentication will block access to unauthorized users to all following middlewares and service. But it
        must be preceded by the SessionMiddleware.

    But some services don't use the AuthenticationMiddleware, instead the authentication is implemented as an
    endpoint dependency. In this case the SessionMiddleware it at the end.
    """
    service_middlewares = [m.cls for m in client.app.user_middleware]
    str_middlewares = "\n  - ".join([""] + [m.__name__ for m in service_middlewares])
    logger.debug(f"FastAPI middlewares: {str_middlewares}")

    tested_middlewares = [
        (HealthMiddleware, False),
        (CORSMiddleware, True),
        (HandleExceptionsMiddleware, False),
    ]

    # Does the service use the AuthenticationMiddleware ? (catalog and staging)
    if use_auth_middleware:
        tested_middlewares += [
            (SessionMiddleware, False),
            (AuthenticationMiddleware, False),
        ]

    # Test the order of the middlewares
    for tested, optional in tested_middlewares:
        if optional and (tested not in service_middlewares):
            continue

        # Assert that the tested middleware is at the top of the list,
        # and also remove this first list element.
        assert tested == service_middlewares.pop(0)

    # Just check that the SessionMiddleware is somewhere after
    if not use_auth_middleware:
        assert SessionMiddleware in service_middlewares