Fix preflight model mapping when skipping invalid providers

This commit is contained in:
ي
2026-03-05 11:10:54 +05:30
parent 2318fd8a48
commit 81052d06b4

View File

@@ -75,7 +75,10 @@ async def preflight_check(
'provider': provider_enum, 'provider': provider_enum,
'tokens_requested': op.tokens_requested or 0, 'tokens_requested': op.tokens_requested or 0,
'actual_provider_name': op.actual_provider_name or op.provider, 'actual_provider_name': op.actual_provider_name or op.provider,
'operation_type': op.operation_type 'operation_type': op.operation_type,
# Keep the originating request fields together so model lookup
# cannot drift when invalid providers are skipped.
'model': op.model
}) })
except Exception as e: except Exception as e:
logger.warning(f"Error processing operation {op.operation_type}: {e}") logger.warning(f"Error processing operation {op.operation_type}: {e}")
@@ -94,7 +97,7 @@ async def preflight_check(
operation_results = [] operation_results = []
total_cost = 0.0 total_cost = 0.0
for i, op in enumerate(operations_to_validate): for op in operations_to_validate:
op_result = { op_result = {
'provider': op['actual_provider_name'], 'provider': op['actual_provider_name'],
'operation_type': op['operation_type'], 'operation_type': op['operation_type'],
@@ -105,7 +108,7 @@ async def preflight_check(
} }
# Get pricing for this operation # Get pricing for this operation
model_name = request.operations[i].model model_name = op.get('model')
if model_name: if model_name:
pricing_info = pricing_service.get_pricing_for_provider_model( pricing_info = pricing_service.get_pricing_for_provider_model(
op['provider'], op['provider'],
@@ -124,11 +127,15 @@ async def preflight_check(
chars = max(0, int(op.get('tokens_requested') or 0)) chars = max(0, int(op.get('tokens_requested') or 0))
cost = max(0.005, 0.005 * (chars / 100.0)) cost = max(0.005, 0.005 * (chars / 100.0))
else: else:
# Audio pricing is per character (every character is 1 token) # Audio pricing uses per-token/per-character unit pricing from DB.
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000.0) # Do not divide by 1000 here: pricing values are already normalized
# as per-unit costs in APIProviderPricing.
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * op['tokens_requested']
elif op['tokens_requested'] > 0: elif op['tokens_requested'] > 0:
# Token-based cost estimation (rough estimate) # Token-based cost estimation (rough estimate).
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000) # IMPORTANT: cost_per_input_token is stored as cost-per-token.
# Multiplying by tokens_requested gives the correct estimate.
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * op['tokens_requested']
else: else:
cost = pricing_info.get('cost_per_request', 0.0) or 0.0 cost = pricing_info.get('cost_per_request', 0.0) or 0.0