149 lines
4.3 KiB
Python
149 lines
4.3 KiB
Python
"""AI统计上报客户端"""
|
||
import os
|
||
from datetime import datetime
|
||
from decimal import Decimal
|
||
from typing import Optional, List
|
||
from dataclasses import dataclass, asdict
|
||
|
||
from .trace import get_trace_id, get_tenant_id, get_user_id
|
||
|
||
|
||
@dataclass
|
||
class AICallEvent:
|
||
"""AI调用事件"""
|
||
tenant_id: int
|
||
app_code: str
|
||
module_code: str
|
||
prompt_name: str
|
||
model: str
|
||
input_tokens: int = 0
|
||
output_tokens: int = 0
|
||
cost: Decimal = Decimal("0")
|
||
latency_ms: int = 0
|
||
status: str = "success"
|
||
user_id: Optional[int] = None
|
||
trace_id: Optional[str] = None
|
||
event_time: datetime = None
|
||
|
||
def __post_init__(self):
|
||
if self.event_time is None:
|
||
self.event_time = datetime.now()
|
||
if self.trace_id is None:
|
||
self.trace_id = get_trace_id()
|
||
if self.user_id is None:
|
||
self.user_id = get_user_id()
|
||
|
||
|
||
class StatsClient:
|
||
"""统计上报客户端
|
||
|
||
使用示例:
|
||
stats = StatsClient(tenant_id=1, app_code="011-ai-interview")
|
||
|
||
# 上报AI调用
|
||
stats.report_ai_call(
|
||
module_code="interview",
|
||
prompt_name="generate_question",
|
||
model="gpt-4",
|
||
input_tokens=100,
|
||
output_tokens=200,
|
||
latency_ms=1500
|
||
)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
tenant_id: int,
|
||
app_code: str,
|
||
platform_url: Optional[str] = None,
|
||
api_key: Optional[str] = None,
|
||
local_only: bool = True
|
||
):
|
||
self.tenant_id = tenant_id
|
||
self.app_code = app_code
|
||
self.platform_url = platform_url or os.getenv("PLATFORM_URL", "")
|
||
self.api_key = api_key or os.getenv("PLATFORM_API_KEY", "")
|
||
self.local_only = local_only or not self.platform_url
|
||
|
||
# 批量上报缓冲区
|
||
self._buffer: List[AICallEvent] = []
|
||
self._buffer_size = 10 # 达到此数量时自动上报
|
||
|
||
def report_ai_call(
|
||
self,
|
||
module_code: str,
|
||
prompt_name: str,
|
||
model: str,
|
||
input_tokens: int = 0,
|
||
output_tokens: int = 0,
|
||
cost: Decimal = Decimal("0"),
|
||
latency_ms: int = 0,
|
||
status: str = "success",
|
||
user_id: Optional[int] = None,
|
||
flush: bool = False
|
||
) -> AICallEvent:
|
||
"""上报AI调用事件
|
||
|
||
Args:
|
||
module_code: 模块编码
|
||
prompt_name: Prompt名称
|
||
model: 模型名称
|
||
input_tokens: 输入token数
|
||
output_tokens: 输出token数
|
||
cost: 成本
|
||
latency_ms: 延迟毫秒
|
||
status: 状态 (success/error)
|
||
user_id: 用户ID(可选,默认从上下文获取)
|
||
flush: 是否立即发送
|
||
|
||
Returns:
|
||
创建的事件对象
|
||
"""
|
||
event = AICallEvent(
|
||
tenant_id=self.tenant_id,
|
||
app_code=self.app_code,
|
||
module_code=module_code,
|
||
prompt_name=prompt_name,
|
||
model=model,
|
||
input_tokens=input_tokens,
|
||
output_tokens=output_tokens,
|
||
cost=cost,
|
||
latency_ms=latency_ms,
|
||
status=status,
|
||
user_id=user_id
|
||
)
|
||
|
||
self._buffer.append(event)
|
||
|
||
if flush or len(self._buffer) >= self._buffer_size:
|
||
self.flush()
|
||
|
||
return event
|
||
|
||
def flush(self):
|
||
"""发送缓冲区中的所有事件"""
|
||
if not self._buffer:
|
||
return
|
||
|
||
events = self._buffer.copy()
|
||
self._buffer.clear()
|
||
|
||
if self.local_only:
|
||
# 本地模式:仅打印
|
||
for event in events:
|
||
print(f"[STATS] {event.app_code}/{event.module_code}: "
|
||
f"{event.prompt_name} - {event.input_tokens}+{event.output_tokens} tokens")
|
||
else:
|
||
# 远程上报
|
||
self._send_to_platform(events)
|
||
|
||
def _send_to_platform(self, events: List[AICallEvent]):
|
||
"""发送事件到平台(异步,后续实现)"""
|
||
# TODO: 使用httpx异步发送
|
||
pass
|
||
|
||
def __del__(self):
|
||
"""析构时发送剩余事件"""
|
||
if self._buffer:
|
||
self.flush()
|