"""FastAPI中间件""" import time from typing import Callable from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response from .trace import generate_trace_id, set_trace_id, get_trace_id from .logger import get_logger class TraceMiddleware(BaseHTTPMiddleware): """链路追踪中间件 为每个请求生成或传递trace_id """ async def dispatch(self, request: Request, call_next: Callable) -> Response: # 从header获取或生成trace_id trace_id = request.headers.get("X-Trace-ID") or generate_trace_id() set_trace_id(trace_id) response = await call_next(request) # 在响应header中返回trace_id response.headers["X-Trace-ID"] = trace_id return response class LoggingMiddleware(BaseHTTPMiddleware): """请求日志中间件 记录每个请求的基本信息和耗时 """ def __init__(self, app, app_code: str = "unknown"): super().__init__(app) self.app_code = app_code self.logger = get_logger(app_code) async def dispatch(self, request: Request, call_next: Callable) -> Response: # 确保有trace_id trace_id = get_trace_id() if not trace_id: trace_id = request.headers.get("X-Trace-ID") or generate_trace_id() set_trace_id(trace_id) start_time = time.time() # 记录请求开始 self.logger.info( f"Request started: {request.method} {request.url.path}", method=request.method, path=str(request.url.path) ) try: response = await call_next(request) # 计算耗时 duration_ms = int((time.time() - start_time) * 1000) # 记录请求完成 self.logger.info( f"Request completed: {response.status_code}", method=request.method, path=str(request.url.path), status_code=response.status_code, duration_ms=duration_ms ) # 添加trace_id到响应header response.headers["X-Trace-ID"] = trace_id return response except Exception as e: duration_ms = int((time.time() - start_time) * 1000) self.logger.error( f"Request failed: {str(e)}", error=e, method=request.method, path=str(request.url.path), duration_ms=duration_ms ) raise