89 lines
2.8 KiB
Python
89 lines
2.8 KiB
Python
"""FastAPI中间件"""
|
|
import time
|
|
from typing import Callable
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
|
|
from .trace import generate_trace_id, set_trace_id, get_trace_id
|
|
from .logger import get_logger
|
|
|
|
|
|
class TraceMiddleware(BaseHTTPMiddleware):
|
|
"""链路追踪中间件
|
|
|
|
为每个请求生成或传递trace_id
|
|
"""
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
# 从header获取或生成trace_id
|
|
trace_id = request.headers.get("X-Trace-ID") or generate_trace_id()
|
|
set_trace_id(trace_id)
|
|
|
|
response = await call_next(request)
|
|
|
|
# 在响应header中返回trace_id
|
|
response.headers["X-Trace-ID"] = trace_id
|
|
|
|
return response
|
|
|
|
|
|
class LoggingMiddleware(BaseHTTPMiddleware):
|
|
"""请求日志中间件
|
|
|
|
记录每个请求的基本信息和耗时
|
|
"""
|
|
|
|
def __init__(self, app, app_code: str = "unknown"):
|
|
super().__init__(app)
|
|
self.app_code = app_code
|
|
self.logger = get_logger(app_code)
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
# 确保有trace_id
|
|
trace_id = get_trace_id()
|
|
if not trace_id:
|
|
trace_id = request.headers.get("X-Trace-ID") or generate_trace_id()
|
|
set_trace_id(trace_id)
|
|
|
|
start_time = time.time()
|
|
|
|
# 记录请求开始
|
|
self.logger.info(
|
|
f"Request started: {request.method} {request.url.path}",
|
|
method=request.method,
|
|
path=str(request.url.path)
|
|
)
|
|
|
|
try:
|
|
response = await call_next(request)
|
|
|
|
# 计算耗时
|
|
duration_ms = int((time.time() - start_time) * 1000)
|
|
|
|
# 记录请求完成
|
|
self.logger.info(
|
|
f"Request completed: {response.status_code}",
|
|
method=request.method,
|
|
path=str(request.url.path),
|
|
status_code=response.status_code,
|
|
duration_ms=duration_ms
|
|
)
|
|
|
|
# 添加trace_id到响应header
|
|
response.headers["X-Trace-ID"] = trace_id
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
duration_ms = int((time.time() - start_time) * 1000)
|
|
|
|
self.logger.error(
|
|
f"Request failed: {str(e)}",
|
|
error=e,
|
|
method=request.method,
|
|
path=str(request.url.path),
|
|
duration_ms=duration_ms
|
|
)
|
|
raise
|