Subscription implementation complete, Renewal system implemented
This commit is contained in:
@@ -30,12 +30,16 @@ class RateLimiter:
|
||||
"/calendar-events",
|
||||
"/calendar-generation/progress",
|
||||
"/health",
|
||||
"/health/database"
|
||||
"/health/database",
|
||||
]
|
||||
# Prefixes to exempt entire route families (keep empty; rely on specific exemptions only)
|
||||
self.exempt_prefixes = []
|
||||
|
||||
def is_exempt_path(self, path: str) -> bool:
|
||||
"""Check if a path is exempt from rate limiting."""
|
||||
return any(exempt_path in path for exempt_path in self.exempt_paths)
|
||||
return any(exempt_path == path or exempt_path in path for exempt_path in self.exempt_paths) or any(
|
||||
path.startswith(prefix) for prefix in self.exempt_prefixes
|
||||
)
|
||||
|
||||
def clean_old_requests(self, client_ip: str, current_time: float) -> None:
|
||||
"""Clean old requests from the tracking dictionary."""
|
||||
@@ -77,7 +81,6 @@ class RateLimiter:
|
||||
|
||||
# Check if path is exempt from rate limiting
|
||||
if self.is_exempt_path(path):
|
||||
# Allow streaming endpoints without rate limiting
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ class OnboardingCompletionService:
|
||||
"""Service for handling onboarding completion logic."""
|
||||
|
||||
def __init__(self):
|
||||
# Only pre-requisite steps; step 6 is the finalization itself
|
||||
self.required_steps = [1, 2, 3]
|
||||
# Pre-requisite steps; step 6 is the finalization itself
|
||||
self.required_steps = [1, 2, 3, 4, 5]
|
||||
|
||||
async def complete_onboarding(self, current_user: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Complete the onboarding process with full validation."""
|
||||
@@ -73,9 +73,15 @@ class OnboardingCompletionService:
|
||||
db = None
|
||||
db_service = None
|
||||
|
||||
logger.info(f"OnboardingCompletionService: Validating steps for user {user_id}")
|
||||
logger.info(f"OnboardingCompletionService: Current step: {progress.current_step}")
|
||||
logger.info(f"OnboardingCompletionService: Required steps: {self.required_steps}")
|
||||
|
||||
for step_num in self.required_steps:
|
||||
step = progress.get_step_data(step_num)
|
||||
logger.info(f"OnboardingCompletionService: Step {step_num} - status: {step.status if step else 'None'}")
|
||||
if step and step.status in [StepStatus.COMPLETED, StepStatus.SKIPPED]:
|
||||
logger.info(f"OnboardingCompletionService: Step {step_num} already completed/skipped")
|
||||
continue
|
||||
|
||||
# DB-aware fallbacks for migration period
|
||||
@@ -129,6 +135,30 @@ class OnboardingCompletionService:
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
if step_num == 4:
|
||||
# Treat as completed if persona data exists in DB
|
||||
persona = None
|
||||
try:
|
||||
persona = db_service.get_persona_data(user_id, db)
|
||||
except Exception:
|
||||
persona = None
|
||||
if persona and persona.get('corePersona'):
|
||||
try:
|
||||
progress.mark_step_completed(4, {'source': 'db-fallback'})
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
if step_num == 5:
|
||||
# Treat as completed if integrations data exists in DB
|
||||
# For now, we'll consider step 5 completed if the user has reached the final step
|
||||
# This is a simplified approach - in the future, we could check for specific integration data
|
||||
try:
|
||||
# Check if user has completed previous steps and is on final step
|
||||
if progress.current_step >= 6: # FinalStep is step 6
|
||||
progress.mark_step_completed(5, {'source': 'final-step-fallback'})
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
# If DB check fails, fall back to progress status only
|
||||
pass
|
||||
|
||||
@@ -134,11 +134,11 @@ async def generate_writing_personas_async(
|
||||
"request_data": (PersonaGenerationRequest(**(request if isinstance(request, dict) else request.dict())).dict()) if request else {}
|
||||
}
|
||||
logger.info(f"Cache hit for user {user_id} - returning completed task without regeneration: {task_id}")
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": "completed",
|
||||
"message": "Persona loaded from cache"
|
||||
}
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": "completed",
|
||||
"message": "Persona loaded from cache"
|
||||
}
|
||||
|
||||
# Generate unique task ID
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
@@ -380,6 +380,13 @@ async def subscribe_to_plan(
|
||||
|
||||
db.commit()
|
||||
|
||||
# Reset usage status for current billing period so new plan takes effect immediately
|
||||
try:
|
||||
usage_service = UsageTrackingService(db)
|
||||
await usage_service.reset_current_billing_period(user_id)
|
||||
except Exception as reset_err:
|
||||
logger.error(f"Failed to reset usage after subscribe: {reset_err}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Successfully subscribed to {plan.name}",
|
||||
|
||||
@@ -35,25 +35,25 @@ class DatabaseAPIMonitor:
|
||||
# API provider detection patterns - Updated to match actual endpoints
|
||||
self.provider_patterns = {
|
||||
APIProvider.GEMINI: [
|
||||
r'/api/blog-writer', r'/api/content-planning', r'/api/strategy-copilot',
|
||||
r'/api/brainstorm', r'/api/writing-assistant', r'/api/seo-dashboard',
|
||||
r'/api/onboarding', r'/api/user-data', r'/api/component-logic',
|
||||
r'gemini', r'google.*ai', r'blog.*writer', r'content.*planning'
|
||||
r'gemini', r'google.*ai'
|
||||
],
|
||||
APIProvider.OPENAI: [r'/openai', r'openai', r'gpt', r'chatgpt'],
|
||||
APIProvider.ANTHROPIC: [r'/anthropic', r'claude', r'anthropic'],
|
||||
APIProvider.MISTRAL: [r'/mistral', r'mistral'],
|
||||
APIProvider.TAVILY: [r'/tavily', r'tavily', r'research', r'search'],
|
||||
APIProvider.SERPER: [r'/serper', r'serper', r'google.*search', r'seo'],
|
||||
APIProvider.METAPHOR: [r'/metaphor', r'/exa', r'metaphor', r'exa'],
|
||||
APIProvider.FIRECRAWL: [r'/firecrawl', r'firecrawl', r'crawl'],
|
||||
APIProvider.STABILITY: [r'/stability', r'stable.*diffusion', r'stability', r'image.*generation']
|
||||
APIProvider.OPENAI: [r'openai', r'gpt', r'chatgpt'],
|
||||
APIProvider.ANTHROPIC: [r'anthropic', r'claude'],
|
||||
APIProvider.MISTRAL: [r'mistral'],
|
||||
APIProvider.TAVILY: [r'tavily'],
|
||||
APIProvider.SERPER: [r'serper'],
|
||||
APIProvider.METAPHOR: [r'metaphor', r'/exa'],
|
||||
APIProvider.FIRECRAWL: [r'firecrawl']
|
||||
}
|
||||
|
||||
def detect_api_provider(self, path: str, user_agent: str = None) -> Optional[APIProvider]:
|
||||
"""Detect which API provider is being used based on request details."""
|
||||
path_lower = path.lower()
|
||||
user_agent_lower = (user_agent or '').lower()
|
||||
|
||||
# Permanently ignore internal route families that must not accrue or check provider usage
|
||||
if path_lower.startswith('/api/onboarding/') or path_lower.startswith('/api/subscription/'):
|
||||
return None
|
||||
|
||||
for provider, patterns in self.provider_patterns.items():
|
||||
for pattern in patterns:
|
||||
@@ -384,16 +384,26 @@ EXCLUDED_ENDPOINTS = [
|
||||
"/api/content-planning/monitoring/cache-stats",
|
||||
"/api/content-planning/monitoring/health"
|
||||
]
|
||||
# Also exclude whole route families by prefix (e.g., subscription/billing must never be blocked)
|
||||
EXCLUDED_PREFIXES = [
|
||||
]
|
||||
|
||||
|
||||
def should_monitor_endpoint(path: str) -> bool:
|
||||
"""Check if an endpoint should be monitored."""
|
||||
return not any(path.endswith(excluded) for excluded in EXCLUDED_ENDPOINTS)
|
||||
return not any(path.endswith(excluded) for excluded in EXCLUDED_ENDPOINTS) and not any(path.startswith(prefix) for prefix in EXCLUDED_PREFIXES)
|
||||
|
||||
async def check_usage_limits_middleware(request: Request, user_id: str, request_body: str = None) -> Optional[JSONResponse]:
|
||||
"""Check usage limits before processing request."""
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
# No special whitelist; onboarding/subscription are ignored by provider detection
|
||||
try:
|
||||
path = request.url.path
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
db = next(get_db())
|
||||
api_monitor = DatabaseAPIMonitor()
|
||||
|
||||
@@ -157,8 +157,8 @@ class BlogSEOMetadataGenerator:
|
||||
|
||||
# Get structured response from Gemini
|
||||
ai_response = self.gemini_provider(
|
||||
prompt=prompt,
|
||||
schema=schema,
|
||||
prompt,
|
||||
schema,
|
||||
temperature=0.3,
|
||||
max_tokens=2048
|
||||
)
|
||||
@@ -167,6 +167,8 @@ class BlogSEOMetadataGenerator:
|
||||
if not ai_response or not isinstance(ai_response, dict):
|
||||
logger.error("Core metadata generation failed: Invalid response from Gemini")
|
||||
# Return fallback response
|
||||
primary_keywords = ', '.join(keywords_data.get('primary_keywords', ['content']))
|
||||
word_count = len(blog_content.split())
|
||||
return {
|
||||
'seo_title': blog_title,
|
||||
'meta_description': f'Learn about {primary_keywords.split(", ")[0] if primary_keywords else "this topic"}.',
|
||||
@@ -246,8 +248,8 @@ class BlogSEOMetadataGenerator:
|
||||
|
||||
# Get structured response from Gemini
|
||||
ai_response = self.gemini_provider(
|
||||
prompt=prompt,
|
||||
schema=schema,
|
||||
prompt,
|
||||
schema,
|
||||
temperature=0.3,
|
||||
max_tokens=2048
|
||||
)
|
||||
|
||||
@@ -348,6 +348,11 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
|
||||
try:
|
||||
# Get API key with proper error handling
|
||||
api_key = get_gemini_api_key()
|
||||
logger.info(f"🔑 Gemini API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})")
|
||||
|
||||
if not api_key:
|
||||
raise Exception("GEMINI_API_KEY not found in environment variables")
|
||||
|
||||
client = genai.Client(api_key=api_key)
|
||||
logger.info("✅ Gemini client initialized for structured JSON response")
|
||||
|
||||
@@ -383,11 +388,18 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
|
||||
system_instruction=system_prompt,
|
||||
)
|
||||
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.5-flash",
|
||||
contents=prompt,
|
||||
config=generation_config,
|
||||
)
|
||||
logger.info("🚀 Making Gemini API call...")
|
||||
try:
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.5-flash",
|
||||
contents=prompt,
|
||||
config=generation_config,
|
||||
)
|
||||
logger.info("✅ Gemini API call completed successfully")
|
||||
except Exception as api_error:
|
||||
logger.error(f"❌ Gemini API call failed: {api_error}")
|
||||
logger.error(f"❌ API Error type: {type(api_error).__name__}")
|
||||
raise api_error
|
||||
|
||||
# Check for parsed content first (primary method for structured output)
|
||||
if hasattr(response, 'parsed'):
|
||||
|
||||
@@ -485,4 +485,27 @@ class UsageTrackingService:
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_requested=tokens_requested
|
||||
)
|
||||
)
|
||||
|
||||
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Reset usage status for the current billing period (after plan change)."""
|
||||
try:
|
||||
billing_period = datetime.now().strftime("%Y-%m")
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == billing_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
# Nothing to reset
|
||||
return {"reset": False, "reason": "no_summary"}
|
||||
|
||||
# Clear LIMIT_REACHED so the user can resume; keep counters intact
|
||||
summary.usage_status = UsageStatus.ACTIVE
|
||||
summary.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
return {"reset": True}
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error resetting usage status: {e}")
|
||||
return {"reset": False, "error": str(e)}
|
||||
Reference in New Issue
Block a user