Base code
This commit is contained in:
752
backend/test/test_stability_endpoints.py
Normal file
752
backend/test/test_stability_endpoints.py
Normal file
@@ -0,0 +1,752 @@
|
||||
"""Test suite for Stability AI endpoints."""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
import io
|
||||
from PIL import Image
|
||||
import json
|
||||
import base64
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from routers.stability import router
|
||||
from services.stability_service import StabilityAIService
|
||||
from models.stability_models import *
|
||||
|
||||
|
||||
# Create test app
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
class TestStabilityEndpoints:
|
||||
"""Test cases for Stability AI endpoints."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test environment."""
|
||||
self.test_image = self._create_test_image()
|
||||
self.test_audio = self._create_test_audio()
|
||||
|
||||
def _create_test_image(self) -> bytes:
|
||||
"""Create test image data."""
|
||||
img = Image.new('RGB', (512, 512), color='red')
|
||||
img_bytes = io.BytesIO()
|
||||
img.save(img_bytes, format='PNG')
|
||||
return img_bytes.getvalue()
|
||||
|
||||
def _create_test_audio(self) -> bytes:
|
||||
"""Create test audio data."""
|
||||
# Mock audio data
|
||||
return b"fake_audio_data" * 1000
|
||||
|
||||
@patch('services.stability_service.StabilityAIService')
|
||||
def test_generate_ultra_success(self, mock_service):
|
||||
"""Test successful Ultra generation."""
|
||||
# Mock service response
|
||||
mock_service.return_value.__aenter__.return_value.generate_ultra = AsyncMock(
|
||||
return_value=self.test_image
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/stability/generate/ultra",
|
||||
data={"prompt": "A beautiful landscape"},
|
||||
files={}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("image/")
|
||||
|
||||
@patch('services.stability_service.StabilityAIService')
|
||||
def test_generate_core_with_parameters(self, mock_service):
|
||||
"""Test Core generation with various parameters."""
|
||||
mock_service.return_value.__aenter__.return_value.generate_core = AsyncMock(
|
||||
return_value=self.test_image
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/stability/generate/core",
|
||||
data={
|
||||
"prompt": "A futuristic city",
|
||||
"aspect_ratio": "16:9",
|
||||
"style_preset": "digital-art",
|
||||
"seed": 42
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@patch('services.stability_service.StabilityAIService')
|
||||
def test_inpaint_with_mask(self, mock_service):
|
||||
"""Test inpainting with mask."""
|
||||
mock_service.return_value.__aenter__.return_value.inpaint = AsyncMock(
|
||||
return_value=self.test_image
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/stability/edit/inpaint",
|
||||
data={"prompt": "A cat"},
|
||||
files={
|
||||
"image": ("test.png", self.test_image, "image/png"),
|
||||
"mask": ("mask.png", self.test_image, "image/png")
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@patch('services.stability_service.StabilityAIService')
|
||||
def test_upscale_fast(self, mock_service):
|
||||
"""Test fast upscaling."""
|
||||
mock_service.return_value.__aenter__.return_value.upscale_fast = AsyncMock(
|
||||
return_value=self.test_image
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/stability/upscale/fast",
|
||||
files={"image": ("test.png", self.test_image, "image/png")}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@patch('services.stability_service.StabilityAIService')
|
||||
def test_control_sketch(self, mock_service):
|
||||
"""Test sketch control."""
|
||||
mock_service.return_value.__aenter__.return_value.control_sketch = AsyncMock(
|
||||
return_value=self.test_image
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/stability/control/sketch",
|
||||
data={
|
||||
"prompt": "A medieval castle",
|
||||
"control_strength": 0.8
|
||||
},
|
||||
files={"image": ("sketch.png", self.test_image, "image/png")}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@patch('services.stability_service.StabilityAIService')
|
||||
def test_3d_generation(self, mock_service):
|
||||
"""Test 3D model generation."""
|
||||
mock_3d_data = b"fake_glb_data" * 100
|
||||
mock_service.return_value.__aenter__.return_value.generate_3d_fast = AsyncMock(
|
||||
return_value=mock_3d_data
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/stability/3d/stable-fast-3d",
|
||||
files={"image": ("test.png", self.test_image, "image/png")}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "model/gltf-binary"
|
||||
|
||||
@patch('services.stability_service.StabilityAIService')
|
||||
def test_audio_generation(self, mock_service):
|
||||
"""Test audio generation."""
|
||||
mock_service.return_value.__aenter__.return_value.generate_audio_from_text = AsyncMock(
|
||||
return_value=self.test_audio
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/stability/audio/text-to-audio",
|
||||
data={
|
||||
"prompt": "Peaceful nature sounds",
|
||||
"duration": 30
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("audio/")
|
||||
|
||||
def test_health_check(self):
|
||||
"""Test health check endpoint."""
|
||||
response = client.get("/api/stability/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "healthy"
|
||||
|
||||
def test_models_info(self):
|
||||
"""Test models info endpoint."""
|
||||
response = client.get("/api/stability/models/info")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "generate" in data
|
||||
assert "edit" in data
|
||||
assert "upscale" in data
|
||||
|
||||
def test_supported_formats(self):
|
||||
"""Test supported formats endpoint."""
|
||||
response = client.get("/api/stability/supported-formats")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert "image_input" in data
|
||||
assert "image_output" in data
|
||||
assert "audio_input" in data
|
||||
|
||||
def test_image_info_analysis(self):
|
||||
"""Test image info utility endpoint."""
|
||||
response = client.post(
|
||||
"/api/stability/utils/image-info",
|
||||
files={"image": ("test.png", self.test_image, "image/png")}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "width" in data
|
||||
assert "height" in data
|
||||
assert "format" in data
|
||||
|
||||
def test_prompt_validation(self):
|
||||
"""Test prompt validation endpoint."""
|
||||
response = client.post(
|
||||
"/api/stability/utils/validate-prompt",
|
||||
data={"prompt": "A beautiful landscape with mountains and lakes"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "is_valid" in data
|
||||
assert "suggestions" in data
|
||||
|
||||
def test_invalid_image_format(self):
|
||||
"""Test error handling for invalid image format."""
|
||||
response = client.post(
|
||||
"/api/stability/generate/ultra",
|
||||
data={"prompt": "Test prompt"},
|
||||
files={"image": ("test.txt", b"not an image", "text/plain")}
|
||||
)
|
||||
|
||||
# Should handle gracefully or return appropriate error
|
||||
assert response.status_code in [400, 422]
|
||||
|
||||
def test_missing_required_parameters(self):
|
||||
"""Test error handling for missing required parameters."""
|
||||
response = client.post("/api/stability/generate/ultra")
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
def test_outpaint_validation(self):
|
||||
"""Test outpaint direction validation."""
|
||||
response = client.post(
|
||||
"/api/stability/edit/outpaint",
|
||||
data={
|
||||
"left": 0,
|
||||
"right": 0,
|
||||
"up": 0,
|
||||
"down": 0
|
||||
},
|
||||
files={"image": ("test.png", self.test_image, "image/png")}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "at least one outpaint direction" in response.json()["detail"]
|
||||
|
||||
@patch('services.stability_service.StabilityAIService')
|
||||
def test_async_generation_response(self, mock_service):
|
||||
"""Test async generation response format."""
|
||||
mock_service.return_value.__aenter__.return_value.upscale_creative = AsyncMock(
|
||||
return_value={"id": "test_generation_id"}
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/stability/upscale/creative",
|
||||
data={"prompt": "High quality upscale"},
|
||||
files={"image": ("test.png", self.test_image, "image/png")}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
|
||||
@patch('services.stability_service.StabilityAIService')
|
||||
def test_batch_comparison(self, mock_service):
|
||||
"""Test model comparison endpoint."""
|
||||
mock_service.return_value.__aenter__.return_value.generate_ultra = AsyncMock(
|
||||
return_value=self.test_image
|
||||
)
|
||||
mock_service.return_value.__aenter__.return_value.generate_core = AsyncMock(
|
||||
return_value=self.test_image
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/stability/advanced/compare/models",
|
||||
data={
|
||||
"prompt": "A test image",
|
||||
"models": json.dumps(["ultra", "core"]),
|
||||
"seed": 42
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "comparison_results" in data
|
||||
|
||||
|
||||
class TestStabilityService:
|
||||
"""Test cases for StabilityAIService class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_initialization(self):
|
||||
"""Test service initialization."""
|
||||
with patch.dict('os.environ', {'STABILITY_API_KEY': 'test_key'}):
|
||||
service = StabilityAIService()
|
||||
assert service.api_key == 'test_key'
|
||||
|
||||
def test_service_initialization_no_key(self):
|
||||
"""Test service initialization without API key."""
|
||||
with patch.dict('os.environ', {}, clear=True):
|
||||
with pytest.raises(ValueError):
|
||||
StabilityAIService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('aiohttp.ClientSession')
|
||||
async def test_make_request_success(self, mock_session):
|
||||
"""Test successful API request."""
|
||||
# Mock response
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read.return_value = b"test_image_data"
|
||||
mock_response.headers = {"Content-Type": "image/png"}
|
||||
|
||||
mock_session.return_value.__aenter__.return_value.request.return_value.__aenter__.return_value = mock_response
|
||||
|
||||
service = StabilityAIService(api_key="test_key")
|
||||
|
||||
async with service:
|
||||
result = await service._make_request(
|
||||
method="POST",
|
||||
endpoint="/test",
|
||||
data={"test": "data"}
|
||||
)
|
||||
|
||||
assert result == b"test_image_data"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_preparation(self):
|
||||
"""Test image preparation methods."""
|
||||
service = StabilityAIService(api_key="test_key")
|
||||
|
||||
# Test bytes input
|
||||
test_bytes = b"test_image_bytes"
|
||||
result = await service._prepare_image_file(test_bytes)
|
||||
assert result == test_bytes
|
||||
|
||||
# Test base64 input
|
||||
test_b64 = base64.b64encode(test_bytes).decode()
|
||||
result = await service._prepare_image_file(test_b64)
|
||||
assert result == test_bytes
|
||||
|
||||
def test_dimension_validation(self):
|
||||
"""Test image dimension validation."""
|
||||
service = StabilityAIService(api_key="test_key")
|
||||
|
||||
# Valid dimensions
|
||||
service._validate_image_requirements(1024, 1024)
|
||||
|
||||
# Invalid dimensions (too small)
|
||||
with pytest.raises(ValueError):
|
||||
service._validate_image_requirements(32, 32)
|
||||
|
||||
def test_aspect_ratio_validation(self):
|
||||
"""Test aspect ratio validation."""
|
||||
service = StabilityAIService(api_key="test_key")
|
||||
|
||||
# Valid aspect ratio
|
||||
service._validate_aspect_ratio(1024, 1024)
|
||||
|
||||
# Invalid aspect ratio (too wide)
|
||||
with pytest.raises(ValueError):
|
||||
service._validate_aspect_ratio(3000, 500)
|
||||
|
||||
|
||||
class TestStabilityModels:
|
||||
"""Test cases for Pydantic models."""
|
||||
|
||||
def test_stable_image_ultra_request(self):
|
||||
"""Test StableImageUltraRequest validation."""
|
||||
# Valid request
|
||||
request = StableImageUltraRequest(
|
||||
prompt="A beautiful landscape",
|
||||
aspect_ratio="16:9",
|
||||
seed=42
|
||||
)
|
||||
assert request.prompt == "A beautiful landscape"
|
||||
assert request.aspect_ratio == "16:9"
|
||||
assert request.seed == 42
|
||||
|
||||
def test_invalid_seed_range(self):
|
||||
"""Test invalid seed range validation."""
|
||||
with pytest.raises(ValueError):
|
||||
StableImageUltraRequest(
|
||||
prompt="Test",
|
||||
seed=5000000000 # Too large
|
||||
)
|
||||
|
||||
def test_prompt_length_validation(self):
|
||||
"""Test prompt length validation."""
|
||||
# Too long prompt
|
||||
with pytest.raises(ValueError):
|
||||
StableImageUltraRequest(
|
||||
prompt="x" * 10001 # Exceeds max length
|
||||
)
|
||||
|
||||
# Empty prompt
|
||||
with pytest.raises(ValueError):
|
||||
StableImageUltraRequest(
|
||||
prompt="" # Below min length
|
||||
)
|
||||
|
||||
def test_outpaint_request(self):
|
||||
"""Test OutpaintRequest validation."""
|
||||
request = OutpaintRequest(
|
||||
left=100,
|
||||
right=200,
|
||||
up=50,
|
||||
down=150
|
||||
)
|
||||
assert request.left == 100
|
||||
assert request.right == 200
|
||||
|
||||
def test_audio_request_validation(self):
|
||||
"""Test audio request validation."""
|
||||
request = TextToAudioRequest(
|
||||
prompt="Peaceful music",
|
||||
duration=60,
|
||||
model="stable-audio-2.5"
|
||||
)
|
||||
assert request.duration == 60
|
||||
assert request.model == "stable-audio-2.5"
|
||||
|
||||
|
||||
class TestStabilityUtils:
|
||||
"""Test cases for utility functions."""
|
||||
|
||||
def test_image_validator(self):
|
||||
"""Test image validation utilities."""
|
||||
from utils.stability_utils import ImageValidator
|
||||
|
||||
# Mock UploadFile
|
||||
mock_file = Mock()
|
||||
mock_file.content_type = "image/png"
|
||||
mock_file.filename = "test.png"
|
||||
|
||||
result = ImageValidator.validate_image_file(mock_file)
|
||||
assert result["is_valid"] is True
|
||||
|
||||
def test_prompt_optimizer(self):
|
||||
"""Test prompt optimization utilities."""
|
||||
from utils.stability_utils import PromptOptimizer
|
||||
|
||||
prompt = "A simple image"
|
||||
result = PromptOptimizer.optimize_prompt(
|
||||
prompt=prompt,
|
||||
target_model="ultra",
|
||||
target_style="photographic",
|
||||
quality_level="high"
|
||||
)
|
||||
|
||||
assert len(result["optimized_prompt"]) > len(prompt)
|
||||
assert "optimizations_applied" in result
|
||||
|
||||
def test_parameter_validator(self):
|
||||
"""Test parameter validation utilities."""
|
||||
from utils.stability_utils import ParameterValidator
|
||||
|
||||
# Valid seed
|
||||
seed = ParameterValidator.validate_seed(42)
|
||||
assert seed == 42
|
||||
|
||||
# Invalid seed
|
||||
with pytest.raises(HTTPException):
|
||||
ParameterValidator.validate_seed(5000000000)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_analysis(self):
|
||||
"""Test image content analysis."""
|
||||
from utils.stability_utils import ImageValidator
|
||||
|
||||
result = await ImageValidator.analyze_image_content(self.test_image)
|
||||
|
||||
assert "width" in result
|
||||
assert "height" in result
|
||||
assert "total_pixels" in result
|
||||
assert "quality_assessment" in result
|
||||
|
||||
|
||||
class TestStabilityConfig:
|
||||
"""Test cases for configuration."""
|
||||
|
||||
def test_stability_config_creation(self):
|
||||
"""Test StabilityConfig creation."""
|
||||
from config.stability_config import StabilityConfig
|
||||
|
||||
config = StabilityConfig(api_key="test_key")
|
||||
assert config.api_key == "test_key"
|
||||
assert config.base_url == "https://api.stability.ai"
|
||||
|
||||
def test_model_recommendations(self):
|
||||
"""Test model recommendation logic."""
|
||||
from config.stability_config import get_model_recommendations
|
||||
|
||||
recommendations = get_model_recommendations(
|
||||
use_case="portrait",
|
||||
quality_preference="premium"
|
||||
)
|
||||
|
||||
assert "primary" in recommendations
|
||||
assert "alternative" in recommendations
|
||||
|
||||
def test_image_validation_config(self):
|
||||
"""Test image validation configuration."""
|
||||
from config.stability_config import validate_image_requirements
|
||||
|
||||
# Valid image
|
||||
result = validate_image_requirements(1024, 1024, "generate")
|
||||
assert result["is_valid"] is True
|
||||
|
||||
# Invalid image (too small)
|
||||
result = validate_image_requirements(32, 32, "generate")
|
||||
assert result["is_valid"] is False
|
||||
|
||||
def test_cost_calculation(self):
|
||||
"""Test cost calculation."""
|
||||
from config.stability_config import calculate_estimated_cost
|
||||
|
||||
cost = calculate_estimated_cost("generate", "ultra")
|
||||
assert cost == 8 # Ultra model cost
|
||||
|
||||
cost = calculate_estimated_cost("upscale", "fast")
|
||||
assert cost == 2 # Fast upscale cost
|
||||
|
||||
|
||||
class TestStabilityMiddleware:
|
||||
"""Test cases for middleware."""
|
||||
|
||||
def test_rate_limit_middleware(self):
|
||||
"""Test rate limiting middleware."""
|
||||
from middleware.stability_middleware import RateLimitMiddleware
|
||||
|
||||
middleware = RateLimitMiddleware(requests_per_window=5, window_seconds=10)
|
||||
|
||||
# Test client identification
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {"authorization": "Bearer test_api_key"}
|
||||
|
||||
client_id = middleware._get_client_id(mock_request)
|
||||
assert len(client_id) == 8 # First 8 chars of API key
|
||||
|
||||
def test_monitoring_middleware(self):
|
||||
"""Test monitoring middleware."""
|
||||
from middleware.stability_middleware import MonitoringMiddleware
|
||||
|
||||
middleware = MonitoringMiddleware()
|
||||
|
||||
# Test operation extraction
|
||||
operation = middleware._extract_operation("/api/stability/generate/ultra")
|
||||
assert operation == "generate_ultra"
|
||||
|
||||
def test_caching_middleware(self):
|
||||
"""Test caching middleware."""
|
||||
from middleware.stability_middleware import CachingMiddleware
|
||||
|
||||
middleware = CachingMiddleware()
|
||||
|
||||
# Test cache key generation
|
||||
mock_request = Mock()
|
||||
mock_request.method = "GET"
|
||||
mock_request.url.path = "/api/stability/health"
|
||||
mock_request.query_params = {}
|
||||
|
||||
# This would need to be properly mocked for async
|
||||
# cache_key = await middleware._generate_cache_key(mock_request)
|
||||
# assert isinstance(cache_key, str)
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling scenarios."""
|
||||
|
||||
@patch('services.stability_service.StabilityAIService')
|
||||
def test_api_error_handling(self, mock_service):
|
||||
"""Test API error response handling."""
|
||||
mock_service.return_value.__aenter__.return_value.generate_ultra = AsyncMock(
|
||||
side_effect=HTTPException(status_code=400, detail="Invalid parameters")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/stability/generate/ultra",
|
||||
data={"prompt": "Test"}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
@patch('services.stability_service.StabilityAIService')
|
||||
def test_timeout_handling(self, mock_service):
|
||||
"""Test timeout error handling."""
|
||||
mock_service.return_value.__aenter__.return_value.generate_ultra = AsyncMock(
|
||||
side_effect=asyncio.TimeoutError()
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/api/stability/generate/ultra",
|
||||
data={"prompt": "Test"}
|
||||
)
|
||||
|
||||
assert response.status_code == 504
|
||||
|
||||
def test_file_size_validation(self):
|
||||
"""Test file size validation."""
|
||||
from utils.stability_utils import validate_file_size
|
||||
|
||||
# Mock large file
|
||||
mock_file = Mock()
|
||||
mock_file.size = 20 * 1024 * 1024 # 20MB
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
validate_file_size(mock_file, max_size=10 * 1024 * 1024)
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
|
||||
|
||||
class TestWorkflowProcessing:
|
||||
"""Test workflow and batch processing."""
|
||||
|
||||
@patch('services.stability_service.StabilityAIService')
|
||||
def test_workflow_validation(self, mock_service):
|
||||
"""Test workflow validation."""
|
||||
from utils.stability_utils import WorkflowManager
|
||||
|
||||
# Valid workflow
|
||||
workflow = [
|
||||
{"operation": "generate_core", "parameters": {"prompt": "test"}},
|
||||
{"operation": "upscale_fast", "parameters": {}}
|
||||
]
|
||||
|
||||
errors = WorkflowManager.validate_workflow(workflow)
|
||||
assert len(errors) == 0
|
||||
|
||||
# Invalid workflow
|
||||
invalid_workflow = [
|
||||
{"operation": "invalid_operation"}
|
||||
]
|
||||
|
||||
errors = WorkflowManager.validate_workflow(invalid_workflow)
|
||||
assert len(errors) > 0
|
||||
|
||||
def test_workflow_optimization(self):
|
||||
"""Test workflow optimization."""
|
||||
from utils.stability_utils import WorkflowManager
|
||||
|
||||
workflow = [
|
||||
{"operation": "upscale_fast"},
|
||||
{"operation": "generate_core"}, # Should be moved to front
|
||||
{"operation": "inpaint"}
|
||||
]
|
||||
|
||||
optimized = WorkflowManager.optimize_workflow(workflow)
|
||||
|
||||
# Generate operation should be first
|
||||
assert optimized[0]["operation"] == "generate_core"
|
||||
|
||||
|
||||
# ==================== INTEGRATION TESTS ====================
|
||||
|
||||
class TestStabilityIntegration:
|
||||
"""Integration tests for full workflow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('aiohttp.ClientSession')
|
||||
async def test_full_generation_workflow(self, mock_session):
|
||||
"""Test complete generation workflow."""
|
||||
# Mock successful API responses
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read.return_value = b"test_image_data"
|
||||
mock_response.headers = {"Content-Type": "image/png"}
|
||||
|
||||
mock_session.return_value.__aenter__.return_value.request.return_value.__aenter__.return_value = mock_response
|
||||
|
||||
service = StabilityAIService(api_key="test_key")
|
||||
|
||||
async with service:
|
||||
# Test generation
|
||||
result = await service.generate_ultra(
|
||||
prompt="A beautiful landscape",
|
||||
aspect_ratio="16:9",
|
||||
seed=42
|
||||
)
|
||||
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('aiohttp.ClientSession')
|
||||
async def test_full_edit_workflow(self, mock_session):
|
||||
"""Test complete edit workflow."""
|
||||
# Mock successful API responses
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read.return_value = b"test_edited_image_data"
|
||||
mock_response.headers = {"Content-Type": "image/png"}
|
||||
|
||||
mock_session.return_value.__aenter__.return_value.request.return_value.__aenter__.return_value = mock_response
|
||||
|
||||
service = StabilityAIService(api_key="test_key")
|
||||
|
||||
async with service:
|
||||
# Test inpainting
|
||||
result = await service.inpaint(
|
||||
image=b"test_image_data",
|
||||
prompt="A cat in the scene",
|
||||
grow_mask=10
|
||||
)
|
||||
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
# ==================== PERFORMANCE TESTS ====================
|
||||
|
||||
class TestStabilityPerformance:
|
||||
"""Performance tests for Stability AI endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests(self):
|
||||
"""Test handling of concurrent requests."""
|
||||
from services.stability_service import StabilityAIService
|
||||
|
||||
async def mock_request():
|
||||
service = StabilityAIService(api_key="test_key")
|
||||
# Mock a quick operation
|
||||
await asyncio.sleep(0.1)
|
||||
return "success"
|
||||
|
||||
# Run multiple concurrent requests
|
||||
tasks = [mock_request() for _ in range(10)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# All should succeed
|
||||
assert all(result == "success" for result in results)
|
||||
|
||||
def test_large_file_handling(self):
|
||||
"""Test handling of large files."""
|
||||
from utils.stability_utils import validate_file_size
|
||||
|
||||
# Test with various file sizes
|
||||
mock_file = Mock()
|
||||
|
||||
# Valid size
|
||||
mock_file.size = 5 * 1024 * 1024 # 5MB
|
||||
validate_file_size(mock_file) # Should not raise
|
||||
|
||||
# Invalid size
|
||||
mock_file.size = 15 * 1024 * 1024 # 15MB
|
||||
with pytest.raises(HTTPException):
|
||||
validate_file_size(mock_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user