PDF / src /billing.py
BirkhoffLee's picture
refactor: 资源拆分
d3a7520 unverified
"""计费逻辑与 usage_records 持久化。"""
from __future__ import annotations
import os
from typing import Any
import storage
# 价格单位:USD / 1M tokens
DEFAULT_INPUT_PRICE_PER_1M = float(
os.environ.get("OPENAI_DEFAULT_INPUT_PRICE_PER_1M", "0.15")
)
DEFAULT_OUTPUT_PRICE_PER_1M = float(
os.environ.get("OPENAI_DEFAULT_OUTPUT_PRICE_PER_1M", "0.60")
)
MODEL_PRICES_PER_1M: dict[str, tuple[float, float]] = {
"gpt-4o-mini": (0.15, 0.60),
"gpt-4.1-mini": (0.40, 1.60),
"gpt-4.1": (2.00, 8.00),
"gpt-4o": (2.50, 10.00),
}
def calc_cost_usd(model: str, prompt_tokens: int, completion_tokens: int) -> float:
"""计算一次请求的美元成本。"""
model_rates = MODEL_PRICES_PER_1M.get(model, None)
if model_rates is None:
in_rate = DEFAULT_INPUT_PRICE_PER_1M
out_rate = DEFAULT_OUTPUT_PRICE_PER_1M
else:
in_rate, out_rate = model_rates
cost = (prompt_tokens * in_rate + completion_tokens * out_rate) / 1_000_000.0
return round(cost, 8)
def record_usage(
*,
username: str,
job_id: str | None,
model: str,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
) -> None:
"""记录一次模型调用的 token 使用情况与成本。"""
cost_usd = calc_cost_usd(model, prompt_tokens, completion_tokens)
storage.db_execute(
"""
INSERT INTO usage_records(
username, job_id, model,
prompt_tokens, completion_tokens, total_tokens,
cost_usd, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
username,
job_id,
model,
prompt_tokens,
completion_tokens,
total_tokens,
cost_usd,
storage.now_iso(),
),
)
def get_billing_summary(username: str) -> dict[str, Any]:
"""汇总某个用户的累计账单信息。"""
row = storage.db_fetchone(
"""
SELECT
COALESCE(SUM(prompt_tokens), 0) AS prompt_tokens,
COALESCE(SUM(completion_tokens), 0) AS completion_tokens,
COALESCE(SUM(total_tokens), 0) AS total_tokens,
COALESCE(SUM(cost_usd), 0) AS total_cost_usd
FROM usage_records
WHERE username = ?
""",
(username,),
)
if row is None:
return {
"username": username,
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"total_cost_usd": 0.0,
}
return {
"username": username,
"prompt_tokens": row["prompt_tokens"],
"completion_tokens": row["completion_tokens"],
"total_tokens": row["total_tokens"],
"total_cost_usd": round(float(row["total_cost_usd"]), 8),
}
def get_billing_records(username: str, limit: int) -> list[dict[str, Any]]:
"""获取用户近期账单记录。"""
rows = storage.db_fetchall(
"""
SELECT
id, username, job_id, model,
prompt_tokens, completion_tokens, total_tokens,
cost_usd, created_at
FROM usage_records
WHERE username = ?
ORDER BY created_at DESC
LIMIT ?
""",
(username, limit),
)
return [dict(row) for row in rows]