Subscription dashboard improvements, AI text generation limit, and other fixes.
This commit is contained in:
@@ -3,13 +3,18 @@ from __future__ import annotations
|
||||
import base64
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from utils.logger_utils import get_service_logger
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from models.subscription_models import APIProvider, UsageSummary
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/images", tags=["images"])
|
||||
@@ -39,9 +44,23 @@ class ImageGenerateResponse(BaseModel):
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ImageGenerateResponse)
|
||||
def generate(req: ImageGenerateRequest) -> ImageGenerateResponse:
|
||||
def generate(
|
||||
req: ImageGenerateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> ImageGenerateResponse:
|
||||
"""Generate image with subscription checking."""
|
||||
try:
|
||||
# Extract Clerk user ID (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
# Validation is now handled inside generate_image function
|
||||
last_error: Optional[Exception] = None
|
||||
result = None
|
||||
for attempt in range(2): # simple single retry
|
||||
try:
|
||||
result = generate_image(
|
||||
@@ -56,8 +75,79 @@ def generate(req: ImageGenerateRequest) -> ImageGenerateResponse:
|
||||
"steps": req.steps,
|
||||
"seed": req.seed,
|
||||
},
|
||||
user_id=user_id, # Pass user_id for validation inside generate_image
|
||||
)
|
||||
image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
|
||||
|
||||
# TRACK USAGE after successful image generation
|
||||
if result:
|
||||
logger.info(f"[images.generate] ✅ Image generation successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
# Get or create usage summary
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
logger.debug(f"[images.generate] Looking for usage summary: user_id={user_id}, period={current_period}")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
logger.info(f"[images.generate] Creating new usage summary for user {user_id}, period {current_period}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
|
||||
# Get "before" state for unified log
|
||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||
|
||||
# Update provider-specific counters (stability for image generation)
|
||||
# Note: All image generation goes through STABILITY provider enum regardless of actual provider
|
||||
new_calls = current_calls_before + 1
|
||||
setattr(summary, "stability_calls", new_calls)
|
||||
logger.debug(f"[images.generate] Updated stability_calls: {current_calls_before} -> {new_calls}")
|
||||
|
||||
# Update totals
|
||||
old_total_calls = summary.total_calls or 0
|
||||
summary.total_calls = old_total_calls + 1
|
||||
logger.debug(f"[images.generate] Updated totals: calls {old_total_calls} -> {summary.total_calls}")
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[images.generate] ✅ Successfully tracked usage: user {user_id} -> stability -> {new_calls} calls")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Image Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: stability
|
||||
├─ Actual Provider: {result.provider}
|
||||
├─ Model: {result.model or 'default'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[images.generate] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
# Non-blocking: log error but don't fail the request
|
||||
logger.error(f"[images.generate] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return ImageGenerateResponse(
|
||||
image_base64=image_b64,
|
||||
width=result.width,
|
||||
@@ -106,7 +196,10 @@ class ImagePromptSuggestResponse(BaseModel):
|
||||
|
||||
|
||||
@router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse)
|
||||
def suggest_prompts(req: ImagePromptSuggestRequest) -> ImagePromptSuggestResponse:
|
||||
def suggest_prompts(
|
||||
req: ImagePromptSuggestRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> ImagePromptSuggestResponse:
|
||||
try:
|
||||
provider = (req.provider or ("gemini" if (os.getenv("GPT_PROVIDER") or "").lower().startswith("gemini") else "huggingface")).lower()
|
||||
section = req.section or {}
|
||||
@@ -203,7 +296,15 @@ def suggest_prompts(req: ImagePromptSuggestRequest) -> ImagePromptSuggestRespons
|
||||
If including on-image text, return it in overlay_text (short: <= 8 words).
|
||||
"""
|
||||
|
||||
raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema)
|
||||
# Get user_id for llm_text_gen subscription check (required)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id_for_llm = str(current_user.get('id', ''))
|
||||
if not user_id_for_llm:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema, user_id=user_id_for_llm)
|
||||
data = raw if isinstance(raw, dict) else {}
|
||||
suggestions = data.get("suggestions") or []
|
||||
# basic fallback if provider returns string
|
||||
|
||||
Reference in New Issue
Block a user