Initial commit: 000-platform project skeleton
All checks were successful
continuous-integration/drone/push Build is passing
All checks were successful
continuous-integration/drone/push Build is passing
This commit is contained in:
129
sdk/README.md
Normal file
129
sdk/README.md
Normal file
@@ -0,0 +1,129 @@
|
||||
# Platform SDK
|
||||
|
||||
平台基础设施客户端SDK,提供统一的统计上报、日志记录、链路追踪等功能。
|
||||
|
||||
## 安装
|
||||
|
||||
SDK作为共享模块使用,通过软链接引用:
|
||||
|
||||
```bash
|
||||
# 在 _shared 目录创建软链接
|
||||
cd AgentWD/_shared
|
||||
ln -s ../projects/000-platform/sdk platform
|
||||
```
|
||||
|
||||
## 在项目中使用
|
||||
|
||||
### 1. 添加路径
|
||||
|
||||
在项目的 `main.py` 开头添加:
|
||||
|
||||
```python
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加 _shared 到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent / "_shared"))
|
||||
|
||||
from platform import get_logger, StatsClient, LoggingMiddleware
|
||||
```
|
||||
|
||||
### 2. 日志
|
||||
|
||||
```python
|
||||
from platform import get_logger
|
||||
|
||||
logger = get_logger("011-ai-interview")
|
||||
|
||||
# 基础日志
|
||||
logger.info("用户开始面试", user_id=123)
|
||||
logger.error("面试出错", error=e)
|
||||
|
||||
# 审计日志
|
||||
logger.audit("create", "interview", "123", operator="admin")
|
||||
```
|
||||
|
||||
### 3. AI统计上报
|
||||
|
||||
```python
|
||||
from platform import StatsClient
|
||||
|
||||
stats = StatsClient(tenant_id=1, app_code="011-ai-interview")
|
||||
|
||||
# 上报AI调用
|
||||
stats.report_ai_call(
|
||||
module_code="interview",
|
||||
prompt_name="generate_question",
|
||||
model="gpt-4",
|
||||
input_tokens=100,
|
||||
output_tokens=200,
|
||||
latency_ms=1500
|
||||
)
|
||||
```
|
||||
|
||||
### 4. 链路追踪
|
||||
|
||||
```python
|
||||
from platform import TraceContext, get_trace_id
|
||||
|
||||
# 在请求处理中
|
||||
with TraceContext(tenant_id=1, user_id=100) as ctx:
|
||||
print(f"当前trace_id: {ctx.trace_id}")
|
||||
# 所有操作共享同一个trace_id
|
||||
```
|
||||
|
||||
### 5. FastAPI中间件
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from platform import LoggingMiddleware, TraceMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 添加中间件(顺序重要)
|
||||
app.add_middleware(LoggingMiddleware, app_code="011-ai-interview")
|
||||
app.add_middleware(TraceMiddleware)
|
||||
```
|
||||
|
||||
### 6. HTTP客户端
|
||||
|
||||
```python
|
||||
from platform import PlatformHttpClient
|
||||
|
||||
client = PlatformHttpClient(base_url="https://api.example.com")
|
||||
|
||||
# 自动传递trace_id和API Key
|
||||
response = await client.get("/users/1")
|
||||
```
|
||||
|
||||
### 7. 配置读取
|
||||
|
||||
```python
|
||||
from platform import ConfigReader
|
||||
|
||||
config = ConfigReader(tenant_id=1)
|
||||
|
||||
# 读取平台配置
|
||||
app_id = await config.get("wechat", "app_id")
|
||||
|
||||
# 读取环境变量
|
||||
debug = config.get_env("DEBUG", False)
|
||||
```
|
||||
|
||||
## 环境变量
|
||||
|
||||
SDK使用以下环境变量:
|
||||
|
||||
| 变量 | 说明 | 默认值 |
|
||||
|------|------|--------|
|
||||
| `PLATFORM_URL` | 平台服务地址 | - |
|
||||
| `PLATFORM_API_KEY` | 平台API Key | - |
|
||||
|
||||
## 更新日志
|
||||
|
||||
### v0.1.0
|
||||
|
||||
- 初始版本
|
||||
- 支持日志、统计、链路追踪
|
||||
- 支持FastAPI中间件
|
||||
- 支持配置读取
|
||||
46
sdk/__init__.py
Normal file
46
sdk/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Platform SDK - 平台基础设施客户端SDK
|
||||
|
||||
提供统一的统计上报、日志记录、链路追踪等功能。
|
||||
|
||||
使用示例:
|
||||
from platform import get_logger, StatsClient, LoggingMiddleware
|
||||
|
||||
# 日志
|
||||
logger = get_logger("my-app")
|
||||
logger.info("Hello")
|
||||
|
||||
# 统计
|
||||
stats = StatsClient(tenant_id=1, app_code="my-app")
|
||||
stats.report_ai_call(...)
|
||||
|
||||
# FastAPI中间件
|
||||
app.add_middleware(LoggingMiddleware)
|
||||
"""
|
||||
|
||||
from .logger import get_logger, PlatformLogger
|
||||
from .stats_client import StatsClient
|
||||
from .trace import TraceContext, get_trace_id, generate_trace_id
|
||||
from .middleware import LoggingMiddleware, TraceMiddleware
|
||||
from .http_client import PlatformHttpClient
|
||||
from .config_reader import ConfigReader
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
__all__ = [
|
||||
# Logger
|
||||
"get_logger",
|
||||
"PlatformLogger",
|
||||
# Stats
|
||||
"StatsClient",
|
||||
# Trace
|
||||
"TraceContext",
|
||||
"get_trace_id",
|
||||
"generate_trace_id",
|
||||
# Middleware
|
||||
"LoggingMiddleware",
|
||||
"TraceMiddleware",
|
||||
# HTTP
|
||||
"PlatformHttpClient",
|
||||
# Config
|
||||
"ConfigReader",
|
||||
]
|
||||
105
sdk/config_reader.py
Normal file
105
sdk/config_reader.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""配置读取客户端"""
|
||||
import os
|
||||
from typing import Optional, Any, Dict
|
||||
from functools import lru_cache
|
||||
|
||||
from .http_client import PlatformHttpClient
|
||||
|
||||
|
||||
class ConfigReader:
|
||||
"""配置读取器
|
||||
|
||||
从平台服务读取租户配置
|
||||
|
||||
使用示例:
|
||||
config = ConfigReader(tenant_id=1)
|
||||
|
||||
# 读取单个配置
|
||||
value = await config.get("wechat", "app_id")
|
||||
|
||||
# 读取配置组
|
||||
wechat_config = await config.get_group("wechat")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: int,
|
||||
platform_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
cache_ttl: int = 300 # 缓存时间(秒)
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.platform_url = platform_url or os.getenv("PLATFORM_URL", "")
|
||||
self.api_key = api_key or os.getenv("PLATFORM_API_KEY", "")
|
||||
self.cache_ttl = cache_ttl
|
||||
|
||||
self._client = PlatformHttpClient(
|
||||
base_url=self.platform_url,
|
||||
api_key=self.api_key
|
||||
)
|
||||
|
||||
# 本地缓存
|
||||
self._cache: Dict[str, Any] = {}
|
||||
|
||||
async def get(
|
||||
self,
|
||||
config_type: str,
|
||||
config_key: str,
|
||||
default: Any = None
|
||||
) -> Any:
|
||||
"""读取配置值
|
||||
|
||||
Args:
|
||||
config_type: 配置类型
|
||||
config_key: 配置键
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
配置值
|
||||
"""
|
||||
cache_key = f"{config_type}:{config_key}"
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in self._cache:
|
||||
return self._cache[cache_key]
|
||||
|
||||
# 从平台获取
|
||||
try:
|
||||
response = await self._client.get(
|
||||
f"/api/config/{config_type}/{config_key}",
|
||||
params={"tenant_id": self.tenant_id}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
value = data.get("config_value", default)
|
||||
self._cache[cache_key] = value
|
||||
return value
|
||||
elif response.status_code == 404:
|
||||
return default
|
||||
else:
|
||||
raise Exception(f"Config read failed: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
# 失败时返回默认值
|
||||
return default
|
||||
|
||||
def get_env(
|
||||
self,
|
||||
env_key: str,
|
||||
default: Any = None
|
||||
) -> Any:
|
||||
"""从环境变量读取配置(同步方法)
|
||||
|
||||
Args:
|
||||
env_key: 环境变量名
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
配置值
|
||||
"""
|
||||
return os.getenv(env_key, default)
|
||||
|
||||
def clear_cache(self):
|
||||
"""清除缓存"""
|
||||
self._cache.clear()
|
||||
83
sdk/http_client.py
Normal file
83
sdk/http_client.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""HTTP客户端封装"""
|
||||
import os
|
||||
from typing import Optional, Any, Dict
|
||||
import httpx
|
||||
|
||||
from .trace import get_trace_id
|
||||
|
||||
|
||||
class PlatformHttpClient:
|
||||
"""平台HTTP客户端
|
||||
|
||||
自动传递trace_id和API Key
|
||||
|
||||
使用示例:
|
||||
client = PlatformHttpClient(base_url="https://api.example.com")
|
||||
response = await client.get("/users/1")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: float = 30.0
|
||||
):
|
||||
self.base_url = base_url or os.getenv("PLATFORM_URL", "")
|
||||
self.api_key = api_key or os.getenv("PLATFORM_API_KEY", "")
|
||||
self.timeout = timeout
|
||||
|
||||
def _get_headers(self, extra_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
|
||||
"""获取请求头"""
|
||||
headers = {}
|
||||
|
||||
# 添加trace_id
|
||||
trace_id = get_trace_id()
|
||||
if trace_id:
|
||||
headers["X-Trace-ID"] = trace_id
|
||||
|
||||
# 添加API Key
|
||||
if self.api_key:
|
||||
headers["X-API-Key"] = self.api_key
|
||||
|
||||
# 合并额外header
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
return headers
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
**kwargs
|
||||
) -> httpx.Response:
|
||||
"""发送HTTP请求"""
|
||||
headers = self._get_headers(kwargs.pop("headers", None))
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout
|
||||
) as client:
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=path,
|
||||
headers=headers,
|
||||
**kwargs
|
||||
)
|
||||
return response
|
||||
|
||||
async def get(self, path: str, **kwargs) -> httpx.Response:
|
||||
"""GET请求"""
|
||||
return await self.request("GET", path, **kwargs)
|
||||
|
||||
async def post(self, path: str, **kwargs) -> httpx.Response:
|
||||
"""POST请求"""
|
||||
return await self.request("POST", path, **kwargs)
|
||||
|
||||
async def put(self, path: str, **kwargs) -> httpx.Response:
|
||||
"""PUT请求"""
|
||||
return await self.request("PUT", path, **kwargs)
|
||||
|
||||
async def delete(self, path: str, **kwargs) -> httpx.Response:
|
||||
"""DELETE请求"""
|
||||
return await self.request("DELETE", path, **kwargs)
|
||||
125
sdk/logger.py
Normal file
125
sdk/logger.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""统一日志模块"""
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, Any
|
||||
from functools import lru_cache
|
||||
|
||||
from .trace import get_trace_id, get_tenant_id, get_user_id
|
||||
|
||||
|
||||
class PlatformLogger:
|
||||
"""平台日志器
|
||||
|
||||
支持本地输出和远程上报两种模式
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_code: str,
|
||||
platform_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
local_only: bool = True
|
||||
):
|
||||
self.app_code = app_code
|
||||
self.platform_url = platform_url or os.getenv("PLATFORM_URL", "")
|
||||
self.api_key = api_key or os.getenv("PLATFORM_API_KEY", "")
|
||||
self.local_only = local_only or not self.platform_url
|
||||
|
||||
# 设置本地日志
|
||||
self._logger = logging.getLogger(app_code)
|
||||
self._logger.setLevel(logging.DEBUG)
|
||||
|
||||
if not self._logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s | %(levelname)s | %(name)s | %(message)s'
|
||||
))
|
||||
self._logger.addHandler(handler)
|
||||
|
||||
def _format_message(self, message: str, **kwargs) -> str:
|
||||
"""格式化日志消息,包含trace_id"""
|
||||
trace_id = get_trace_id()
|
||||
prefix = f"[{trace_id[:8]}] " if trace_id else ""
|
||||
|
||||
if kwargs:
|
||||
extra = " | " + " ".join(f"{k}={v}" for k, v in kwargs.items())
|
||||
else:
|
||||
extra = ""
|
||||
|
||||
return f"{prefix}{message}{extra}"
|
||||
|
||||
def _log(
|
||||
self,
|
||||
level: str,
|
||||
message: str,
|
||||
log_type: str = "app",
|
||||
category: Optional[str] = None,
|
||||
context: Optional[dict] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""内部日志方法"""
|
||||
formatted = self._format_message(message, **kwargs)
|
||||
|
||||
# 本地日志
|
||||
log_method = getattr(self._logger, level.lower(), self._logger.info)
|
||||
log_method(formatted)
|
||||
|
||||
# TODO: 远程上报(异步)
|
||||
if not self.local_only:
|
||||
self._send_to_platform(level, message, log_type, category, context, kwargs)
|
||||
|
||||
def _send_to_platform(
|
||||
self,
|
||||
level: str,
|
||||
message: str,
|
||||
log_type: str,
|
||||
category: Optional[str],
|
||||
context: Optional[dict],
|
||||
extra: dict
|
||||
):
|
||||
"""发送日志到平台(异步,后续实现)"""
|
||||
# TODO: 使用httpx异步发送
|
||||
pass
|
||||
|
||||
def debug(self, message: str, **kwargs):
|
||||
"""调试日志"""
|
||||
self._log("debug", message, **kwargs)
|
||||
|
||||
def info(self, message: str, **kwargs):
|
||||
"""信息日志"""
|
||||
self._log("info", message, **kwargs)
|
||||
|
||||
def warn(self, message: str, **kwargs):
|
||||
"""警告日志"""
|
||||
self._log("warn", message, **kwargs)
|
||||
|
||||
def warning(self, message: str, **kwargs):
|
||||
"""警告日志(别名)"""
|
||||
self.warn(message, **kwargs)
|
||||
|
||||
def error(self, message: str, error: Optional[Exception] = None, **kwargs):
|
||||
"""错误日志"""
|
||||
if error:
|
||||
kwargs["error_type"] = type(error).__name__
|
||||
kwargs["error_msg"] = str(error)
|
||||
self._log("error", message, log_type="error", **kwargs)
|
||||
|
||||
def audit(self, action: str, target_type: str, target_id: str, **kwargs):
|
||||
"""审计日志"""
|
||||
self._log(
|
||||
"info",
|
||||
f"AUDIT: {action} {target_type}:{target_id}",
|
||||
log_type="audit",
|
||||
action=action,
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_logger(app_code: str) -> PlatformLogger:
|
||||
"""获取日志器(单例)"""
|
||||
return PlatformLogger(app_code)
|
||||
88
sdk/middleware.py
Normal file
88
sdk/middleware.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""FastAPI中间件"""
|
||||
import time
|
||||
from typing import Callable
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from .trace import generate_trace_id, set_trace_id, get_trace_id
|
||||
from .logger import get_logger
|
||||
|
||||
|
||||
class TraceMiddleware(BaseHTTPMiddleware):
|
||||
"""链路追踪中间件
|
||||
|
||||
为每个请求生成或传递trace_id
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# 从header获取或生成trace_id
|
||||
trace_id = request.headers.get("X-Trace-ID") or generate_trace_id()
|
||||
set_trace_id(trace_id)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# 在响应header中返回trace_id
|
||||
response.headers["X-Trace-ID"] = trace_id
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""请求日志中间件
|
||||
|
||||
记录每个请求的基本信息和耗时
|
||||
"""
|
||||
|
||||
def __init__(self, app, app_code: str = "unknown"):
|
||||
super().__init__(app)
|
||||
self.app_code = app_code
|
||||
self.logger = get_logger(app_code)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# 确保有trace_id
|
||||
trace_id = get_trace_id()
|
||||
if not trace_id:
|
||||
trace_id = request.headers.get("X-Trace-ID") or generate_trace_id()
|
||||
set_trace_id(trace_id)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 记录请求开始
|
||||
self.logger.info(
|
||||
f"Request started: {request.method} {request.url.path}",
|
||||
method=request.method,
|
||||
path=str(request.url.path)
|
||||
)
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
|
||||
# 计算耗时
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 记录请求完成
|
||||
self.logger.info(
|
||||
f"Request completed: {response.status_code}",
|
||||
method=request.method,
|
||||
path=str(request.url.path),
|
||||
status_code=response.status_code,
|
||||
duration_ms=duration_ms
|
||||
)
|
||||
|
||||
# 添加trace_id到响应header
|
||||
response.headers["X-Trace-ID"] = trace_id
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
self.logger.error(
|
||||
f"Request failed: {str(e)}",
|
||||
error=e,
|
||||
method=request.method,
|
||||
path=str(request.url.path),
|
||||
duration_ms=duration_ms
|
||||
)
|
||||
raise
|
||||
148
sdk/stats_client.py
Normal file
148
sdk/stats_client.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""AI统计上报客户端"""
|
||||
import os
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Optional, List
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
from .trace import get_trace_id, get_tenant_id, get_user_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class AICallEvent:
|
||||
"""AI调用事件"""
|
||||
tenant_id: int
|
||||
app_code: str
|
||||
module_code: str
|
||||
prompt_name: str
|
||||
model: str
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cost: Decimal = Decimal("0")
|
||||
latency_ms: int = 0
|
||||
status: str = "success"
|
||||
user_id: Optional[int] = None
|
||||
trace_id: Optional[str] = None
|
||||
event_time: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.event_time is None:
|
||||
self.event_time = datetime.now()
|
||||
if self.trace_id is None:
|
||||
self.trace_id = get_trace_id()
|
||||
if self.user_id is None:
|
||||
self.user_id = get_user_id()
|
||||
|
||||
|
||||
class StatsClient:
|
||||
"""统计上报客户端
|
||||
|
||||
使用示例:
|
||||
stats = StatsClient(tenant_id=1, app_code="011-ai-interview")
|
||||
|
||||
# 上报AI调用
|
||||
stats.report_ai_call(
|
||||
module_code="interview",
|
||||
prompt_name="generate_question",
|
||||
model="gpt-4",
|
||||
input_tokens=100,
|
||||
output_tokens=200,
|
||||
latency_ms=1500
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: int,
|
||||
app_code: str,
|
||||
platform_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
local_only: bool = True
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.app_code = app_code
|
||||
self.platform_url = platform_url or os.getenv("PLATFORM_URL", "")
|
||||
self.api_key = api_key or os.getenv("PLATFORM_API_KEY", "")
|
||||
self.local_only = local_only or not self.platform_url
|
||||
|
||||
# 批量上报缓冲区
|
||||
self._buffer: List[AICallEvent] = []
|
||||
self._buffer_size = 10 # 达到此数量时自动上报
|
||||
|
||||
def report_ai_call(
|
||||
self,
|
||||
module_code: str,
|
||||
prompt_name: str,
|
||||
model: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cost: Decimal = Decimal("0"),
|
||||
latency_ms: int = 0,
|
||||
status: str = "success",
|
||||
user_id: Optional[int] = None,
|
||||
flush: bool = False
|
||||
) -> AICallEvent:
|
||||
"""上报AI调用事件
|
||||
|
||||
Args:
|
||||
module_code: 模块编码
|
||||
prompt_name: Prompt名称
|
||||
model: 模型名称
|
||||
input_tokens: 输入token数
|
||||
output_tokens: 输出token数
|
||||
cost: 成本
|
||||
latency_ms: 延迟毫秒
|
||||
status: 状态 (success/error)
|
||||
user_id: 用户ID(可选,默认从上下文获取)
|
||||
flush: 是否立即发送
|
||||
|
||||
Returns:
|
||||
创建的事件对象
|
||||
"""
|
||||
event = AICallEvent(
|
||||
tenant_id=self.tenant_id,
|
||||
app_code=self.app_code,
|
||||
module_code=module_code,
|
||||
prompt_name=prompt_name,
|
||||
model=model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
cost=cost,
|
||||
latency_ms=latency_ms,
|
||||
status=status,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
self._buffer.append(event)
|
||||
|
||||
if flush or len(self._buffer) >= self._buffer_size:
|
||||
self.flush()
|
||||
|
||||
return event
|
||||
|
||||
def flush(self):
|
||||
"""发送缓冲区中的所有事件"""
|
||||
if not self._buffer:
|
||||
return
|
||||
|
||||
events = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
|
||||
if self.local_only:
|
||||
# 本地模式:仅打印
|
||||
for event in events:
|
||||
print(f"[STATS] {event.app_code}/{event.module_code}: "
|
||||
f"{event.prompt_name} - {event.input_tokens}+{event.output_tokens} tokens")
|
||||
else:
|
||||
# 远程上报
|
||||
self._send_to_platform(events)
|
||||
|
||||
def _send_to_platform(self, events: List[AICallEvent]):
|
||||
"""发送事件到平台(异步,后续实现)"""
|
||||
# TODO: 使用httpx异步发送
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""析构时发送剩余事件"""
|
||||
if self._buffer:
|
||||
self.flush()
|
||||
79
sdk/trace.py
Normal file
79
sdk/trace.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""链路追踪模块"""
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional
|
||||
|
||||
# 上下文变量存储trace_id
|
||||
_trace_id_var: ContextVar[Optional[str]] = ContextVar("trace_id", default=None)
|
||||
_tenant_id_var: ContextVar[Optional[int]] = ContextVar("tenant_id", default=None)
|
||||
_user_id_var: ContextVar[Optional[int]] = ContextVar("user_id", default=None)
|
||||
|
||||
|
||||
def generate_trace_id() -> str:
|
||||
"""生成新的trace_id"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def get_trace_id() -> Optional[str]:
|
||||
"""获取当前trace_id"""
|
||||
return _trace_id_var.get()
|
||||
|
||||
|
||||
def set_trace_id(trace_id: str) -> None:
|
||||
"""设置当前trace_id"""
|
||||
_trace_id_var.set(trace_id)
|
||||
|
||||
|
||||
def get_tenant_id() -> Optional[int]:
|
||||
"""获取当前租户ID"""
|
||||
return _tenant_id_var.get()
|
||||
|
||||
|
||||
def set_tenant_id(tenant_id: int) -> None:
|
||||
"""设置当前租户ID"""
|
||||
_tenant_id_var.set(tenant_id)
|
||||
|
||||
|
||||
def get_user_id() -> Optional[int]:
|
||||
"""获取当前用户ID"""
|
||||
return _user_id_var.get()
|
||||
|
||||
|
||||
def set_user_id(user_id: int) -> None:
|
||||
"""设置当前用户ID"""
|
||||
_user_id_var.set(user_id)
|
||||
|
||||
|
||||
class TraceContext:
|
||||
"""链路追踪上下文管理器
|
||||
|
||||
使用示例:
|
||||
with TraceContext(tenant_id=1, user_id=100) as ctx:
|
||||
print(ctx.trace_id)
|
||||
# 在此上下文中的所有操作都会使用相同的trace_id
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trace_id: Optional[str] = None,
|
||||
tenant_id: Optional[int] = None,
|
||||
user_id: Optional[int] = None
|
||||
):
|
||||
self.trace_id = trace_id or generate_trace_id()
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self._tokens = []
|
||||
|
||||
def __enter__(self):
|
||||
self._tokens.append(_trace_id_var.set(self.trace_id))
|
||||
if self.tenant_id is not None:
|
||||
self._tokens.append(_tenant_id_var.set(self.tenant_id))
|
||||
if self.user_id is not None:
|
||||
self._tokens.append(_user_id_var.set(self.user_id))
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for token in self._tokens:
|
||||
# 重置为之前的值
|
||||
pass
|
||||
return False
|
||||
Reference in New Issue
Block a user