from typing import Any, Dict, List from fastapi import APIRouter, Depends, HTTPException from loguru import logger from middleware.auth_middleware import get_current_user from models.story_models import ( StoryStartRequest, StoryContentResponse, StoryScene, StoryContinueRequest, StoryContinueResponse, ) from services.story_writer.story_service import StoryWriterService from ..utils.auth import require_authenticated_user router = APIRouter() story_service = StoryWriterService() @router.post("/generate-start", response_model=StoryContentResponse) async def generate_story_start( request: StoryStartRequest, current_user: Dict[str, Any] = Depends(get_current_user), ) -> StoryContentResponse: """Generate the starting section of a story.""" try: user_id = require_authenticated_user(current_user) if not request.premise or not request.premise.strip(): raise HTTPException(status_code=400, detail="Premise is required") if not request.outline or (isinstance(request.outline, str) and not request.outline.strip()): raise HTTPException(status_code=400, detail="Outline is required") logger.info(f"[StoryWriter] Generating story start for user {user_id}") outline_data: Any = request.outline if isinstance(outline_data, list) and outline_data and isinstance(outline_data[0], StoryScene): outline_data = [scene.dict() for scene in outline_data] story_length = getattr(request, "story_length", "Medium") story_start = story_service.generate_story_start( premise=request.premise, outline=outline_data, persona=request.persona, story_setting=request.story_setting, character_input=request.character_input, plot_elements=request.plot_elements, writing_style=request.writing_style, story_tone=request.story_tone, narrative_pov=request.narrative_pov, audience_age_group=request.audience_age_group, content_rating=request.content_rating, ending_preference=request.ending_preference, story_length=story_length, user_id=user_id, ) story_length_lower = story_length.lower() is_short_story = "short" in story_length_lower or "1000" in story_length_lower is_complete = False if is_short_story: word_count = len(story_start.split()) if story_start else 0 if word_count >= 900: is_complete = True logger.info( f"[StoryWriter] Short story generated with {word_count} words. Marking as complete." ) else: logger.warning( f"[StoryWriter] Short story generated with only {word_count} words. May need continuation." ) outline_response = outline_data if isinstance(outline_response, list): outline_response = "\n".join( [ f"Scene {scene.get('scene_number', i + 1)}: " f"{scene.get('title', 'Untitled')}\n {scene.get('description', '')}" for i, scene in enumerate(outline_response) ] ) return StoryContentResponse( story=story_start, premise=request.premise, outline=str(outline_response), is_complete=is_complete, success=True, ) except HTTPException: raise except Exception as exc: logger.error(f"[StoryWriter] Failed to generate story start: {exc}") raise HTTPException(status_code=500, detail=str(exc)) @router.post("/continue", response_model=StoryContinueResponse) async def continue_story( request: StoryContinueRequest, current_user: Dict[str, Any] = Depends(get_current_user), ) -> StoryContinueResponse: """Continue writing a story.""" try: user_id = require_authenticated_user(current_user) if not request.story_text or not request.story_text.strip(): raise HTTPException(status_code=400, detail="Story text is required") logger.info(f"[StoryWriter] Continuing story for user {user_id}") outline_data: Any = request.outline if isinstance(outline_data, list) and outline_data and isinstance(outline_data[0], StoryScene): outline_data = [scene.dict() for scene in outline_data] story_length = getattr(request, "story_length", "Medium") story_length_lower = story_length.lower() is_short_story = "short" in story_length_lower or "1000" in story_length_lower if is_short_story: logger.warning( "[StoryWriter] Attempted to continue a short story. Short stories should be complete in one call." ) raise HTTPException( status_code=400, detail="Short stories are generated in a single call and should be complete. " "If the story is incomplete, please regenerate it from the beginning.", ) current_word_count = len(request.story_text.split()) if request.story_text else 0 if "long" in story_length_lower or "10000" in story_length_lower: target_total_words = 10000 else: target_total_words = 4500 buffer_target = int(target_total_words * 1.05) if current_word_count >= buffer_target or ( current_word_count >= target_total_words and (current_word_count - target_total_words) < 50 ): logger.info( f"[StoryWriter] Word count ({current_word_count}) already at or near target ({target_total_words})." ) return StoryContinueResponse(continuation="IAMDONE", is_complete=True, success=True) continuation = story_service.continue_story( premise=request.premise, outline=outline_data, story_text=request.story_text, persona=request.persona, story_setting=request.story_setting, character_input=request.character_input, plot_elements=request.plot_elements, writing_style=request.writing_style, story_tone=request.story_tone, narrative_pov=request.narrative_pov, audience_age_group=request.audience_age_group, content_rating=request.content_rating, ending_preference=request.ending_preference, story_length=story_length, user_id=user_id, ) is_complete = "IAMDONE" in continuation.upper() if not is_complete and continuation: new_story_text = request.story_text + "\n\n" + continuation new_word_count = len(new_story_text.split()) if new_word_count >= buffer_target: logger.info( f"[StoryWriter] Word count ({new_word_count}) now exceeds buffer target ({buffer_target})." ) if "IAMDONE" not in continuation.upper(): continuation = continuation.rstrip() + "\n\nIAMDONE" is_complete = True elif new_word_count >= target_total_words and ( new_word_count - target_total_words ) < 100: logger.info( f"[StoryWriter] Word count ({new_word_count}) is at or very close to target ({target_total_words})." ) if "IAMDONE" not in continuation.upper(): continuation = continuation.rstrip() + "\n\nIAMDONE" is_complete = True return StoryContinueResponse(continuation=continuation, is_complete=is_complete, success=True) except HTTPException: raise except Exception as exc: logger.error(f"[StoryWriter] Failed to continue story: {exc}") raise HTTPException(status_code=500, detail=str(exc))