FastAPI是基于 Starlette 并实现了ASGI规范,所以可以使用任何 ASGI 中间件
创建 ASGI 中间件最常用的方法是使用类。
class ASGIMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
await self.app(scope, receive, send)
上面的中间件是最基本的ASGI中间件。它接收一个父 ASGI 应用程序作为其构造函数的参数,并实现一个async def __call__调用该父应用程序的方法.
statlette 提供了BaseHTTPMiddleware抽象类,方便用户实现中间件,要使用 实现中间件类BaseHTTPMiddleware,必须重写该 async def dispatch(request, call_next)方法。
可以先看下BaseHTTPMiddleware的源码:
class BaseHTTPMiddleware:
def __init__(
self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None
) -> None:
self.app = app
self.dispatch_func = self.dispatch if dispatch is None else dispatch
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream, recv_stream = anyio.create_memory_object_stream()
async def coro() -> None:
nonlocal app_exc
async with send_stream:
try:
await self.app(scope, request.receive, send_stream.send) #调用app
except Exception as exc:
app_exc = exc
task_group.start_soon(coro)
try:
message = await recv_stream.receive()
except anyio.EndOfStream:
if app_exc is not None:
raise app_exc
raise RuntimeError("No response returned.")
assert message["type"] == "http.response.start"
async def body_stream() -> typing.AsyncGenerator[bytes, None]: # 获取response
async with recv_stream:
async for message in recv_stream:
assert message["type"] == "http.response.body"
yield message.get("body", b"")
if app_exc is not None:
raise app_exc
response = StreamingResponse(
status_code=message["status"], content=body_stream()
)
response.raw_headers = message["headers"]
return response
async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
task_group.cancel_scope.cancel()
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
raise NotImplementedError() # pragma: no cover
BaseHTTPMiddleware的__call__方法中通过 调用 await self.dispatch_func(request, call_next), 执行用户重写的dispatch方法。用户在dispatch中接收到的call_next参数,在BaseHTTPMiddleware的__call__方法中已经定义,他的主要作用分两部分,一是调用ASGIApp, 二是返回了response.
由于因为响应主体在从流中读取它时会被消耗,每个请求周期只能存活一次,在BaseHTTPMiddleware.call_next()中调用ASGIApp时被消耗,所以,直接在BaseHTTPMiddleware.dispatch方法中无法获取到body.
class BaseHTTPMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
print(request._body) # 结果为空
class MyMiddleware:
def __init__(
self,
app: ASGIApp,
) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
done = False
chunks: "List[bytes]" = []
async def wrapped_receive() -> Message:
nonlocal done
message = await receive()
if message["type"] == "http.disconnect":
done = True
return message
body = message.get("body", b"")
more_body = message.get("more_body", False)
if not more_body:
done = True
chunks.append(body)
return message
try:
await self.app(scope, wrapped_receive, send)
finally:
while not done:
await wrapped_receive()
body = b"".join(chunks)
print(body)
以上通过定义done检查响应流是否加载完毕,将wrapped_receive传给app的同时使用chunks记录body。
但是这样,如果我们需要Response对象,需要重新实现。
我们可以借助BaseHTTPMiddleware, 重写dispatch, 只需要在receive被消耗前记录body.
中先实例化了request,Request(scope, receive=receive)
。 将request传给call_next().
最后在调用app,把request.receive传给app.因此我们可以实现 wrapped_receive(),把wrapped_receive赋值给request.receive实现记录body.
实现如下:
class MyMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, dispatch: DispatchFunction = None):
super(BehaviorRecord, self).__init__(app=app, dispatch=dispatch)
async def dispatch(self, request: Request, call_next):
done = False
chunks: "List[bytes]" = []
receive = request.receive
async def wrapped_receive() -> Message: # 取body
nonlocal done
message = await receive()
if message["type"] == "http.disconnect":
done = True
return message
body = message.get("body", b"")
more_body = message.get("more_body", False)
if not more_body:
done = True
chunks.append(body)
return message
request._receive = wrapped_receive # 赋值给_receive, 达到在call_next使用wrapped_receive的目的
start_time = time.time()
response = await call_next(request)
while not done:
await wrapped_receive()
process_time = (time.time() - start_time)
response.headers["Response-Time"] = str(process_time) # 可以使用response, 添加信息
body = b"".join(chunks)
logging.info({'requestBody':body})
return response