Files
000-platform/sdk/stats_client.py
111 daa8125c58
All checks were successful
continuous-integration/drone/push Build is passing
Initial commit: 000-platform project skeleton
2026-01-23 14:32:09 +08:00

149 lines
4.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI统计上报客户端"""
import os
from datetime import datetime
from decimal import Decimal
from typing import Optional, List
from dataclasses import dataclass, asdict
from .trace import get_trace_id, get_tenant_id, get_user_id
@dataclass
class AICallEvent:
"""AI调用事件"""
tenant_id: int
app_code: str
module_code: str
prompt_name: str
model: str
input_tokens: int = 0
output_tokens: int = 0
cost: Decimal = Decimal("0")
latency_ms: int = 0
status: str = "success"
user_id: Optional[int] = None
trace_id: Optional[str] = None
event_time: datetime = None
def __post_init__(self):
if self.event_time is None:
self.event_time = datetime.now()
if self.trace_id is None:
self.trace_id = get_trace_id()
if self.user_id is None:
self.user_id = get_user_id()
class StatsClient:
"""统计上报客户端
使用示例:
stats = StatsClient(tenant_id=1, app_code="011-ai-interview")
# 上报AI调用
stats.report_ai_call(
module_code="interview",
prompt_name="generate_question",
model="gpt-4",
input_tokens=100,
output_tokens=200,
latency_ms=1500
)
"""
def __init__(
self,
tenant_id: int,
app_code: str,
platform_url: Optional[str] = None,
api_key: Optional[str] = None,
local_only: bool = True
):
self.tenant_id = tenant_id
self.app_code = app_code
self.platform_url = platform_url or os.getenv("PLATFORM_URL", "")
self.api_key = api_key or os.getenv("PLATFORM_API_KEY", "")
self.local_only = local_only or not self.platform_url
# 批量上报缓冲区
self._buffer: List[AICallEvent] = []
self._buffer_size = 10 # 达到此数量时自动上报
def report_ai_call(
self,
module_code: str,
prompt_name: str,
model: str,
input_tokens: int = 0,
output_tokens: int = 0,
cost: Decimal = Decimal("0"),
latency_ms: int = 0,
status: str = "success",
user_id: Optional[int] = None,
flush: bool = False
) -> AICallEvent:
"""上报AI调用事件
Args:
module_code: 模块编码
prompt_name: Prompt名称
model: 模型名称
input_tokens: 输入token数
output_tokens: 输出token数
cost: 成本
latency_ms: 延迟毫秒
status: 状态 (success/error)
user_id: 用户ID可选默认从上下文获取
flush: 是否立即发送
Returns:
创建的事件对象
"""
event = AICallEvent(
tenant_id=self.tenant_id,
app_code=self.app_code,
module_code=module_code,
prompt_name=prompt_name,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
latency_ms=latency_ms,
status=status,
user_id=user_id
)
self._buffer.append(event)
if flush or len(self._buffer) >= self._buffer_size:
self.flush()
return event
def flush(self):
"""发送缓冲区中的所有事件"""
if not self._buffer:
return
events = self._buffer.copy()
self._buffer.clear()
if self.local_only:
# 本地模式:仅打印
for event in events:
print(f"[STATS] {event.app_code}/{event.module_code}: "
f"{event.prompt_name} - {event.input_tokens}+{event.output_tokens} tokens")
else:
# 远程上报
self._send_to_platform(events)
def _send_to_platform(self, events: List[AICallEvent]):
"""发送事件到平台(异步,后续实现)"""
# TODO: 使用httpx异步发送
pass
def __del__(self):
"""析构时发送剩余事件"""
if self._buffer:
self.flush()