403 lines
16 KiB
Python
403 lines
16 KiB
Python
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from loguru import logger
|
|
from pydantic import BaseModel, Field
|
|
|
|
from middleware.auth_middleware import get_current_user
|
|
from models.story_models import (
|
|
StoryStartRequest,
|
|
StoryContentResponse,
|
|
StoryScene,
|
|
StoryContinueRequest,
|
|
StoryContinueResponse,
|
|
AnimeSceneTextRequest,
|
|
AnimeSceneTextResponse,
|
|
AnimeSceneGenerateRequest,
|
|
AnimeSceneGenerateResponse,
|
|
)
|
|
from services.story_writer.story_service import StoryWriterService
|
|
|
|
from ..utils.auth import require_authenticated_user
|
|
|
|
|
|
router = APIRouter()
|
|
story_service = StoryWriterService()
|
|
scene_approval_store: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]] = {}
|
|
APPROVAL_TTL_SECONDS = 60 * 60 * 24
|
|
MAX_APPROVALS_PER_USER = 200
|
|
|
|
|
|
def _cleanup_user_approvals(user_id: str) -> None:
|
|
user_store = scene_approval_store.get(user_id)
|
|
if not user_store:
|
|
return
|
|
now = datetime.utcnow()
|
|
for project_id in list(user_store.keys()):
|
|
scenes = user_store.get(project_id, {})
|
|
for scene_id in list(scenes.keys()):
|
|
timestamp = scenes[scene_id].get("timestamp")
|
|
if isinstance(timestamp, datetime):
|
|
if (now - timestamp).total_seconds() > APPROVAL_TTL_SECONDS:
|
|
scenes.pop(scene_id, None)
|
|
if not scenes:
|
|
user_store.pop(project_id, None)
|
|
if not user_store:
|
|
scene_approval_store.pop(user_id, None)
|
|
|
|
|
|
def _enforce_capacity(user_id: str) -> None:
|
|
user_store = scene_approval_store.get(user_id)
|
|
if not user_store:
|
|
return
|
|
entries: List[tuple[datetime, str, str]] = []
|
|
for project_id, scenes in user_store.items():
|
|
for scene_id, meta in scenes.items():
|
|
timestamp = meta.get("timestamp")
|
|
if isinstance(timestamp, datetime):
|
|
entries.append((timestamp, project_id, scene_id))
|
|
if len(entries) <= MAX_APPROVALS_PER_USER:
|
|
return
|
|
entries.sort(key=lambda item: item[0])
|
|
to_remove = len(entries) - MAX_APPROVALS_PER_USER
|
|
for i in range(to_remove):
|
|
_, project_id, scene_id = entries[i]
|
|
scenes = user_store.get(project_id)
|
|
if not scenes:
|
|
continue
|
|
scenes.pop(scene_id, None)
|
|
if not scenes:
|
|
user_store.pop(project_id, None)
|
|
|
|
|
|
def _get_user_store(user_id: str) -> Dict[str, Dict[str, Dict[str, Any]]]:
|
|
_cleanup_user_approvals(user_id)
|
|
return scene_approval_store.setdefault(user_id, {})
|
|
|
|
|
|
@router.post("/generate-start", response_model=StoryContentResponse)
|
|
async def generate_story_start(
|
|
request: StoryStartRequest,
|
|
current_user: Dict[str, Any] = Depends(get_current_user),
|
|
) -> StoryContentResponse:
|
|
"""Generate the starting section of a story."""
|
|
try:
|
|
user_id = require_authenticated_user(current_user)
|
|
|
|
if not request.premise or not request.premise.strip():
|
|
raise HTTPException(status_code=400, detail="Premise is required")
|
|
if not request.outline or (isinstance(request.outline, str) and not request.outline.strip()):
|
|
raise HTTPException(status_code=400, detail="Outline is required")
|
|
|
|
logger.info(f"[StoryWriter] Generating story start for user {user_id}")
|
|
|
|
outline_data: Any = request.outline
|
|
if isinstance(outline_data, list) and outline_data and isinstance(outline_data[0], StoryScene):
|
|
outline_data = [scene.dict() for scene in outline_data]
|
|
|
|
story_length = getattr(request, "story_length", "Medium")
|
|
story_start = story_service.generate_story_start(
|
|
premise=request.premise,
|
|
outline=outline_data,
|
|
persona=request.persona,
|
|
story_setting=request.story_setting,
|
|
character_input=request.character_input,
|
|
plot_elements=request.plot_elements,
|
|
writing_style=request.writing_style,
|
|
story_tone=request.story_tone,
|
|
narrative_pov=request.narrative_pov,
|
|
audience_age_group=request.audience_age_group,
|
|
content_rating=request.content_rating,
|
|
ending_preference=request.ending_preference,
|
|
story_length=story_length,
|
|
anime_bible=getattr(request, "anime_bible", None),
|
|
user_id=user_id,
|
|
)
|
|
|
|
story_length_lower = story_length.lower()
|
|
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
|
|
is_complete = False
|
|
if is_short_story:
|
|
word_count = len(story_start.split()) if story_start else 0
|
|
if word_count >= 900:
|
|
is_complete = True
|
|
logger.info(
|
|
f"[StoryWriter] Short story generated with {word_count} words. Marking as complete."
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"[StoryWriter] Short story generated with only {word_count} words. May need continuation."
|
|
)
|
|
|
|
outline_response = outline_data
|
|
if isinstance(outline_response, list):
|
|
outline_response = "\n".join(
|
|
[
|
|
f"Scene {scene.get('scene_number', i + 1)}: "
|
|
f"{scene.get('title', 'Untitled')}\n {scene.get('description', '')}"
|
|
for i, scene in enumerate(outline_response)
|
|
]
|
|
)
|
|
|
|
return StoryContentResponse(
|
|
story=story_start,
|
|
premise=request.premise,
|
|
outline=str(outline_response),
|
|
is_complete=is_complete,
|
|
success=True,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
logger.error(f"[StoryWriter] Failed to generate story start: {exc}")
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
@router.post("/continue", response_model=StoryContinueResponse)
|
|
async def continue_story(
|
|
request: StoryContinueRequest,
|
|
current_user: Dict[str, Any] = Depends(get_current_user),
|
|
) -> StoryContinueResponse:
|
|
"""Continue writing a story."""
|
|
try:
|
|
user_id = require_authenticated_user(current_user)
|
|
|
|
if not request.story_text or not request.story_text.strip():
|
|
raise HTTPException(status_code=400, detail="Story text is required")
|
|
|
|
logger.info(f"[StoryWriter] Continuing story for user {user_id}")
|
|
|
|
outline_data: Any = request.outline
|
|
if isinstance(outline_data, list) and outline_data and isinstance(outline_data[0], StoryScene):
|
|
outline_data = [scene.dict() for scene in outline_data]
|
|
|
|
story_length = getattr(request, "story_length", "Medium")
|
|
story_length_lower = story_length.lower()
|
|
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
|
|
if is_short_story:
|
|
logger.warning(
|
|
"[StoryWriter] Attempted to continue a short story. Short stories should be complete in one call."
|
|
)
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Short stories are generated in a single call and should be complete. "
|
|
"If the story is incomplete, please regenerate it from the beginning.",
|
|
)
|
|
|
|
current_word_count = len(request.story_text.split()) if request.story_text else 0
|
|
if "long" in story_length_lower or "10000" in story_length_lower:
|
|
target_total_words = 10000
|
|
else:
|
|
target_total_words = 4500
|
|
buffer_target = int(target_total_words * 1.05)
|
|
|
|
if current_word_count >= buffer_target or (
|
|
current_word_count >= target_total_words
|
|
and (current_word_count - target_total_words) < 50
|
|
):
|
|
logger.info(
|
|
f"[StoryWriter] Word count ({current_word_count}) already at or near target ({target_total_words})."
|
|
)
|
|
return StoryContinueResponse(continuation="IAMDONE", is_complete=True, success=True)
|
|
|
|
continuation = story_service.continue_story(
|
|
premise=request.premise,
|
|
outline=outline_data,
|
|
story_text=request.story_text,
|
|
persona=request.persona,
|
|
story_setting=request.story_setting,
|
|
character_input=request.character_input,
|
|
plot_elements=request.plot_elements,
|
|
writing_style=request.writing_style,
|
|
story_tone=request.story_tone,
|
|
narrative_pov=request.narrative_pov,
|
|
audience_age_group=request.audience_age_group,
|
|
content_rating=request.content_rating,
|
|
ending_preference=request.ending_preference,
|
|
anime_bible=getattr(request, "anime_bible", None),
|
|
story_length=story_length,
|
|
user_id=user_id,
|
|
)
|
|
|
|
is_complete = "IAMDONE" in continuation.upper()
|
|
if not is_complete and continuation:
|
|
new_story_text = request.story_text + "\n\n" + continuation
|
|
new_word_count = len(new_story_text.split())
|
|
if new_word_count >= buffer_target:
|
|
logger.info(
|
|
f"[StoryWriter] Word count ({new_word_count}) now exceeds buffer target ({buffer_target})."
|
|
)
|
|
if "IAMDONE" not in continuation.upper():
|
|
continuation = continuation.rstrip() + "\n\nIAMDONE"
|
|
is_complete = True
|
|
elif new_word_count >= target_total_words and (
|
|
new_word_count - target_total_words
|
|
) < 100:
|
|
logger.info(
|
|
f"[StoryWriter] Word count ({new_word_count}) is at or very close to target ({target_total_words})."
|
|
)
|
|
if "IAMDONE" not in continuation.upper():
|
|
continuation = continuation.rstrip() + "\n\nIAMDONE"
|
|
is_complete = True
|
|
|
|
return StoryContinueResponse(continuation=continuation, is_complete=is_complete, success=True)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
logger.error(f"[StoryWriter] Failed to continue story: {exc}")
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
@router.post("/anime/scene-text", response_model=AnimeSceneTextResponse)
|
|
async def refine_anime_scene_text(
|
|
request: AnimeSceneTextRequest,
|
|
current_user: Dict[str, Any] = Depends(get_current_user),
|
|
) -> AnimeSceneTextResponse:
|
|
try:
|
|
user_id = require_authenticated_user(current_user)
|
|
|
|
scene_dict = request.scene.dict()
|
|
if not scene_dict.get("title") and not scene_dict.get("description"):
|
|
raise HTTPException(status_code=400, detail="Scene title or description is required")
|
|
|
|
refined = story_service.refine_anime_scene_text(
|
|
scene=scene_dict,
|
|
persona=request.persona,
|
|
story_setting=request.story_setting,
|
|
character_input=request.character_input,
|
|
plot_elements=request.plot_elements,
|
|
writing_style=request.writing_style,
|
|
story_tone=request.story_tone,
|
|
narrative_pov=request.narrative_pov,
|
|
audience_age_group=request.audience_age_group,
|
|
content_rating=request.content_rating,
|
|
anime_bible=request.anime_bible,
|
|
user_id=user_id,
|
|
)
|
|
|
|
refined_scene = StoryScene(
|
|
scene_number=refined.get("scene_number", request.scene.scene_number),
|
|
title=refined.get("title", request.scene.title),
|
|
description=refined.get("description", request.scene.description),
|
|
image_prompt=refined.get("image_prompt", request.scene.image_prompt),
|
|
audio_narration=refined.get("audio_narration", request.scene.audio_narration),
|
|
character_descriptions=refined.get(
|
|
"character_descriptions", request.scene.character_descriptions
|
|
),
|
|
key_events=refined.get("key_events", request.scene.key_events),
|
|
)
|
|
|
|
return AnimeSceneTextResponse(scene=refined_scene, success=True)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
logger.error(f"[StoryWriter] Failed to refine anime scene text: {exc}")
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
@router.post("/anime/scene-generate", response_model=AnimeSceneGenerateResponse)
|
|
async def generate_anime_scene_from_bible(
|
|
request: AnimeSceneGenerateRequest,
|
|
current_user: Dict[str, Any] = Depends(get_current_user),
|
|
) -> AnimeSceneGenerateResponse:
|
|
try:
|
|
user_id = require_authenticated_user(current_user)
|
|
|
|
if not request.anime_bible:
|
|
raise HTTPException(status_code=400, detail="Anime story bible is required")
|
|
|
|
previous_scenes_payload: Optional[List[Dict[str, Any]]] = None
|
|
if request.previous_scenes:
|
|
previous_scenes_payload = [scene.dict() for scene in request.previous_scenes]
|
|
|
|
generated = story_service.generate_anime_scene_from_bible(
|
|
premise=request.premise,
|
|
persona=request.persona,
|
|
story_setting=request.story_setting,
|
|
character_input=request.character_input,
|
|
plot_elements=request.plot_elements,
|
|
writing_style=request.writing_style,
|
|
story_tone=request.story_tone,
|
|
narrative_pov=request.narrative_pov,
|
|
audience_age_group=request.audience_age_group,
|
|
content_rating=request.content_rating,
|
|
anime_bible=request.anime_bible,
|
|
previous_scenes=previous_scenes_payload,
|
|
target_scene_number=request.target_scene_number,
|
|
user_id=user_id,
|
|
)
|
|
|
|
scene = StoryScene(
|
|
scene_number=generated.get("scene_number"),
|
|
title=generated.get("title", ""),
|
|
description=generated.get("description", ""),
|
|
image_prompt=generated.get("image_prompt", ""),
|
|
audio_narration=generated.get("audio_narration", ""),
|
|
character_descriptions=generated.get("character_descriptions") or [],
|
|
key_events=generated.get("key_events") or [],
|
|
)
|
|
|
|
return AnimeSceneGenerateResponse(scene=scene, success=True)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
logger.error(f"[StoryWriter] Failed to generate anime scene from bible: {exc}")
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|
|
class SceneApprovalRequest(BaseModel):
|
|
project_id: str = Field(..., min_length=1)
|
|
scene_id: str = Field(..., min_length=1)
|
|
approved: bool = True
|
|
notes: Optional[str] = None
|
|
|
|
|
|
@router.post("/script/approve")
|
|
async def approve_script_scene(
|
|
request: SceneApprovalRequest,
|
|
current_user: Dict[str, Any] = Depends(get_current_user),
|
|
) -> Dict[str, Any]:
|
|
"""Persist scene approval metadata for auditing."""
|
|
try:
|
|
user_id = require_authenticated_user(current_user)
|
|
if not request.project_id.strip() or not request.scene_id.strip():
|
|
raise HTTPException(status_code=400, detail="project_id and scene_id are required")
|
|
|
|
notes = request.notes.strip() if request.notes else None
|
|
user_store = _get_user_store(user_id)
|
|
project_store = user_store.setdefault(request.project_id, {})
|
|
timestamp = datetime.utcnow()
|
|
project_store[request.scene_id] = {
|
|
"approved": request.approved,
|
|
"notes": notes,
|
|
"user_id": user_id,
|
|
"timestamp": timestamp,
|
|
}
|
|
_enforce_capacity(user_id)
|
|
logger.info(
|
|
"[StoryWriter] Scene approval recorded user=%s project=%s scene=%s approved=%s",
|
|
user_id,
|
|
request.project_id,
|
|
request.scene_id,
|
|
request.approved,
|
|
)
|
|
return {
|
|
"success": True,
|
|
"project_id": request.project_id,
|
|
"scene_id": request.scene_id,
|
|
"approved": request.approved,
|
|
"timestamp": timestamp.isoformat(),
|
|
}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
logger.error(f"[StoryWriter] Failed to approve scene: {exc}")
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
|