fix: WYSIWYG editor, content generation, and writing assistant bug fixes

- Fix text selection menu not showing: wire contentRef via inputRef on multiline TextField
- Fix blog title not truncating: add min-w-0 for flex item overflow
- Fix outline generation 500: escape curly braces in f-string prompt template
- Fix content generation 'NoneType not callable': replace SessionLocal() with get_session_for_user(), add db param to MediumBlogGenerator, fix signature mismatch in database_task_manager
- Fix writing assistant suggest 500: add auth + user_id to API endpoint and service, replace sync requests with httpx.AsyncClient
- Fix hallucination detector 404: explicitly include router in main.py and app.py
- Fix missing error_data in task failure responses
- Hide CopilotKit web inspector button
- Remove hardcoded fallback suggestions from SmartTypingAssist
- Fix stale closure refs in SmartTypingAssist handleTypingChange
- Add two-column editor layout, stats bar, section hover menu
- Various subscription, billing, and research module improvements
This commit is contained in:
ajaysi
2026-05-14 09:11:30 +05:30
parent 7385100017
commit 928c2f20aa
113 changed files with 4344 additions and 10064 deletions

View File

@@ -7,12 +7,11 @@ The onboarding endpoints are re-exported from a stable module
import os
# Check podcast mode early
_is_podcast = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast"
# In podcast mode, don't import heavy onboarding endpoints
# In feature-only modes, don't import heavy onboarding endpoints
# They trigger heavy dependencies (exa_py, etc.)
if _is_podcast:
_is_full_mode = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() in ("", "all")
if not _is_full_mode:
__all__ = []
else:
from .onboarding_endpoints import (

View File

@@ -1195,3 +1195,68 @@ async def generate_introductions(
except Exception as e:
logger.error(f"Failed to generate introductions: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------
# Save Complete Blog Asset
# ---------------------------
class SaveCompleteBlogAssetRequest(BaseModel):
title: str
content: str
seo_title: Optional[str] = None
meta_description: Optional[str] = None
focus_keyword: Optional[str] = None
tags: List[str] = Field(default_factory=list)
categories: List[str] = Field(default_factory=list)
@router.post("/save-complete-asset")
async def save_complete_blog_asset(
request: SaveCompleteBlogAssetRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
) -> Dict[str, Any]:
"""Save the complete blog content as a single asset in the asset library."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
full_content = f"# {request.title}\n\n{request.content}"
asset_id = save_and_track_text_content(
db=db,
user_id=user_id,
content=full_content,
source_module="blog_writer",
title=f"Published Blog: {request.title[:60]}",
description=request.meta_description or f"Complete published blog post: {request.title}",
prompt=f"SEO Title: {request.seo_title or request.title}\nFocus Keyword: {request.focus_keyword or ''}",
tags=["blog", "published"] + [t for t in (request.tags or []) if t],
asset_metadata={
"status": "published",
"focus_keyword": request.focus_keyword,
"categories": request.categories,
"word_count": len(full_content.split()),
},
subdirectory="published",
file_extension=".md"
)
if asset_id:
logger.info(f"✅ Complete blog asset saved to library: ID={asset_id}")
return {"success": True, "asset_id": asset_id}
else:
logger.warning("save_and_track_text_content returned None for published blog")
return {"success": False, "error": "Failed to save blog asset"}
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to save complete blog asset: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -13,7 +13,7 @@ from typing import Any, Dict, List
from fastapi import HTTPException
from loguru import logger
from sqlalchemy.orm import Session
from services.database import SessionLocal, get_session_for_user
from services.database import get_session_for_user
from models.blog_models import (
BlogResearchRequest,
@@ -264,7 +264,7 @@ class TaskManager:
raise ValueError("Global target words exceed 1000; medium generation not allowed")
# Create a sync session for asset saving
db_session = SessionLocal()
db_session = get_session_for_user(user_id)
try:
result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress(
request,
@@ -326,6 +326,7 @@ class TaskManager:
await self.update_progress(task_id, f"❌ Medium generation failed: {str(e)}")
self.task_storage[task_id]["status"] = "failed"
self.task_storage[task_id]["error"] = str(e)
self.task_storage[task_id]["error_data"] = {"error_message": str(e), "error_type": type(e).__name__}
# Global task manager instance

View File

@@ -202,6 +202,26 @@ Listener CTA: {request.analysis.get('listener_cta', 'N/A')}
interests = ", ".join(audience_dna.get("interests", []))
target_audience = f"Expertise: {audience_dna.get('expertise_level', '')}. Interests: {interests}."
# Preflight subscription check for Exa
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=APIProvider.EXA,
tokens_requested=0,
actual_provider_name="exa",
)
if not can_proceed:
raise HTTPException(status_code=429, detail={
'error': message, 'message': message,
'provider': 'exa', 'usage_info': usage_info or {}
})
logger.info(f"[Podcast Research] Preflight check passed for user {user_id}")
except HTTPException:
raise
except Exception as e:
logger.warning(f"[Podcast Research] Preflight check failed: {e}")
try:
# 1. RUN EXA SEARCH
logger.warning(f"[Podcast Research] Calling Exa search with topic: {request.topic[:100]}...")

View File

@@ -9,10 +9,13 @@ from typing import Dict, Any, List, Optional
from pydantic import BaseModel
from loguru import logger
from types import SimpleNamespace
from sqlalchemy import text
from middleware.auth_middleware import get_current_user
from api.story_writer.utils.auth import require_authenticated_user
from services.research.tavily_service import TavilyService
from services.blog_writer.research.exa_provider import ExaResearchProvider
from services.subscription import PricingService
from models.subscription_models import APIProvider
router = APIRouter(prefix="/research", tags=["Podcast Category Research"])
@@ -29,6 +32,75 @@ EXA_CATEGORY_MAP = {
}
def _preflight_check(user_id: str, provider: APIProvider, provider_name: str):
"""Check subscription limits before making a research API call."""
from services.database import get_session_for_user
db = get_session_for_user(user_id)
if not db:
return
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=provider,
tokens_requested=0,
actual_provider_name=provider_name,
)
if not can_proceed:
raise HTTPException(status_code=429, detail={
'error': message, 'message': message,
'provider': provider_name, 'usage_info': usage_info or {}
})
except HTTPException:
raise
except Exception as e:
logger.warning(f"[CategoryResearch] Preflight check failed for {provider_name}: {e}")
finally:
db.close()
def _track_research_usage(user_id: str, provider_name: str, cost: float, calls_column: str, cost_column: str):
"""Track research API usage after successful call."""
from services.database import get_session_for_user
db = get_session_for_user(user_id)
if not db:
logger.warning(f"[CategoryResearch] Could not get DB session for user {user_id}")
return
try:
pricing_service = PricingService(db)
current_period = pricing_service.get_current_billing_period(user_id)
update_query = text(f"""
UPDATE usage_summaries
SET {calls_column} = COALESCE({calls_column}, 0) + 1,
{cost_column} = COALESCE({cost_column}, 0) + :cost,
total_calls = COALESCE(total_calls, 0) + 1,
total_cost = COALESCE(total_cost, 0) + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db.execute(update_query, {
'cost': cost,
'user_id': user_id,
'period': current_period,
})
db.commit()
logger.info(f"[CategoryResearch] Tracked {provider_name} usage: user={user_id}, cost=${cost}")
# Clear dashboard cache so header stats update immediately
try:
from api.subscription.cache import clear_dashboard_cache
clear_dashboard_cache(user_id)
except Exception as cache_err:
logger.warning(f"[CategoryResearch] Failed to clear dashboard cache: {cache_err}")
except Exception as e:
logger.error(f"[CategoryResearch] Failed to track {provider_name} usage: {e}")
db.rollback()
finally:
db.close()
class CategoryResearchRequest(BaseModel):
category: str
keyword: Optional[str] = None
@@ -80,9 +152,12 @@ def _normalize_exa_results(results: List[Dict], query: str) -> List[CategoryTopi
return topics
async def _search_tavily(category: str, keyword: str, max_results: int) -> CategoryResearchResponse:
async def _search_tavily(category: str, keyword: str, max_results: int, user_id: str) -> CategoryResearchResponse:
logger.info(f"[CategoryResearch] Using Tavily for category={category}, keyword={keyword}")
# Preflight subscription check
_preflight_check(user_id, APIProvider.TAVILY, "tavily")
try:
tavily = TavilyService()
result = await tavily.search(
@@ -102,6 +177,10 @@ async def _search_tavily(category: str, keyword: str, max_results: int) -> Categ
topics = _normalize_tavily_results(result.get("results", []))
logger.info(f"[CategoryResearch] Tavily found {len(topics)} topics")
# Track usage
cost = 0.001 # basic search = 1 credit
_track_research_usage(user_id, "tavily", cost, "tavily_calls", "tavily_cost")
return CategoryResearchResponse(
success=True,
category=category,
@@ -117,7 +196,7 @@ async def _search_tavily(category: str, keyword: str, max_results: int) -> Categ
raise HTTPException(status_code=500, detail=str(e))
async def _search_exa(category: str, keyword: str, max_results: int, website_url: Optional[str] = None) -> CategoryResearchResponse:
async def _search_exa(category: str, keyword: str, max_results: int, user_id: str, website_url: Optional[str] = None) -> CategoryResearchResponse:
exa_category = EXA_CATEGORY_MAP.get(category, category)
logger.info(f"[CategoryResearch] Exa: category={category}, exa_category={exa_category}, keyword={keyword}, website_url={website_url}")
@@ -133,6 +212,9 @@ async def _search_exa(category: str, keyword: str, max_results: int, website_url
from exa_py import Exa
exa = Exa(exa_api_key)
logger.info(f"[CategoryResearch] Exa client initialized")
# Preflight subscription check
_preflight_check(user_id, APIProvider.EXA, "exa")
# Build search parameters
search_params = {
@@ -189,6 +271,10 @@ async def _search_exa(category: str, keyword: str, max_results: int, website_url
logger.info(f"[CategoryResearch] Exa found {len(topics)} topics")
# Track usage
cost = 0.005 # Default Exa cost for 1-25 results
_track_research_usage(user_id, "exa", cost, "exa_calls", "exa_cost")
return CategoryResearchResponse(
success=True,
category=category,
@@ -218,6 +304,7 @@ async def research_by_category(
- news, finance: Uses Tavily
- research-paper, personal-site: Uses Exa
"""
user_id = require_authenticated_user(current_user)
category = request.category.lower()
valid_categories = list(CATEGORY_PROVIDER_MAP.keys())
@@ -241,9 +328,9 @@ async def research_by_category(
try:
if provider == "tavily":
return await _search_tavily(category, keyword, max_results)
return await _search_tavily(category, keyword, max_results, user_id)
elif provider == "exa":
return await _search_exa(category, keyword, max_results, website_url)
return await _search_exa(category, keyword, max_results, user_id, website_url)
else:
raise HTTPException(status_code=500, detail="Unknown provider")
except Exception as e:

View File

@@ -4,6 +4,7 @@ Podcast Trends Handler
Endpoints for fetching Google Trends data relevant to podcast topics.
"""
import asyncio
from fastapi import APIRouter, Depends, HTTPException
from typing import Dict, Any, List, Optional
from pydantic import BaseModel, Field
@@ -13,6 +14,25 @@ from middleware.auth_middleware import get_current_user
router = APIRouter(prefix="/trends", tags=["Podcast Trends"])
# Module-level shared instance (singleton pattern)
_trends_service_instance = None
_trends_service_lock = None
def get_trends_service():
"""Get or create shared GoogleTrendsService instance."""
global _trends_service_instance, _trends_service_lock
if _trends_service_instance is None:
try:
from services.research.trends import GoogleTrendsService
_trends_service_instance = GoogleTrendsService()
_trends_service_lock = asyncio.Lock()
logger.info("[Podcast Trends] Created shared GoogleTrendsService instance")
except (ImportError, RuntimeError) as e:
logger.error(f"[Podcast Trends] Failed to create GoogleTrendsService: {e}")
raise
return _trends_service_instance
class PodcastTrendsRequest(BaseModel):
keywords: List[str] = Field(..., min_length=1, max_length=5, description="1-5 keywords to analyze")
@@ -38,7 +58,7 @@ async def get_podcast_trends(
raise HTTPException(status_code=401, detail="User ID not found")
try:
from services.research.trends import GoogleTrendsService
service = get_trends_service()
except (ImportError, RuntimeError) as e:
logger.error(f"[Podcast Trends] GoogleTrendsService unavailable: {e}")
raise HTTPException(
@@ -47,11 +67,10 @@ async def get_podcast_trends(
)
try:
service = GoogleTrendsService()
# Map 'source' to 'gprop' - 'podcast' uses YouTube for video/podcast relevance
gprop_map = {"": "", "web": "", "podcast": "youtube", "news": "news", "images": "images", "shopping": "froogle"}
gprop = gprop_map.get(request.source, "")
result = await service.analyze_trends(
keywords=request.keywords,
timeframe=request.timeframe,
@@ -73,7 +92,15 @@ async def get_podcast_trends(
# Return error if: has error OR no data (meaning blocked/empty)
if has_error and not has_data:
error_msg = result.get("error", "")
cooldown_active = result.get("cooldown_active", False)
logger.warning(f"[Trends] No data or error: {error_msg[:100]}")
# Provide helpful message during cooldown
if cooldown_active:
return PodcastTrendsResponse(
success=False,
data=result,
error="Google is rate limiting requests. Try using 'Get Trending Topics' instead, or wait 30 minutes."
)
return PodcastTrendsResponse(success=False, data=result, error=error_msg or "No trends data available. Google may be blocking requests.")
# Even if no error but empty data - return error

View File

@@ -12,7 +12,7 @@ import sqlite3
from services.database import get_db
from services.subscription import UsageTrackingService, PricingService
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
from models.subscription_models import UsageAlert
from models.subscription_models import UsageAlert, UserSubscription
from middleware.auth_middleware import get_current_user
from ..dependencies import verify_user_access
from ..cache import get_cached_dashboard, set_cached_dashboard
@@ -27,7 +27,9 @@ async def get_dashboard_data(
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Get comprehensive dashboard data for usage monitoring."""
"""Get comprehensive dashboard data for usage monitoring.
Returns all-time total + current period usage by default.
When billing_period is specified, returns that period's data only."""
verify_user_access(user_id, current_user)
@@ -35,17 +37,23 @@ async def get_dashboard_data(
ensure_subscription_plan_columns(db)
ensure_usage_summaries_columns(db)
# Check cache first (skip if billing_period is specified)
if not billing_period:
cached_data = get_cached_dashboard(user_id)
if cached_data:
return cached_data
# Check cache first (only for default view, skip when a specific period is requested)
cached_data = get_cached_dashboard(user_id)
if cached_data and not billing_period:
return cached_data
usage_service = UsageTrackingService(db)
pricing_service = PricingService(db)
# Get current usage stats (for the requested period)
current_usage = usage_service.get_user_usage_stats(user_id, billing_period)
# When a specific billing_period is requested, show only that period's data
# Otherwise show all-time total + current period usage
if billing_period:
period_usage = usage_service.get_usage_for_period(user_id, billing_period)
total_usage = period_usage
current_period_usage = period_usage
else:
total_usage = usage_service.get_user_usage_stats(user_id, None)
current_period_usage = usage_service.get_current_period_usage(user_id)
# Get usage trends (last 6 months)
trends = usage_service.get_usage_trends(user_id, 6)
@@ -76,13 +84,44 @@ async def get_dashboard_data(
]
# Calculate cost projections (only relevant for current month)
current_cost = current_usage.get('total_cost', 0)
current_cost = total_usage.get('total_cost', 0)
days_in_period = 30
current_day = datetime.now().day
# Only project costs if viewing current month
is_current_month = not billing_period or billing_period == datetime.now().strftime("%Y-%m")
if is_current_month:
# Determine if viewing current period based on subscription, not calendar
subscription = db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
# Use subscription's billing period or fallback to calendar
if subscription and subscription.current_period_start:
sub_period = subscription.current_period_start.strftime("%Y-%m")
calendar_period = datetime.now().strftime("%Y-%m")
# Check if we have data for subscription period or calendar period
from models.subscription_models import UsageSummary
sub_data_exists = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == sub_period
).first()
# Determine which period to use for "current"
if sub_data_exists:
effective_period = sub_period
else:
# Check calendar period for backward compatibility
cal_data_exists = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == calendar_period
).first()
effective_period = calendar_period if cal_data_exists else sub_period
is_current_period = not billing_period or billing_period == effective_period
else:
is_current_period = not billing_period or billing_period == datetime.now().strftime("%Y-%m")
if is_current_period:
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
else:
projected_cost = current_cost # For past months, projected is actual
@@ -90,7 +129,8 @@ async def get_dashboard_data(
response_payload = {
"success": True,
"data": {
"current_usage": current_usage,
"total_usage": total_usage,
"current_period_usage": current_period_usage,
"trends": trends,
"limits": limits,
"alerts": alerts_data,
@@ -100,9 +140,9 @@ async def get_dashboard_data(
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
},
"summary": {
"total_api_calls_this_month": current_usage.get('total_calls', 0),
"total_cost_this_month": current_usage.get('total_cost', 0),
"usage_status": current_usage.get('usage_status', 'active'),
"total_api_calls_this_month": total_usage.get('total_calls', 0),
"total_cost_this_month": total_usage.get('total_cost', 0),
"usage_status": total_usage.get('usage_status', 'active'),
"unread_alerts": len(alerts_data)
}
}
@@ -131,7 +171,13 @@ async def get_dashboard_data(
usage_service = UsageTrackingService(db)
pricing_service = PricingService(db)
current_usage = usage_service.get_user_usage_stats(user_id)
if billing_period:
period_usage = usage_service.get_usage_for_period(user_id, billing_period)
total_usage = period_usage
current_period_usage = period_usage
else:
total_usage = usage_service.get_user_usage_stats(user_id, None)
current_period_usage = usage_service.get_current_period_usage(user_id)
trends = usage_service.get_usage_trends(user_id, 6)
limits = pricing_service.get_user_limits(user_id)
@@ -152,7 +198,7 @@ async def get_dashboard_data(
for alert in alerts
]
current_cost = current_usage.get('total_cost', 0)
current_cost = total_usage.get('total_cost', 0)
days_in_period = 30
current_day = datetime.now().day
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
@@ -160,7 +206,8 @@ async def get_dashboard_data(
response_payload = {
"success": True,
"data": {
"current_usage": current_usage,
"total_usage": total_usage,
"current_period_usage": current_period_usage,
"trends": trends,
"limits": limits,
"alerts": alerts_data,
@@ -170,16 +217,17 @@ async def get_dashboard_data(
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
},
"summary": {
"total_api_calls_this_month": current_usage.get('total_calls', 0),
"total_cost_this_month": current_usage.get('total_cost', 0),
"usage_status": current_usage.get('usage_status', 'active'),
"total_api_calls_this_month": total_usage.get('total_calls', 0),
"total_cost_this_month": total_usage.get('total_cost', 0),
"usage_status": total_usage.get('usage_status', 'active'),
"unread_alerts": len(alerts_data)
}
}
}
# Cache the response after successful retry
set_cached_dashboard(user_id, response_payload)
# Cache the response after successful retry (only for default view)
if not billing_period:
set_cached_dashboard(user_id, response_payload)
return response_payload
except Exception as retry_err:
logger.error(f"Schema fix and retry failed: {retry_err}")
@@ -187,7 +235,8 @@ async def get_dashboard_data(
"success": False,
"error": str(retry_err),
"data": {
"current_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
"total_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
"current_period_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}, "usage_percentages": {}},
"trends": [],
"limits": {"limits": {"monthly_cost": 0}},
"alerts": [],
@@ -201,7 +250,8 @@ async def get_dashboard_data(
"success": False,
"error": str(e),
"data": {
"current_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
"total_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
"current_period_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}, "usage_percentages": {}},
"trends": [],
"limits": {"limits": {"monthly_cost": 0}},
"alerts": [],

View File

@@ -14,13 +14,21 @@ def format_plan_limits(plan: SubscriptionPlan) -> Dict[str, Any]:
"""
Format subscription plan limits for API response.
Includes _zero_means metadata per field to disambiguate:
- 'disabled': 0 means the feature is not available (Free tier)
- 'unlimited': 0 means unlimited usage (Enterprise tier)
- 'limited': >0 means numerical limit applies
Args:
plan: SubscriptionPlan model instance
Returns:
Dictionary with formatted limits
Dictionary with formatted limits and _zero_means metadata
"""
return {
tier = plan.tier.value if hasattr(plan.tier, 'value') else str(plan.tier)
is_enterprise = tier == 'enterprise'
limit_fields = {
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
"gemini_calls": plan.gemini_calls_limit,
"openai_calls": plan.openai_calls_limit,
@@ -35,11 +43,43 @@ def format_plan_limits(plan: SubscriptionPlan) -> Dict[str, Any]:
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0) or 0,
"audio_calls": getattr(plan, 'audio_calls_limit', 0) or 0,
"exa_calls": getattr(plan, 'exa_calls_limit', 0) or 0,
"wavespeed_calls": getattr(plan, 'wavespeed_calls_limit', 0) or 0,
"gemini_tokens": plan.gemini_tokens_limit,
"openai_tokens": plan.openai_tokens_limit,
"anthropic_tokens": plan.anthropic_tokens_limit,
"mistral_tokens": plan.mistral_tokens_limit,
"monthly_cost": plan.monthly_cost_limit
"monthly_cost": plan.monthly_cost_limit,
}
# Build _zero_means metadata: indicates whether 0 means 'disabled' or 'unlimited'
zero_means = {}
for field, value in limit_fields.items():
if field == "monthly_cost":
zero_means[field] = "disabled"
elif is_enterprise:
# Enterprise: 0 means unlimited for all call/token fields
zero_means[field] = "unlimited"
else:
# Free/Basic/Pro: determine per-field
# Fields that are 0=disabled on Free tier but 0=unlimited on Basic/Pro
call_and_token_fields = {
"gemini_calls", "openai_calls", "anthropic_calls", "mistral_calls",
"tavily_calls", "serper_calls", "metaphor_calls", "firecrawl_calls",
"stability_calls", "video_calls", "image_edit_calls", "audio_calls",
"exa_calls", "wavespeed_calls", "ai_text_generation_calls",
"gemini_tokens", "openai_tokens", "anthropic_tokens", "mistral_tokens",
}
if field in call_and_token_fields:
if value == 0:
zero_means[field] = "disabled" if tier == "free" else "unlimited"
else:
zero_means[field] = "limited"
else:
zero_means[field] = "limited" if value > 0 else "disabled"
return {
**limit_fields,
"_zero_means": zero_means,
}

View File

@@ -1,9 +1,10 @@
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from typing import List, Any, Dict
from loguru import logger
from services.writing_assistant import WritingAssistantService
from middleware.auth_middleware import get_current_user
router = APIRouter(prefix="/api/writing-assistant", tags=["writing-assistant"])
@@ -11,7 +12,6 @@ router = APIRouter(prefix="/api/writing-assistant", tags=["writing-assistant"])
class SuggestRequest(BaseModel):
text: str
max_results: int | None = 1
class SourceModel(BaseModel):
@@ -38,9 +38,10 @@ assistant_service = WritingAssistantService()
@router.post("/suggest", response_model=SuggestResponse)
async def suggest_endpoint(req: SuggestRequest) -> SuggestResponse:
async def suggest_endpoint(req: SuggestRequest, current_user: Dict[str, Any] = Depends(get_current_user)) -> SuggestResponse:
try:
suggestions = await assistant_service.suggest(req.text, req.max_results or 1)
user_id = current_user.get("id")
suggestions = await assistant_service.suggest(req.text, user_id=user_id)
return SuggestResponse(
success=True,
suggestions=[