Subscription dashboard improvements, AI text generation limit, and other fixes.

This commit is contained in:
ajaysi
2025-11-01 18:01:14 +05:30
parent cdb41aec1b
commit de4328175d
64 changed files with 5809 additions and 444 deletions

View File

@@ -3,13 +3,18 @@ from __future__ import annotations
import base64
import os
from typing import Optional, Dict, Any
from datetime import datetime
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel, Field
from services.llm_providers.main_image_generation import generate_image
from services.llm_providers.main_text_generation import llm_text_gen
from utils.logger_utils import get_service_logger
from middleware.auth_middleware import get_current_user
from services.database import get_db
from services.subscription import UsageTrackingService, PricingService
from models.subscription_models import APIProvider, UsageSummary
router = APIRouter(prefix="/api/images", tags=["images"])
@@ -39,9 +44,23 @@ class ImageGenerateResponse(BaseModel):
@router.post("/generate", response_model=ImageGenerateResponse)
def generate(req: ImageGenerateRequest) -> ImageGenerateResponse:
def generate(
req: ImageGenerateRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> ImageGenerateResponse:
"""Generate image with subscription checking."""
try:
# Extract Clerk user ID (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
# Validation is now handled inside generate_image function
last_error: Optional[Exception] = None
result = None
for attempt in range(2): # simple single retry
try:
result = generate_image(
@@ -56,8 +75,79 @@ def generate(req: ImageGenerateRequest) -> ImageGenerateResponse:
"steps": req.steps,
"seed": req.seed,
},
user_id=user_id, # Pass user_id for validation inside generate_image
)
image_b64 = base64.b64encode(result.image_bytes).decode("utf-8")
# TRACK USAGE after successful image generation
if result:
logger.info(f"[images.generate] ✅ Image generation successful, tracking usage for user {user_id}")
try:
db_track = next(get_db())
try:
# Get or create usage summary
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[images.generate] Looking for usage summary: user_id={user_id}, period={current_period}")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
logger.info(f"[images.generate] Creating new usage summary for user {user_id}, period {current_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# Get "before" state for unified log
current_calls_before = getattr(summary, "stability_calls", 0) or 0
# Update provider-specific counters (stability for image generation)
# Note: All image generation goes through STABILITY provider enum regardless of actual provider
new_calls = current_calls_before + 1
setattr(summary, "stability_calls", new_calls)
logger.debug(f"[images.generate] Updated stability_calls: {current_calls_before} -> {new_calls}")
# Update totals
old_total_calls = summary.total_calls or 0
summary.total_calls = old_total_calls + 1
logger.debug(f"[images.generate] Updated totals: calls {old_total_calls} -> {summary.total_calls}")
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get("stability_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[images.generate] ✅ Successfully tracked usage: user {user_id} -> stability -> {new_calls} calls")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
print(f"""
[SUBSCRIPTION] Image Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: stability
├─ Actual Provider: {result.provider}
├─ Model: {result.model or 'default'}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[images.generate] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
# Non-blocking: log error but don't fail the request
logger.error(f"[images.generate] ❌ Failed to track usage: {usage_error}", exc_info=True)
return ImageGenerateResponse(
image_base64=image_b64,
width=result.width,
@@ -106,7 +196,10 @@ class ImagePromptSuggestResponse(BaseModel):
@router.post("/suggest-prompts", response_model=ImagePromptSuggestResponse)
def suggest_prompts(req: ImagePromptSuggestRequest) -> ImagePromptSuggestResponse:
def suggest_prompts(
req: ImagePromptSuggestRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> ImagePromptSuggestResponse:
try:
provider = (req.provider or ("gemini" if (os.getenv("GPT_PROVIDER") or "").lower().startswith("gemini") else "huggingface")).lower()
section = req.section or {}
@@ -203,7 +296,15 @@ def suggest_prompts(req: ImagePromptSuggestRequest) -> ImagePromptSuggestRespons
If including on-image text, return it in overlay_text (short: <= 8 words).
"""
raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema)
# Get user_id for llm_text_gen subscription check (required)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id_for_llm = str(current_user.get('id', ''))
if not user_id_for_llm:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
raw = llm_text_gen(prompt=prompt, system_prompt=system, json_struct=schema, user_id=user_id_for_llm)
data = raw if isinstance(raw, dict) else {}
suggestions = data.get("suggestions") or []
# basic fallback if provider returns string