752 lines
25 KiB
Python
752 lines
25 KiB
Python
"""Test suite for Stability AI endpoints."""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from fastapi.testclient import TestClient
|
|
from fastapi import FastAPI
|
|
import io
|
|
from PIL import Image
|
|
import json
|
|
import base64
|
|
from unittest.mock import Mock, AsyncMock, patch
|
|
|
|
from routers.stability import router
|
|
from services.stability_service import StabilityAIService
|
|
from models.stability_models import *
|
|
|
|
|
|
# Create test app
|
|
app = FastAPI()
|
|
app.include_router(router)
|
|
client = TestClient(app)
|
|
|
|
|
|
class TestStabilityEndpoints:
|
|
"""Test cases for Stability AI endpoints."""
|
|
|
|
def setup_method(self):
|
|
"""Set up test environment."""
|
|
self.test_image = self._create_test_image()
|
|
self.test_audio = self._create_test_audio()
|
|
|
|
def _create_test_image(self) -> bytes:
|
|
"""Create test image data."""
|
|
img = Image.new('RGB', (512, 512), color='red')
|
|
img_bytes = io.BytesIO()
|
|
img.save(img_bytes, format='PNG')
|
|
return img_bytes.getvalue()
|
|
|
|
def _create_test_audio(self) -> bytes:
|
|
"""Create test audio data."""
|
|
# Mock audio data
|
|
return b"fake_audio_data" * 1000
|
|
|
|
@patch('services.stability_service.StabilityAIService')
|
|
def test_generate_ultra_success(self, mock_service):
|
|
"""Test successful Ultra generation."""
|
|
# Mock service response
|
|
mock_service.return_value.__aenter__.return_value.generate_ultra = AsyncMock(
|
|
return_value=self.test_image
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/stability/generate/ultra",
|
|
data={"prompt": "A beautiful landscape"},
|
|
files={}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"].startswith("image/")
|
|
|
|
@patch('services.stability_service.StabilityAIService')
|
|
def test_generate_core_with_parameters(self, mock_service):
|
|
"""Test Core generation with various parameters."""
|
|
mock_service.return_value.__aenter__.return_value.generate_core = AsyncMock(
|
|
return_value=self.test_image
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/stability/generate/core",
|
|
data={
|
|
"prompt": "A futuristic city",
|
|
"aspect_ratio": "16:9",
|
|
"style_preset": "digital-art",
|
|
"seed": 42
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
@patch('services.stability_service.StabilityAIService')
|
|
def test_inpaint_with_mask(self, mock_service):
|
|
"""Test inpainting with mask."""
|
|
mock_service.return_value.__aenter__.return_value.inpaint = AsyncMock(
|
|
return_value=self.test_image
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/stability/edit/inpaint",
|
|
data={"prompt": "A cat"},
|
|
files={
|
|
"image": ("test.png", self.test_image, "image/png"),
|
|
"mask": ("mask.png", self.test_image, "image/png")
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
@patch('services.stability_service.StabilityAIService')
|
|
def test_upscale_fast(self, mock_service):
|
|
"""Test fast upscaling."""
|
|
mock_service.return_value.__aenter__.return_value.upscale_fast = AsyncMock(
|
|
return_value=self.test_image
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/stability/upscale/fast",
|
|
files={"image": ("test.png", self.test_image, "image/png")}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
@patch('services.stability_service.StabilityAIService')
|
|
def test_control_sketch(self, mock_service):
|
|
"""Test sketch control."""
|
|
mock_service.return_value.__aenter__.return_value.control_sketch = AsyncMock(
|
|
return_value=self.test_image
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/stability/control/sketch",
|
|
data={
|
|
"prompt": "A medieval castle",
|
|
"control_strength": 0.8
|
|
},
|
|
files={"image": ("sketch.png", self.test_image, "image/png")}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
@patch('services.stability_service.StabilityAIService')
|
|
def test_3d_generation(self, mock_service):
|
|
"""Test 3D model generation."""
|
|
mock_3d_data = b"fake_glb_data" * 100
|
|
mock_service.return_value.__aenter__.return_value.generate_3d_fast = AsyncMock(
|
|
return_value=mock_3d_data
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/stability/3d/stable-fast-3d",
|
|
files={"image": ("test.png", self.test_image, "image/png")}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "model/gltf-binary"
|
|
|
|
@patch('services.stability_service.StabilityAIService')
|
|
def test_audio_generation(self, mock_service):
|
|
"""Test audio generation."""
|
|
mock_service.return_value.__aenter__.return_value.generate_audio_from_text = AsyncMock(
|
|
return_value=self.test_audio
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/stability/audio/text-to-audio",
|
|
data={
|
|
"prompt": "Peaceful nature sounds",
|
|
"duration": 30
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"].startswith("audio/")
|
|
|
|
def test_health_check(self):
|
|
"""Test health check endpoint."""
|
|
response = client.get("/api/stability/health")
|
|
assert response.status_code == 200
|
|
assert response.json()["status"] == "healthy"
|
|
|
|
def test_models_info(self):
|
|
"""Test models info endpoint."""
|
|
response = client.get("/api/stability/models/info")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "generate" in data
|
|
assert "edit" in data
|
|
assert "upscale" in data
|
|
|
|
def test_supported_formats(self):
|
|
"""Test supported formats endpoint."""
|
|
response = client.get("/api/stability/supported-formats")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "image_input" in data
|
|
assert "image_output" in data
|
|
assert "audio_input" in data
|
|
|
|
def test_image_info_analysis(self):
|
|
"""Test image info utility endpoint."""
|
|
response = client.post(
|
|
"/api/stability/utils/image-info",
|
|
files={"image": ("test.png", self.test_image, "image/png")}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "width" in data
|
|
assert "height" in data
|
|
assert "format" in data
|
|
|
|
def test_prompt_validation(self):
|
|
"""Test prompt validation endpoint."""
|
|
response = client.post(
|
|
"/api/stability/utils/validate-prompt",
|
|
data={"prompt": "A beautiful landscape with mountains and lakes"}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "is_valid" in data
|
|
assert "suggestions" in data
|
|
|
|
def test_invalid_image_format(self):
|
|
"""Test error handling for invalid image format."""
|
|
response = client.post(
|
|
"/api/stability/generate/ultra",
|
|
data={"prompt": "Test prompt"},
|
|
files={"image": ("test.txt", b"not an image", "text/plain")}
|
|
)
|
|
|
|
# Should handle gracefully or return appropriate error
|
|
assert response.status_code in [400, 422]
|
|
|
|
def test_missing_required_parameters(self):
|
|
"""Test error handling for missing required parameters."""
|
|
response = client.post("/api/stability/generate/ultra")
|
|
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
def test_outpaint_validation(self):
|
|
"""Test outpaint direction validation."""
|
|
response = client.post(
|
|
"/api/stability/edit/outpaint",
|
|
data={
|
|
"left": 0,
|
|
"right": 0,
|
|
"up": 0,
|
|
"down": 0
|
|
},
|
|
files={"image": ("test.png", self.test_image, "image/png")}
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert "at least one outpaint direction" in response.json()["detail"]
|
|
|
|
@patch('services.stability_service.StabilityAIService')
|
|
def test_async_generation_response(self, mock_service):
|
|
"""Test async generation response format."""
|
|
mock_service.return_value.__aenter__.return_value.upscale_creative = AsyncMock(
|
|
return_value={"id": "test_generation_id"}
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/stability/upscale/creative",
|
|
data={"prompt": "High quality upscale"},
|
|
files={"image": ("test.png", self.test_image, "image/png")}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "id" in data
|
|
|
|
@patch('services.stability_service.StabilityAIService')
|
|
def test_batch_comparison(self, mock_service):
|
|
"""Test model comparison endpoint."""
|
|
mock_service.return_value.__aenter__.return_value.generate_ultra = AsyncMock(
|
|
return_value=self.test_image
|
|
)
|
|
mock_service.return_value.__aenter__.return_value.generate_core = AsyncMock(
|
|
return_value=self.test_image
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/stability/advanced/compare/models",
|
|
data={
|
|
"prompt": "A test image",
|
|
"models": json.dumps(["ultra", "core"]),
|
|
"seed": 42
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "comparison_results" in data
|
|
|
|
|
|
class TestStabilityService:
|
|
"""Test cases for StabilityAIService class."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_service_initialization(self):
|
|
"""Test service initialization."""
|
|
with patch.dict('os.environ', {'STABILITY_API_KEY': 'test_key'}):
|
|
service = StabilityAIService()
|
|
assert service.api_key == 'test_key'
|
|
|
|
def test_service_initialization_no_key(self):
|
|
"""Test service initialization without API key."""
|
|
with patch.dict('os.environ', {}, clear=True):
|
|
with pytest.raises(ValueError):
|
|
StabilityAIService()
|
|
|
|
@pytest.mark.asyncio
|
|
@patch('aiohttp.ClientSession')
|
|
async def test_make_request_success(self, mock_session):
|
|
"""Test successful API request."""
|
|
# Mock response
|
|
mock_response = AsyncMock()
|
|
mock_response.status = 200
|
|
mock_response.read.return_value = b"test_image_data"
|
|
mock_response.headers = {"Content-Type": "image/png"}
|
|
|
|
mock_session.return_value.__aenter__.return_value.request.return_value.__aenter__.return_value = mock_response
|
|
|
|
service = StabilityAIService(api_key="test_key")
|
|
|
|
async with service:
|
|
result = await service._make_request(
|
|
method="POST",
|
|
endpoint="/test",
|
|
data={"test": "data"}
|
|
)
|
|
|
|
assert result == b"test_image_data"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_image_preparation(self):
|
|
"""Test image preparation methods."""
|
|
service = StabilityAIService(api_key="test_key")
|
|
|
|
# Test bytes input
|
|
test_bytes = b"test_image_bytes"
|
|
result = await service._prepare_image_file(test_bytes)
|
|
assert result == test_bytes
|
|
|
|
# Test base64 input
|
|
test_b64 = base64.b64encode(test_bytes).decode()
|
|
result = await service._prepare_image_file(test_b64)
|
|
assert result == test_bytes
|
|
|
|
def test_dimension_validation(self):
|
|
"""Test image dimension validation."""
|
|
service = StabilityAIService(api_key="test_key")
|
|
|
|
# Valid dimensions
|
|
service._validate_image_requirements(1024, 1024)
|
|
|
|
# Invalid dimensions (too small)
|
|
with pytest.raises(ValueError):
|
|
service._validate_image_requirements(32, 32)
|
|
|
|
def test_aspect_ratio_validation(self):
|
|
"""Test aspect ratio validation."""
|
|
service = StabilityAIService(api_key="test_key")
|
|
|
|
# Valid aspect ratio
|
|
service._validate_aspect_ratio(1024, 1024)
|
|
|
|
# Invalid aspect ratio (too wide)
|
|
with pytest.raises(ValueError):
|
|
service._validate_aspect_ratio(3000, 500)
|
|
|
|
|
|
class TestStabilityModels:
|
|
"""Test cases for Pydantic models."""
|
|
|
|
def test_stable_image_ultra_request(self):
|
|
"""Test StableImageUltraRequest validation."""
|
|
# Valid request
|
|
request = StableImageUltraRequest(
|
|
prompt="A beautiful landscape",
|
|
aspect_ratio="16:9",
|
|
seed=42
|
|
)
|
|
assert request.prompt == "A beautiful landscape"
|
|
assert request.aspect_ratio == "16:9"
|
|
assert request.seed == 42
|
|
|
|
def test_invalid_seed_range(self):
|
|
"""Test invalid seed range validation."""
|
|
with pytest.raises(ValueError):
|
|
StableImageUltraRequest(
|
|
prompt="Test",
|
|
seed=5000000000 # Too large
|
|
)
|
|
|
|
def test_prompt_length_validation(self):
|
|
"""Test prompt length validation."""
|
|
# Too long prompt
|
|
with pytest.raises(ValueError):
|
|
StableImageUltraRequest(
|
|
prompt="x" * 10001 # Exceeds max length
|
|
)
|
|
|
|
# Empty prompt
|
|
with pytest.raises(ValueError):
|
|
StableImageUltraRequest(
|
|
prompt="" # Below min length
|
|
)
|
|
|
|
def test_outpaint_request(self):
|
|
"""Test OutpaintRequest validation."""
|
|
request = OutpaintRequest(
|
|
left=100,
|
|
right=200,
|
|
up=50,
|
|
down=150
|
|
)
|
|
assert request.left == 100
|
|
assert request.right == 200
|
|
|
|
def test_audio_request_validation(self):
|
|
"""Test audio request validation."""
|
|
request = TextToAudioRequest(
|
|
prompt="Peaceful music",
|
|
duration=60,
|
|
model="stable-audio-2.5"
|
|
)
|
|
assert request.duration == 60
|
|
assert request.model == "stable-audio-2.5"
|
|
|
|
|
|
class TestStabilityUtils:
|
|
"""Test cases for utility functions."""
|
|
|
|
def test_image_validator(self):
|
|
"""Test image validation utilities."""
|
|
from utils.stability_utils import ImageValidator
|
|
|
|
# Mock UploadFile
|
|
mock_file = Mock()
|
|
mock_file.content_type = "image/png"
|
|
mock_file.filename = "test.png"
|
|
|
|
result = ImageValidator.validate_image_file(mock_file)
|
|
assert result["is_valid"] is True
|
|
|
|
def test_prompt_optimizer(self):
|
|
"""Test prompt optimization utilities."""
|
|
from utils.stability_utils import PromptOptimizer
|
|
|
|
prompt = "A simple image"
|
|
result = PromptOptimizer.optimize_prompt(
|
|
prompt=prompt,
|
|
target_model="ultra",
|
|
target_style="photographic",
|
|
quality_level="high"
|
|
)
|
|
|
|
assert len(result["optimized_prompt"]) > len(prompt)
|
|
assert "optimizations_applied" in result
|
|
|
|
def test_parameter_validator(self):
|
|
"""Test parameter validation utilities."""
|
|
from utils.stability_utils import ParameterValidator
|
|
|
|
# Valid seed
|
|
seed = ParameterValidator.validate_seed(42)
|
|
assert seed == 42
|
|
|
|
# Invalid seed
|
|
with pytest.raises(HTTPException):
|
|
ParameterValidator.validate_seed(5000000000)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_image_analysis(self):
|
|
"""Test image content analysis."""
|
|
from utils.stability_utils import ImageValidator
|
|
|
|
result = await ImageValidator.analyze_image_content(self.test_image)
|
|
|
|
assert "width" in result
|
|
assert "height" in result
|
|
assert "total_pixels" in result
|
|
assert "quality_assessment" in result
|
|
|
|
|
|
class TestStabilityConfig:
|
|
"""Test cases for configuration."""
|
|
|
|
def test_stability_config_creation(self):
|
|
"""Test StabilityConfig creation."""
|
|
from config.stability_config import StabilityConfig
|
|
|
|
config = StabilityConfig(api_key="test_key")
|
|
assert config.api_key == "test_key"
|
|
assert config.base_url == "https://api.stability.ai"
|
|
|
|
def test_model_recommendations(self):
|
|
"""Test model recommendation logic."""
|
|
from config.stability_config import get_model_recommendations
|
|
|
|
recommendations = get_model_recommendations(
|
|
use_case="portrait",
|
|
quality_preference="premium"
|
|
)
|
|
|
|
assert "primary" in recommendations
|
|
assert "alternative" in recommendations
|
|
|
|
def test_image_validation_config(self):
|
|
"""Test image validation configuration."""
|
|
from config.stability_config import validate_image_requirements
|
|
|
|
# Valid image
|
|
result = validate_image_requirements(1024, 1024, "generate")
|
|
assert result["is_valid"] is True
|
|
|
|
# Invalid image (too small)
|
|
result = validate_image_requirements(32, 32, "generate")
|
|
assert result["is_valid"] is False
|
|
|
|
def test_cost_calculation(self):
|
|
"""Test cost calculation."""
|
|
from config.stability_config import calculate_estimated_cost
|
|
|
|
cost = calculate_estimated_cost("generate", "ultra")
|
|
assert cost == 8 # Ultra model cost
|
|
|
|
cost = calculate_estimated_cost("upscale", "fast")
|
|
assert cost == 2 # Fast upscale cost
|
|
|
|
|
|
class TestStabilityMiddleware:
|
|
"""Test cases for middleware."""
|
|
|
|
def test_rate_limit_middleware(self):
|
|
"""Test rate limiting middleware."""
|
|
from middleware.stability_middleware import RateLimitMiddleware
|
|
|
|
middleware = RateLimitMiddleware(requests_per_window=5, window_seconds=10)
|
|
|
|
# Test client identification
|
|
mock_request = Mock()
|
|
mock_request.headers = {"authorization": "Bearer test_api_key"}
|
|
|
|
client_id = middleware._get_client_id(mock_request)
|
|
assert len(client_id) == 8 # First 8 chars of API key
|
|
|
|
def test_monitoring_middleware(self):
|
|
"""Test monitoring middleware."""
|
|
from middleware.stability_middleware import MonitoringMiddleware
|
|
|
|
middleware = MonitoringMiddleware()
|
|
|
|
# Test operation extraction
|
|
operation = middleware._extract_operation("/api/stability/generate/ultra")
|
|
assert operation == "generate_ultra"
|
|
|
|
def test_caching_middleware(self):
|
|
"""Test caching middleware."""
|
|
from middleware.stability_middleware import CachingMiddleware
|
|
|
|
middleware = CachingMiddleware()
|
|
|
|
# Test cache key generation
|
|
mock_request = Mock()
|
|
mock_request.method = "GET"
|
|
mock_request.url.path = "/api/stability/health"
|
|
mock_request.query_params = {}
|
|
|
|
# This would need to be properly mocked for async
|
|
# cache_key = await middleware._generate_cache_key(mock_request)
|
|
# assert isinstance(cache_key, str)
|
|
|
|
|
|
class TestErrorHandling:
|
|
"""Test error handling scenarios."""
|
|
|
|
@patch('services.stability_service.StabilityAIService')
|
|
def test_api_error_handling(self, mock_service):
|
|
"""Test API error response handling."""
|
|
mock_service.return_value.__aenter__.return_value.generate_ultra = AsyncMock(
|
|
side_effect=HTTPException(status_code=400, detail="Invalid parameters")
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/stability/generate/ultra",
|
|
data={"prompt": "Test"}
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
|
|
@patch('services.stability_service.StabilityAIService')
|
|
def test_timeout_handling(self, mock_service):
|
|
"""Test timeout error handling."""
|
|
mock_service.return_value.__aenter__.return_value.generate_ultra = AsyncMock(
|
|
side_effect=asyncio.TimeoutError()
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/stability/generate/ultra",
|
|
data={"prompt": "Test"}
|
|
)
|
|
|
|
assert response.status_code == 504
|
|
|
|
def test_file_size_validation(self):
|
|
"""Test file size validation."""
|
|
from utils.stability_utils import validate_file_size
|
|
|
|
# Mock large file
|
|
mock_file = Mock()
|
|
mock_file.size = 20 * 1024 * 1024 # 20MB
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
validate_file_size(mock_file, max_size=10 * 1024 * 1024)
|
|
|
|
assert exc_info.value.status_code == 413
|
|
|
|
|
|
class TestWorkflowProcessing:
|
|
"""Test workflow and batch processing."""
|
|
|
|
@patch('services.stability_service.StabilityAIService')
|
|
def test_workflow_validation(self, mock_service):
|
|
"""Test workflow validation."""
|
|
from utils.stability_utils import WorkflowManager
|
|
|
|
# Valid workflow
|
|
workflow = [
|
|
{"operation": "generate_core", "parameters": {"prompt": "test"}},
|
|
{"operation": "upscale_fast", "parameters": {}}
|
|
]
|
|
|
|
errors = WorkflowManager.validate_workflow(workflow)
|
|
assert len(errors) == 0
|
|
|
|
# Invalid workflow
|
|
invalid_workflow = [
|
|
{"operation": "invalid_operation"}
|
|
]
|
|
|
|
errors = WorkflowManager.validate_workflow(invalid_workflow)
|
|
assert len(errors) > 0
|
|
|
|
def test_workflow_optimization(self):
|
|
"""Test workflow optimization."""
|
|
from utils.stability_utils import WorkflowManager
|
|
|
|
workflow = [
|
|
{"operation": "upscale_fast"},
|
|
{"operation": "generate_core"}, # Should be moved to front
|
|
{"operation": "inpaint"}
|
|
]
|
|
|
|
optimized = WorkflowManager.optimize_workflow(workflow)
|
|
|
|
# Generate operation should be first
|
|
assert optimized[0]["operation"] == "generate_core"
|
|
|
|
|
|
# ==================== INTEGRATION TESTS ====================
|
|
|
|
class TestStabilityIntegration:
|
|
"""Integration tests for full workflow."""
|
|
|
|
@pytest.mark.asyncio
|
|
@patch('aiohttp.ClientSession')
|
|
async def test_full_generation_workflow(self, mock_session):
|
|
"""Test complete generation workflow."""
|
|
# Mock successful API responses
|
|
mock_response = AsyncMock()
|
|
mock_response.status = 200
|
|
mock_response.read.return_value = b"test_image_data"
|
|
mock_response.headers = {"Content-Type": "image/png"}
|
|
|
|
mock_session.return_value.__aenter__.return_value.request.return_value.__aenter__.return_value = mock_response
|
|
|
|
service = StabilityAIService(api_key="test_key")
|
|
|
|
async with service:
|
|
# Test generation
|
|
result = await service.generate_ultra(
|
|
prompt="A beautiful landscape",
|
|
aspect_ratio="16:9",
|
|
seed=42
|
|
)
|
|
|
|
assert isinstance(result, bytes)
|
|
assert len(result) > 0
|
|
|
|
@pytest.mark.asyncio
|
|
@patch('aiohttp.ClientSession')
|
|
async def test_full_edit_workflow(self, mock_session):
|
|
"""Test complete edit workflow."""
|
|
# Mock successful API responses
|
|
mock_response = AsyncMock()
|
|
mock_response.status = 200
|
|
mock_response.read.return_value = b"test_edited_image_data"
|
|
mock_response.headers = {"Content-Type": "image/png"}
|
|
|
|
mock_session.return_value.__aenter__.return_value.request.return_value.__aenter__.return_value = mock_response
|
|
|
|
service = StabilityAIService(api_key="test_key")
|
|
|
|
async with service:
|
|
# Test inpainting
|
|
result = await service.inpaint(
|
|
image=b"test_image_data",
|
|
prompt="A cat in the scene",
|
|
grow_mask=10
|
|
)
|
|
|
|
assert isinstance(result, bytes)
|
|
assert len(result) > 0
|
|
|
|
|
|
# ==================== PERFORMANCE TESTS ====================
|
|
|
|
class TestStabilityPerformance:
|
|
"""Performance tests for Stability AI endpoints."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_requests(self):
|
|
"""Test handling of concurrent requests."""
|
|
from services.stability_service import StabilityAIService
|
|
|
|
async def mock_request():
|
|
service = StabilityAIService(api_key="test_key")
|
|
# Mock a quick operation
|
|
await asyncio.sleep(0.1)
|
|
return "success"
|
|
|
|
# Run multiple concurrent requests
|
|
tasks = [mock_request() for _ in range(10)]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# All should succeed
|
|
assert all(result == "success" for result in results)
|
|
|
|
def test_large_file_handling(self):
|
|
"""Test handling of large files."""
|
|
from utils.stability_utils import validate_file_size
|
|
|
|
# Test with various file sizes
|
|
mock_file = Mock()
|
|
|
|
# Valid size
|
|
mock_file.size = 5 * 1024 * 1024 # 5MB
|
|
validate_file_size(mock_file) # Should not raise
|
|
|
|
# Invalid size
|
|
mock_file.size = 15 * 1024 * 1024 # 15MB
|
|
with pytest.raises(HTTPException):
|
|
validate_file_size(mock_file)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"]) |