当前位置: 首页 > 工具软件 > Fast Assert > 使用案例 >

fastapi 在中间件中获取requestBody

平航
2023-12-01

FastAPI是基于 Starlette 并实现了ASGI规范,所以可以使用任何 ASGI 中间件

创建 ASGI 中间件

创建 ASGI 中间件最常用的方法是使用类。

class ASGIMiddleware:
    def __init__(self, app):
        self.app = app

    async def __call__(self, scope, receive, send):
        await self.app(scope, receive, send)

上面的中间件是最基本的ASGI中间件。它接收一个父 ASGI 应用程序作为其构造函数的参数,并实现一个async def __call__调用该父应用程序的方法.

BaseHTTPMiddleware

statlette 提供了BaseHTTPMiddleware抽象类,方便用户实现中间件,要使用 实现中间件类BaseHTTPMiddleware,必须重写该 async def dispatch(request, call_next)方法。

可以先看下BaseHTTPMiddleware的源码:


class BaseHTTPMiddleware:
    def __init__(
        self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
    ) -> None:
        self.app = app
        self.dispatch_func = self.dispatch if dispatch is None else dispatch

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        async def call_next(request: Request) -> Response:
            app_exc: typing.Optional[Exception] = None
            send_stream, recv_stream = anyio.create_memory_object_stream()

            async def coro() -> None:
                nonlocal app_exc

                async with send_stream:
                    try:
                        await self.app(scope, request.receive, send_stream.send) #调用app
                    except Exception as exc:
                        app_exc = exc

            task_group.start_soon(coro)

            try:
                message = await recv_stream.receive()
            except anyio.EndOfStream:
                if app_exc is not None:
                    raise app_exc
                raise RuntimeError("No response returned.")

            assert message["type"] == "http.response.start"

            async def body_stream() -> typing.AsyncGenerator[bytes, None]:  # 获取response
                async with recv_stream:
                    async for message in recv_stream:
                        assert message["type"] == "http.response.body"
                        yield message.get("body", b"")

                if app_exc is not None:
                    raise app_exc

            response = StreamingResponse(
                status_code=message["status"], content=body_stream()
            )
            response.raw_headers = message["headers"]
            return response

        async with anyio.create_task_group() as task_group:
            request = Request(scope, receive=receive)
            response = await self.dispatch_func(request, call_next)
            await response(scope, receive, send)
            task_group.cancel_scope.cancel()

    async def dispatch(
        self, request: Request, call_next: RequestResponseEndpoint
    ) -> Response:
        raise NotImplementedError()  # pragma: no cover

在中间件中获取requestBody

BaseHTTPMiddleware的__call__方法中通过 调用 await self.dispatch_func(request, call_next), 执行用户重写的dispatch方法。用户在dispatch中接收到的call_next参数,在BaseHTTPMiddleware的__call__方法中已经定义,他的主要作用分两部分,一是调用ASGIApp, 二是返回了response.

由于因为响应主体在从流中读取它时会被消耗,每个请求周期只能存活一次,在BaseHTTPMiddleware.call_next()中调用ASGIApp时被消耗,所以,直接在BaseHTTPMiddleware.dispatch方法中无法获取到body.

class BaseHTTPMiddleware(BaseHTTPMiddleware):
    
    async def dispatch(
        self, request: Request, call_next: RequestResponseEndpoint
    ) -> Response:
        print(request._body) # 结果为空

解决方案

  1. 使用原生ASGI 中间件
    
class MyMiddleware:
    def __init__(
            self,
            app: ASGIApp,
    ) -> None:
        self.app = app

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
 
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        done = False
        chunks: "List[bytes]" = []
        async def wrapped_receive() -> Message:
            nonlocal done
            message = await receive()
            if message["type"] == "http.disconnect":
                done = True
                return message
            body = message.get("body", b"")

            more_body = message.get("more_body", False)
            if not more_body:
                done = True
            chunks.append(body)
            return message

        try:
            await self.app(scope, wrapped_receive, send)
        finally:
            while not done:
                await wrapped_receive()

            body = b"".join(chunks)
            print(body)
           
    

以上通过定义done检查响应流是否加载完毕,将wrapped_receive传给app的同时使用chunks记录body。

但是这样,如果我们需要Response对象,需要重新实现。
我们可以借助BaseHTTPMiddleware, 重写dispatch, 只需要在receive被消耗前记录body.
中先实例化了request,Request(scope, receive=receive)。 将request传给call_next().
最后在调用app,把request.receive传给app.因此我们可以实现 wrapped_receive(),把wrapped_receive赋值给request.receive实现记录body.

实现如下:

class MyMiddleware(BaseHTTPMiddleware):
    
    def __init__(self,  app: ASGIApp, dispatch: DispatchFunction = None):
        super(BehaviorRecord, self).__init__(app=app, dispatch=dispatch)

    async def dispatch(self, request: Request, call_next):
        done = False
        chunks: "List[bytes]" = []
        receive = request.receive

        async def wrapped_receive() -> Message:  # 取body
            nonlocal done
            message = await receive()
            if message["type"] == "http.disconnect":
                done = True
                return message
            body = message.get("body", b"")
            more_body = message.get("more_body", False)
            if not more_body:
                done = True
            chunks.append(body)
            return message

        request._receive = wrapped_receive  # 赋值给_receive, 达到在call_next使用wrapped_receive的目的
        start_time = time.time()
        response = await call_next(request)
        while not done:
            await wrapped_receive()
        process_time = (time.time() - start_time)
        response.headers["Response-Time"] = str(process_time)  # 可以使用response, 添加信息
        body = b"".join(chunks)
        logging.info({'requestBody':body})
        return response

 类似资料: