"""计费逻辑与 usage_records 持久化。""" from __future__ import annotations import os from typing import Any import storage # 价格单位:USD / 1M tokens DEFAULT_INPUT_PRICE_PER_1M = float( os.environ.get("OPENAI_DEFAULT_INPUT_PRICE_PER_1M", "0.15") ) DEFAULT_OUTPUT_PRICE_PER_1M = float( os.environ.get("OPENAI_DEFAULT_OUTPUT_PRICE_PER_1M", "0.60") ) MODEL_PRICES_PER_1M: dict[str, tuple[float, float]] = { "gpt-4o-mini": (0.15, 0.60), "gpt-4.1-mini": (0.40, 1.60), "gpt-4.1": (2.00, 8.00), "gpt-4o": (2.50, 10.00), } def calc_cost_usd(model: str, prompt_tokens: int, completion_tokens: int) -> float: """计算一次请求的美元成本。""" model_rates = MODEL_PRICES_PER_1M.get(model, None) if model_rates is None: in_rate = DEFAULT_INPUT_PRICE_PER_1M out_rate = DEFAULT_OUTPUT_PRICE_PER_1M else: in_rate, out_rate = model_rates cost = (prompt_tokens * in_rate + completion_tokens * out_rate) / 1_000_000.0 return round(cost, 8) def record_usage( *, username: str, job_id: str | None, model: str, prompt_tokens: int, completion_tokens: int, total_tokens: int, ) -> None: """记录一次模型调用的 token 使用情况与成本。""" cost_usd = calc_cost_usd(model, prompt_tokens, completion_tokens) storage.db_execute( """ INSERT INTO usage_records( username, job_id, model, prompt_tokens, completion_tokens, total_tokens, cost_usd, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( username, job_id, model, prompt_tokens, completion_tokens, total_tokens, cost_usd, storage.now_iso(), ), ) def get_billing_summary(username: str) -> dict[str, Any]: """汇总某个用户的累计账单信息。""" row = storage.db_fetchone( """ SELECT COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens, COALESCE(SUM(completion_tokens), 0) AS completion_tokens, COALESCE(SUM(total_tokens), 0) AS total_tokens, COALESCE(SUM(cost_usd), 0) AS total_cost_usd FROM usage_records WHERE username = ? """, (username,), ) if row is None: return { "username": username, "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "total_cost_usd": 0.0, } return { "username": username, "prompt_tokens": row["prompt_tokens"], "completion_tokens": row["completion_tokens"], "total_tokens": row["total_tokens"], "total_cost_usd": round(float(row["total_cost_usd"]), 8), } def get_billing_records(username: str, limit: int) -> list[dict[str, Any]]: """获取用户近期账单记录。""" rows = storage.db_fetchall( """ SELECT id, username, job_id, model, prompt_tokens, completion_tokens, total_tokens, cost_usd, created_at FROM usage_records WHERE username = ? ORDER BY created_at DESC LIMIT ? """, (username, limit), ) return [dict(row) for row in rows]