Fix generating pytest for training service 2

This commit is contained in:
Urtzi Alfaro
2025-07-25 14:46:45 +02:00
parent 499d6a1db0
commit 7995429454
10 changed files with 13 additions and 5936 deletions

View File

@@ -220,6 +220,19 @@ async def get_metrics():
return app.state.metrics_collector.get_metrics() return app.state.metrics_collector.get_metrics()
return {"status": "metrics not available"} return {"status": "metrics not available"}
@app.get("/health/live")
async def liveness_check():
return {"status": "alive"}
@app.get("/health/ready")
async def readiness_check():
ready = getattr(app.state, 'ready', True)
return {"status": "ready" if ready else "not ready"}
@app.get("/")
async def root():
return {"service": "training-service", "version": "1.0.0"}
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run( uvicorn.run(
"app.main:app", "app.main:app",

File diff suppressed because it is too large Load Diff

View File

@@ -1,673 +0,0 @@
# ================================================================
# services/training/tests/run_tests.py
# ================================================================
"""
Main test runner script for Training Service
Executes comprehensive test suite and generates reports
"""
import os
import sys
import asyncio
import subprocess
import json
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Any
import logging
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class TrainingTestRunner:
"""Main test runner for training service"""
def __init__(self):
self.test_dir = Path(__file__).parent
self.results_dir = self.test_dir / "results"
self.results_dir.mkdir(exist_ok=True)
# Test configuration
self.test_suites = {
"unit": {
"files": ["test_api.py", "test_ml.py", "test_service.py"],
"description": "Unit tests for individual components",
"timeout": 300 # 5 minutes
},
"integration": {
"files": ["test_ml_pipeline_integration.py"],
"description": "Integration tests for ML pipeline with external data",
"timeout": 600 # 10 minutes
},
"performance": {
"files": ["test_performance.py"],
"description": "Performance and load testing",
"timeout": 900 # 15 minutes
},
"end_to_end": {
"files": ["test_end_to_end.py"],
"description": "End-to-end workflow testing",
"timeout": 800 # 13 minutes
}
}
self.test_results = {}
async def setup_test_environment(self):
"""Setup test environment and dependencies"""
logger.info("Setting up test environment...")
# Check if we're running in Docker
if os.path.exists("/.dockerenv"):
logger.info("Running in Docker environment")
else:
logger.info("Running in local environment")
# Verify required files exist
required_files = [
"conftest.py",
"test_ml_pipeline_integration.py",
"test_performance.py"
]
for file in required_files:
file_path = self.test_dir / file
if not file_path.exists():
logger.warning(f"Required test file missing: {file}")
# Create test data if needed
await self.create_test_data()
# Verify external services (mock or real)
await self.verify_external_services()
async def create_test_data(self):
"""Create or verify test data exists"""
logger.info("Creating/verifying test data...")
test_data_dir = self.test_dir / "fixtures" / "test_data"
test_data_dir.mkdir(parents=True, exist_ok=True)
# Create bakery sales sample if it doesn't exist
sales_file = test_data_dir / "bakery_sales_sample.csv"
if not sales_file.exists():
logger.info("Creating sample sales data...")
await self.generate_sample_sales_data(sales_file)
# Create weather data sample
weather_file = test_data_dir / "madrid_weather_sample.json"
if not weather_file.exists():
logger.info("Creating sample weather data...")
await self.generate_sample_weather_data(weather_file)
# Create traffic data sample
traffic_file = test_data_dir / "madrid_traffic_sample.json"
if not traffic_file.exists():
logger.info("Creating sample traffic data...")
await self.generate_sample_traffic_data(traffic_file)
async def generate_sample_sales_data(self, file_path: Path):
"""Generate sample sales data for testing"""
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
# Generate 6 months of sample data
start_date = datetime(2023, 6, 1)
dates = [start_date + timedelta(days=i) for i in range(180)]
products = ["Pan Integral", "Croissant", "Magdalenas", "Empanadas", "Tarta Chocolate"]
data = []
for date in dates:
for product in products:
base_quantity = np.random.randint(10, 100)
# Weekend boost
if date.weekday() >= 5:
base_quantity *= 1.2
# Seasonal variation
temp = 15 + 10 * np.sin((date.timetuple().tm_yday / 365) * 2 * np.pi)
data.append({
"date": date.strftime("%Y-%m-%d"),
"product": product,
"quantity": int(base_quantity),
"revenue": round(base_quantity * np.random.uniform(2.5, 8.0), 2),
"temperature": round(temp + np.random.normal(0, 3), 1),
"precipitation": max(0, np.random.exponential(0.5)),
"is_weekend": date.weekday() >= 5,
"is_holiday": False
})
df = pd.DataFrame(data)
df.to_csv(file_path, index=False)
logger.info(f"Created sample sales data: {len(df)} records")
async def generate_sample_weather_data(self, file_path: Path):
"""Generate sample weather data"""
import json
from datetime import datetime, timedelta
import numpy as np
start_date = datetime(2023, 6, 1)
weather_data = []
for i in range(180):
date = start_date + timedelta(days=i)
day_of_year = date.timetuple().tm_yday
base_temp = 14 + 12 * np.sin((day_of_year / 365) * 2 * np.pi)
weather_data.append({
"date": date.isoformat(),
"temperature": round(base_temp + np.random.normal(0, 5), 1),
"precipitation": max(0, np.random.exponential(1.0)),
"humidity": np.random.uniform(30, 80),
"wind_speed": np.random.uniform(5, 25),
"pressure": np.random.uniform(1000, 1025),
"description": np.random.choice(["Soleado", "Nuboso", "Lluvioso"]),
"source": "aemet_test"
})
with open(file_path, 'w') as f:
json.dump(weather_data, f, indent=2)
logger.info(f"Created sample weather data: {len(weather_data)} records")
async def generate_sample_traffic_data(self, file_path: Path):
"""Generate sample traffic data"""
import json
from datetime import datetime, timedelta
import numpy as np
start_date = datetime(2023, 6, 1)
traffic_data = []
for i in range(180):
date = start_date + timedelta(days=i)
for hour in [8, 12, 18]: # Three measurements per day
measurement_time = date.replace(hour=hour)
if hour in [8, 18]: # Rush hours
volume = np.random.randint(800, 1500)
congestion = "high"
else: # Lunch time
volume = np.random.randint(400, 800)
congestion = "medium"
traffic_data.append({
"date": measurement_time.isoformat(),
"traffic_volume": volume,
"occupation_percentage": np.random.randint(10, 90),
"load_percentage": np.random.randint(20, 95),
"average_speed": np.random.randint(15, 50),
"congestion_level": congestion,
"pedestrian_count": np.random.randint(50, 500),
"measurement_point_id": "TEST_POINT_001",
"measurement_point_name": "Plaza Mayor",
"road_type": "URB",
"source": "madrid_opendata_test"
})
with open(file_path, 'w') as f:
json.dump(traffic_data, f, indent=2)
logger.info(f"Created sample traffic data: {len(traffic_data)} records")
async def verify_external_services(self):
"""Verify external services are available (mock or real)"""
logger.info("Verifying external services...")
# Check if mock services are available
mock_services = [
("Mock AEMET", "http://localhost:8080/health"),
("Mock Madrid OpenData", "http://localhost:8081/health"),
("Mock Auth Service", "http://localhost:8082/health"),
("Mock Data Service", "http://localhost:8083/health")
]
try:
import httpx
async with httpx.AsyncClient(timeout=5.0) as client:
for service_name, url in mock_services:
try:
response = await client.get(url)
if response.status_code == 200:
logger.info(f"{service_name} is available")
else:
logger.warning(f"{service_name} returned status {response.status_code}")
except Exception as e:
logger.warning(f"{service_name} is not available: {e}")
except ImportError:
logger.warning("httpx not available, skipping service checks")
def run_test_suite(self, suite_name: str) -> Dict[str, Any]:
"""Run a specific test suite"""
suite_config = self.test_suites[suite_name]
logger.info(f"Running {suite_name} test suite: {suite_config['description']}")
start_time = time.time()
# Prepare pytest command
pytest_args = [
"python", "-m", "pytest",
"-v",
"--tb=short",
"--capture=no",
f"--junitxml={self.results_dir}/junit_{suite_name}.xml",
f"--cov=app",
f"--cov-report=html:{self.results_dir}/coverage_{suite_name}_html",
f"--cov-report=xml:{self.results_dir}/coverage_{suite_name}.xml",
"--cov-report=term-missing"
]
# Add test files
for test_file in suite_config["files"]:
test_path = self.test_dir / test_file
if test_path.exists():
pytest_args.append(str(test_path))
else:
logger.warning(f"Test file not found: {test_file}")
# Run the tests
try:
result = subprocess.run(
pytest_args,
cwd=self.test_dir.parent, # Run from training service root
capture_output=True,
text=True,
timeout=suite_config["timeout"]
)
duration = time.time() - start_time
return {
"suite": suite_name,
"status": "passed" if result.returncode == 0 else "failed",
"return_code": result.returncode,
"duration": duration,
"stdout": result.stdout,
"stderr": result.stderr,
"timestamp": datetime.now().isoformat()
}
except subprocess.TimeoutExpired:
duration = time.time() - start_time
logger.error(f"Test suite {suite_name} timed out after {duration:.2f}s")
return {
"suite": suite_name,
"status": "timeout",
"return_code": -1,
"duration": duration,
"stdout": "",
"stderr": f"Test suite timed out after {suite_config['timeout']}s",
"timestamp": datetime.now().isoformat()
}
except Exception as e:
duration = time.time() - start_time
logger.error(f"Error running test suite {suite_name}: {e}")
return {
"suite": suite_name,
"status": "error",
"return_code": -1,
"duration": duration,
"stdout": "",
"stderr": str(e),
"timestamp": datetime.now().isoformat()
}
def generate_test_report(self):
"""Generate comprehensive test report"""
logger.info("Generating test report...")
# Calculate summary statistics
total_suites = len(self.test_results)
passed_suites = sum(1 for r in self.test_results.values() if r["status"] == "passed")
failed_suites = sum(1 for r in self.test_results.values() if r["status"] == "failed")
error_suites = sum(1 for r in self.test_results.values() if r["status"] == "error")
timeout_suites = sum(1 for r in self.test_results.values() if r["status"] == "timeout")
total_duration = sum(r["duration"] for r in self.test_results.values())
# Create detailed report
report = {
"test_run_summary": {
"timestamp": datetime.now().isoformat(),
"total_suites": total_suites,
"passed_suites": passed_suites,
"failed_suites": failed_suites,
"error_suites": error_suites,
"timeout_suites": timeout_suites,
"success_rate": (passed_suites / total_suites * 100) if total_suites > 0 else 0,
"total_duration_seconds": total_duration
},
"suite_results": self.test_results,
"recommendations": self.generate_recommendations()
}
# Save JSON report
report_file = self.results_dir / "test_report.json"
with open(report_file, 'w') as f:
json.dump(report, f, indent=2)
# Generate HTML report
self.generate_html_report(report)
# Print summary to console
self.print_test_summary(report)
return report
def generate_recommendations(self) -> List[str]:
"""Generate recommendations based on test results"""
recommendations = []
failed_suites = [name for name, result in self.test_results.items() if result["status"] == "failed"]
timeout_suites = [name for name, result in self.test_results.items() if result["status"] == "timeout"]
if failed_suites:
recommendations.append(f"Failed test suites: {', '.join(failed_suites)}. Check logs for detailed error messages.")
if timeout_suites:
recommendations.append(f"Timeout in suites: {', '.join(timeout_suites)}. Consider increasing timeout or optimizing performance.")
# Performance recommendations
slow_suites = [
name for name, result in self.test_results.items()
if result["duration"] > 300 # 5 minutes
]
if slow_suites:
recommendations.append(f"Slow test suites: {', '.join(slow_suites)}. Consider performance optimization.")
if not recommendations:
recommendations.append("All tests passed successfully! Consider adding more edge case tests.")
return recommendations
def generate_html_report(self, report: Dict[str, Any]):
"""Generate HTML test report"""
html_template = """
<!DOCTYPE html>
<html>
<head>
<title>Training Service Test Report</title>
<style>
body { font-family: Arial, sans-serif; margin: 40px; }
.header { background-color: #f8f9fa; padding: 20px; border-radius: 5px; }
.summary { display: flex; gap: 20px; margin: 20px 0; }
.metric { background: white; border: 1px solid #dee2e6; padding: 15px; border-radius: 5px; text-align: center; }
.metric-value { font-size: 24px; font-weight: bold; }
.passed { color: #28a745; }
.failed { color: #dc3545; }
.timeout { color: #fd7e14; }
.error { color: #6c757d; }
.suite-result { margin: 20px 0; padding: 15px; border: 1px solid #dee2e6; border-radius: 5px; }
.recommendations { background-color: #e7f3ff; padding: 15px; border-radius: 5px; margin: 20px 0; }
pre { background-color: #f8f9fa; padding: 10px; border-radius: 3px; overflow-x: auto; }
</style>
</head>
<body>
<div class="header">
<h1>Training Service Test Report</h1>
<p>Generated: {timestamp}</p>
</div>
<div class="summary">
<div class="metric">
<div class="metric-value">{total_suites}</div>
<div>Total Suites</div>
</div>
<div class="metric">
<div class="metric-value passed">{passed_suites}</div>
<div>Passed</div>
</div>
<div class="metric">
<div class="metric-value failed">{failed_suites}</div>
<div>Failed</div>
</div>
<div class="metric">
<div class="metric-value timeout">{timeout_suites}</div>
<div>Timeout</div>
</div>
<div class="metric">
<div class="metric-value">{success_rate:.1f}%</div>
<div>Success Rate</div>
</div>
<div class="metric">
<div class="metric-value">{duration:.1f}s</div>
<div>Total Duration</div>
</div>
</div>
<div class="recommendations">
<h3>Recommendations</h3>
<ul>
{recommendations_html}
</ul>
</div>
<h2>Suite Results</h2>
{suite_results_html}
</body>
</html>
"""
# Format recommendations
recommendations_html = '\n'.join(
f"<li>{rec}</li>" for rec in report["recommendations"]
)
# Format suite results
suite_results_html = ""
for suite_name, result in report["suite_results"].items():
status_class = result["status"]
suite_results_html += f"""
<div class="suite-result">
<h3>{suite_name.title()} Tests <span class="{status_class}">({result["status"].upper()})</span></h3>
<p><strong>Duration:</strong> {result["duration"]:.2f}s</p>
<p><strong>Return Code:</strong> {result["return_code"]}</p>
{f'<h4>Output:</h4><pre>{result["stdout"][:1000]}{"..." if len(result["stdout"]) > 1000 else ""}</pre>' if result["stdout"] else ""}
{f'<h4>Errors:</h4><pre>{result["stderr"][:1000]}{"..." if len(result["stderr"]) > 1000 else ""}</pre>' if result["stderr"] else ""}
</div>
"""
# Fill template
html_content = html_template.format(
timestamp=report["test_run_summary"]["timestamp"],
total_suites=report["test_run_summary"]["total_suites"],
passed_suites=report["test_run_summary"]["passed_suites"],
failed_suites=report["test_run_summary"]["failed_suites"],
timeout_suites=report["test_run_summary"]["timeout_suites"],
success_rate=report["test_run_summary"]["success_rate"],
duration=report["test_run_summary"]["total_duration_seconds"],
recommendations_html=recommendations_html,
suite_results_html=suite_results_html
)
# Save HTML report
html_file = self.results_dir / "test_report.html"
with open(html_file, 'w') as f:
f.write(html_content)
logger.info(f"HTML report saved to: {html_file}")
def print_test_summary(self, report: Dict[str, Any]):
"""Print test summary to console"""
summary = report["test_run_summary"]
print("\n" + "=" * 80)
print("TRAINING SERVICE TEST RESULTS SUMMARY")
print("=" * 80)
print(f"Timestamp: {summary['timestamp']}")
print(f"Total Suites: {summary['total_suites']}")
print(f"Passed: {summary['passed_suites']}")
print(f"Failed: {summary['failed_suites']}")
print(f"Errors: {summary['error_suites']}")
print(f"Timeouts: {summary['timeout_suites']}")
print(f"Success Rate: {summary['success_rate']:.1f}%")
print(f"Total Duration: {summary['total_duration_seconds']:.2f}s")
print("\nSUITE DETAILS:")
print("-" * 50)
for suite_name, result in report["suite_results"].items():
status_icon = "" if result["status"] == "passed" else ""
print(f"{status_icon} {suite_name.ljust(15)}: {result['status'].upper().ljust(10)} ({result['duration']:.2f}s)")
print("\nRECOMMENDATIONS:")
print("-" * 50)
for i, rec in enumerate(report["recommendations"], 1):
print(f"{i}. {rec}")
print("\nFILES GENERATED:")
print("-" * 50)
print(f"📄 JSON Report: {self.results_dir}/test_report.json")
print(f"🌐 HTML Report: {self.results_dir}/test_report.html")
print(f"📊 Coverage Reports: {self.results_dir}/coverage_*_html/")
print(f"📋 JUnit XML: {self.results_dir}/junit_*.xml")
print("=" * 80)
async def run_all_tests(self):
"""Run all test suites"""
logger.info("Starting comprehensive test run...")
# Setup environment
await self.setup_test_environment()
# Run each test suite
for suite_name in self.test_suites.keys():
logger.info(f"Starting {suite_name} test suite...")
result = self.run_test_suite(suite_name)
self.test_results[suite_name] = result
if result["status"] == "passed":
logger.info(f"{suite_name} tests PASSED ({result['duration']:.2f}s)")
elif result["status"] == "failed":
logger.error(f"{suite_name} tests FAILED ({result['duration']:.2f}s)")
elif result["status"] == "timeout":
logger.error(f"{suite_name} tests TIMED OUT ({result['duration']:.2f}s)")
else:
logger.error(f"💥 {suite_name} tests ERROR ({result['duration']:.2f}s)")
# Generate final report
report = self.generate_test_report()
return report
def run_specific_suite(self, suite_name: str):
"""Run a specific test suite"""
if suite_name not in self.test_suites:
logger.error(f"Unknown test suite: {suite_name}")
logger.info(f"Available suites: {', '.join(self.test_suites.keys())}")
return None
logger.info(f"Running {suite_name} test suite only...")
result = self.run_test_suite(suite_name)
self.test_results[suite_name] = result
# Generate report for single suite
report = self.generate_test_report()
return report
# ================================================================
# MAIN EXECUTION
# ================================================================
async def main():
"""Main execution function"""
import argparse
parser = argparse.ArgumentParser(description="Training Service Test Runner")
parser.add_argument(
"--suite",
choices=list(TrainingTestRunner().test_suites.keys()) + ["all"],
default="all",
help="Test suite to run (default: all)"
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Verbose output"
)
parser.add_argument(
"--quick",
action="store_true",
help="Run quick tests only (skip performance tests)"
)
args = parser.parse_args()
# Setup logging level
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
# Create test runner
runner = TrainingTestRunner()
# Modify test suites for quick run
if args.quick:
# Skip performance tests in quick mode
if "performance" in runner.test_suites:
del runner.test_suites["performance"]
logger.info("Quick mode: Skipping performance tests")
try:
if args.suite == "all":
report = await runner.run_all_tests()
else:
report = runner.run_specific_suite(args.suite)
# Exit with appropriate code
if report and report["test_run_summary"]["failed_suites"] == 0 and report["test_run_summary"]["error_suites"] == 0:
logger.info("All tests completed successfully!")
sys.exit(0)
else:
logger.error("Some tests failed!")
sys.exit(1)
except KeyboardInterrupt:
logger.info("Test run interrupted by user")
sys.exit(130)
except Exception as e:
logger.error(f"Test run failed with error: {e}")
sys.exit(1)
if __name__ == "__main__":
# Handle both direct execution and pytest discovery
if len(sys.argv) > 1 and sys.argv[1] in ["--suite", "-h", "--help"]:
# Running as main script with arguments
asyncio.run(main())
else:
# Running as pytest discovery or direct execution without args
print("Training Service Test Runner")
print("=" * 50)
print("Usage:")
print(" python run_tests.py --suite all # Run all test suites")
print(" python run_tests.py --suite unit # Run unit tests only")
print(" python run_tests.py --suite integration # Run integration tests only")
print(" python run_tests.py --suite performance # Run performance tests only")
print(" python run_tests.py --quick # Run quick tests (skip performance)")
print(" python run_tests.py -v # Verbose output")
print()
print("Available test suites:")
runner = TrainingTestRunner()
for suite_name, config in runner.test_suites.items():
print(f" {suite_name.ljust(15)}: {config['description']}")
print()
# If no arguments provided, run all tests
if len(sys.argv) == 1:
print("No arguments provided. Running all tests...")
asyncio.run(TrainingTestRunner().run_all_tests())

View File

@@ -1,687 +0,0 @@
# services/training/tests/test_api.py
"""
Tests for training service API endpoints
"""
import pytest
from unittest.mock import AsyncMock, patch
from fastapi import status
from httpx import AsyncClient
from app.schemas.training import TrainingJobRequest
class TestTrainingAPI:
"""Test training API endpoints"""
@pytest.mark.asyncio
async def test_health_check(self, test_client: AsyncClient):
"""Test health check endpoint"""
response = await test_client.get("/health")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["service"] == "training-service"
assert data["version"] == "1.0.0"
assert "status" in data
@pytest.mark.asyncio
async def test_readiness_check_ready(self, test_client: AsyncClient):
"""Test readiness check when service is ready"""
# Mock app state as ready
from app.main import app # Add import at top
with patch.object(app.state, 'ready', True, create=True):
response = await test_client.get("/health/ready")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["status"] == "ready"
@pytest.mark.asyncio
async def test_readiness_check_not_ready(self, test_client: AsyncClient):
"""Test readiness check when service is not ready"""
with patch('app.main.app.state.ready', False):
response = await test_client.get("/health/ready")
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
data = response.json()
assert data["status"] == "not_ready"
@pytest.mark.asyncio
async def test_liveness_check_healthy(self, test_client: AsyncClient):
"""Test liveness check when service is healthy"""
with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=True)):
response = await test_client.get("/health/live")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["status"] == "alive"
@pytest.mark.asyncio
async def test_liveness_check_unhealthy(self, test_client: AsyncClient):
"""Test liveness check when database is unhealthy"""
with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=False)):
response = await test_client.get("/health/live")
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
data = response.json()
assert data["status"] == "unhealthy"
assert data["reason"] == "database_unavailable"
@pytest.mark.asyncio
async def test_metrics_endpoint(self, test_client: AsyncClient):
"""Test metrics endpoint"""
response = await test_client.get("/metrics")
assert response.status_code == status.HTTP_200_OK
data = response.json()
expected_metrics = [
"training_jobs_active",
"training_jobs_completed",
"training_jobs_failed",
"models_trained_total",
"uptime_seconds"
]
for metric in expected_metrics:
assert metric in data
@pytest.mark.asyncio
async def test_root_endpoint(self, test_client: AsyncClient):
"""Test root endpoint"""
response = await test_client.get("/")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["service"] == "training-service"
assert data["version"] == "1.0.0"
assert "description" in data
class TestTrainingJobsAPI:
"""Test training jobs API endpoints"""
@pytest.mark.asyncio
async def test_start_training_job_success(
self,
test_client: AsyncClient,
mock_messaging,
mock_ml_trainer,
mock_data_service
):
"""Test starting a training job successfully"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30,
"seasonality_mode": "additive"
}
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "job_id" in data
assert data["status"] == "started"
assert data["tenant_id"] == "test-tenant"
assert "estimated_duration_minutes" in data
@pytest.mark.asyncio
async def test_start_training_job_validation_error(self, test_client: AsyncClient):
"""Test starting training job with validation error"""
request_data = {
"seasonality_mode": "invalid_mode", # Invalid value
"min_data_points": 5 # Too low
}
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_get_training_status_existing_job(
self,
test_client: AsyncClient,
training_job_in_db
):
"""Test getting status of existing training job"""
job_id = training_job_in_db.job_id
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["job_id"] == job_id
assert data["status"] == "pending"
assert "progress" in data
assert "started_at" in data
@pytest.mark.asyncio
async def test_get_training_status_nonexistent_job(self, test_client: AsyncClient):
"""Test getting status of non-existent training job"""
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get("/training/jobs/nonexistent-job/status")
assert response.status_code == status.HTTP_404_NOT_FOUND
@pytest.mark.asyncio
async def test_list_training_jobs(
self,
test_client: AsyncClient,
training_job_in_db
):
"""Test listing training jobs"""
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get("/training/jobs")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)
assert len(data) >= 1
# Check first job structure
job = data[0]
assert "job_id" in job
assert "status" in job
assert "started_at" in job
@pytest.mark.asyncio
async def test_list_training_jobs_with_status_filter(
self,
test_client: AsyncClient,
training_job_in_db
):
"""Test listing training jobs with status filter"""
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get("/training/jobs?status=pending")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)
# All jobs should have status "pending"
for job in data:
assert job["status"] == "pending"
@pytest.mark.asyncio
async def test_cancel_training_job_success(
self,
test_client: AsyncClient,
training_job_in_db,
mock_messaging
):
"""Test cancelling a training job successfully"""
job_id = training_job_in_db.job_id
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(f"/training/jobs/{job_id}/cancel")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "message" in data
assert "cancelled" in data["message"].lower()
@pytest.mark.asyncio
async def test_cancel_nonexistent_job(self, test_client: AsyncClient):
"""Test cancelling a non-existent training job"""
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs/nonexistent-job/cancel")
assert response.status_code == status.HTTP_404_NOT_FOUND
@pytest.mark.asyncio
async def test_get_training_logs(
self,
test_client: AsyncClient,
training_job_in_db
):
"""Test getting training logs"""
job_id = training_job_in_db.job_id
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/logs")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "job_id" in data
assert "logs" in data
assert isinstance(data["logs"], list)
@pytest.mark.asyncio
async def test_validate_training_data_valid(
self,
test_client: AsyncClient,
mock_data_service
):
"""Test validating valid training data"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/validate", json=request_data)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "is_valid" in data
assert "issues" in data
assert "recommendations" in data
assert "estimated_training_time" in data
class TestSingleProductTrainingAPI:
"""Test single product training API endpoints"""
@pytest.mark.asyncio
async def test_train_single_product_success(
self,
test_client: AsyncClient,
mock_messaging,
mock_ml_trainer,
mock_data_service
):
"""Test training a single product successfully"""
product_name = "Pan Integral"
request_data = {
"include_weather": True,
"include_traffic": True,
"seasonality_mode": "additive"
}
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=request_data
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "job_id" in data
assert data["status"] == "started"
assert data["tenant_id"] == "test-tenant"
assert f"training started for {product_name}" in data["message"].lower()
@pytest.mark.asyncio
async def test_train_single_product_validation_error(self, test_client: AsyncClient):
"""Test single product training with validation error"""
product_name = "Pan Integral"
request_data = {
"seasonality_mode": "invalid_mode" # Invalid value
}
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=request_data
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_train_single_product_special_characters(
self,
test_client: AsyncClient,
mock_messaging,
mock_ml_trainer,
mock_data_service
):
"""Test training product with special characters in name"""
product_name = "Pan Francés" # With accent
request_data = {
"include_weather": True,
"seasonality_mode": "additive"
}
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=request_data
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "job_id" in data
class TestModelsAPI:
"""Test models API endpoints"""
@pytest.mark.asyncio
async def test_list_models(
self,
test_client: AsyncClient,
trained_model_in_db
):
"""Test listing trained models"""
with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get("/models")
# This endpoint might not exist yet, so we expect either 200 or 404
assert response.status_code in [status.HTTP_200_OK, status.HTTP_404_NOT_FOUND]
if response.status_code == status.HTTP_200_OK:
data = response.json()
assert isinstance(data, list)
@pytest.mark.asyncio
async def test_get_model_details(
self,
test_client: AsyncClient,
trained_model_in_db
):
"""Test getting model details"""
model_id = trained_model_in_db.model_id
with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"):
response = await test_client.get(f"/models/{model_id}")
# This endpoint might not exist yet
assert response.status_code in [
status.HTTP_200_OK,
status.HTTP_404_NOT_FOUND,
status.HTTP_501_NOT_IMPLEMENTED
]
class TestErrorHandling:
"""Test error handling in API endpoints"""
@pytest.mark.asyncio
async def test_database_error_handling(self, test_client: AsyncClient):
"""Test handling of database errors"""
with patch('app.services.training_service.TrainingService.create_training_job') as mock_create:
mock_create.side_effect = Exception("Database connection failed")
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_missing_tenant_id(self, test_client: AsyncClient):
"""Test handling when tenant ID is missing"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
# Don't mock get_current_tenant_id to simulate missing auth
response = await test_client.post("/training/jobs", json=request_data)
# Should fail due to missing authentication
assert response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]
@pytest.mark.asyncio
async def test_invalid_job_id_format(self, test_client: AsyncClient):
"""Test handling of invalid job ID format"""
invalid_job_id = "invalid-job-id-with-special-chars@#$"
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{invalid_job_id}/status")
# Should handle gracefully
assert response.status_code in [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]
@pytest.mark.asyncio
async def test_messaging_failure_handling(
self,
test_client: AsyncClient,
mock_data_service
):
"""Test handling when messaging fails"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")), \
patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
# Should still succeed even if messaging fails
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "job_id" in data
@pytest.mark.asyncio
async def test_invalid_json_payload(self, test_client: AsyncClient):
"""Test handling of invalid JSON payload"""
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
"/training/jobs",
content="invalid json {{{",
headers={"Content-Type": "application/json"}
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_unsupported_content_type(self, test_client: AsyncClient):
"""Test handling of unsupported content type"""
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
"/training/jobs",
content="some text data",
headers={"Content-Type": "text/plain"}
)
assert response.status_code in [
status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
status.HTTP_422_UNPROCESSABLE_ENTITY
]
class TestAuthenticationIntegration:
"""Test authentication integration"""
@pytest.mark.asyncio
async def test_endpoints_require_auth(self, test_client: AsyncClient):
"""Test that endpoints require authentication in production"""
# This test would be more meaningful in a production environment
# where authentication is actually enforced
endpoints_to_test = [
("POST", "/training/jobs"),
("GET", "/training/jobs"),
("POST", "/training/products/Pan Integral"),
("POST", "/training/validate")
]
for method, endpoint in endpoints_to_test:
if method == "POST":
response = await test_client.post(endpoint, json={})
else:
response = await test_client.get(endpoint)
# In test environment with mocked auth, should work
# In production, would require valid authentication
assert response.status_code != status.HTTP_500_INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_tenant_isolation_in_api(
self,
test_client: AsyncClient,
training_job_in_db
):
"""Test tenant isolation at API level"""
job_id = training_job_in_db.job_id
# Try to access job with different tenant
with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
# Should not find job for different tenant
assert response.status_code == status.HTTP_404_NOT_FOUND
class TestAPIValidation:
"""Test API validation and input handling"""
@pytest.mark.asyncio
async def test_training_request_validation(self, test_client: AsyncClient):
"""Test comprehensive training request validation"""
# Test valid request
valid_request = {
"include_weather": True,
"include_traffic": False,
"min_data_points": 30,
"seasonality_mode": "additive",
"daily_seasonality": True,
"weekly_seasonality": True,
"yearly_seasonality": True
}
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=valid_request)
assert response.status_code == status.HTTP_200_OK
# Test invalid seasonality mode
invalid_request = valid_request.copy()
invalid_request["seasonality_mode"] = "invalid_mode"
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=invalid_request)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
# Test invalid min_data_points
invalid_request = valid_request.copy()
invalid_request["min_data_points"] = 5 # Too low
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=invalid_request)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_single_product_request_validation(self, test_client: AsyncClient):
"""Test single product training request validation"""
product_name = "Pan Integral"
# Test valid request
valid_request = {
"include_weather": True,
"include_traffic": True,
"seasonality_mode": "multiplicative"
}
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=valid_request
)
assert response.status_code == status.HTTP_200_OK
# Test empty product name
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
"/training/products/",
json=valid_request
)
assert response.status_code == status.HTTP_404_NOT_FOUND
@pytest.mark.asyncio
async def test_query_parameter_validation(self, test_client: AsyncClient):
"""Test query parameter validation"""
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
# Test valid limit parameter
response = await test_client.get("/training/jobs?limit=5")
assert response.status_code == status.HTTP_200_OK
# Test invalid limit parameter
response = await test_client.get("/training/jobs?limit=invalid")
assert response.status_code in [
status.HTTP_422_UNPROCESSABLE_ENTITY,
status.HTTP_400_BAD_REQUEST
]
# Test negative limit
response = await test_client.get("/training/jobs?limit=-1")
assert response.status_code in [
status.HTTP_422_UNPROCESSABLE_ENTITY,
status.HTTP_400_BAD_REQUEST
]
class TestAPIPerformance:
"""Test API performance characteristics"""
@pytest.mark.asyncio
async def test_concurrent_requests(self, test_client: AsyncClient):
"""Test handling of concurrent requests"""
import asyncio
# Create multiple concurrent requests
tasks = []
for i in range(10):
with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"):
task = test_client.get("/health")
tasks.append(task)
responses = await asyncio.gather(*tasks)
# All requests should succeed
for response in responses:
assert response.status_code == status.HTTP_200_OK
@pytest.mark.asyncio
async def test_large_payload_handling(self, test_client: AsyncClient):
"""Test handling of large request payloads"""
# Create large request payload
large_request = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30,
"large_config": {f"key_{i}": f"value_{i}" for i in range(1000)}
}
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=large_request)
# Should handle large payload gracefully
assert response.status_code in [
status.HTTP_200_OK,
status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
]
@pytest.mark.asyncio
async def test_rapid_successive_requests(self, test_client: AsyncClient):
"""Test rapid successive requests to same endpoint"""
# Make rapid requests
responses = []
for _ in range(20):
response = await test_client.get("/health")
responses.append(response)
# All should succeed
for response in responses:
assert response.status_code == status.HTTP_200_OK

View File

@@ -1,311 +0,0 @@
# ================================================================
# services/training/tests/test_end_to_end.py
# ================================================================
"""
End-to-End Testing for Training Service
Tests complete workflows from API to ML pipeline to results
"""
import pytest
import asyncio
import httpx
import pandas as pd
import json
import tempfile
import time
from datetime import datetime, timedelta
from typing import Dict, List, Any
from unittest.mock import patch, AsyncMock
import uuid
from app.main import app
from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
class TestTrainingServiceEndToEnd:
"""End-to-end tests for complete training workflows"""
@pytest.fixture
async def test_client(self):
"""Create test client for the training service"""
from httpx import AsyncClient
async with AsyncClient(app=app, base_url="http://test") as client:
yield client
@pytest.fixture
def real_bakery_data(self):
"""Use the actual bakery sales data from the uploaded CSV"""
# This fixture would load the real bakery_sales_2023_2024.csv data
# For testing, we'll simulate the structure based on the document description
# Generate realistic data matching the CSV structure
start_date = datetime(2023, 1, 1)
dates = [start_date + timedelta(days=i) for i in range(365)]
products = [
"Pan Integral", "Pan Blanco", "Croissant", "Magdalenas",
"Empanadas", "Tarta Chocolate", "Roscon Reyes", "Palmeras"
]
data = []
for date in dates:
for product in products:
# Realistic sales patterns for Madrid bakery
base_quantity = {
"Pan Integral": 80, "Pan Blanco": 120, "Croissant": 45,
"Magdalenas": 30, "Empanadas": 25, "Tarta Chocolate": 15,
"Roscon Reyes": 8, "Palmeras": 12
}.get(product, 20)
# Seasonal variations
if date.month == 12 and product == "Roscon Reyes":
base_quantity *= 5 # Christmas specialty
elif date.month in [6, 7, 8]: # Summer
base_quantity *= 0.85
elif date.month in [11, 12, 1]: # Winter
base_quantity *= 1.15
# Weekly patterns
if date.weekday() >= 5: # Weekends
base_quantity *= 1.3
elif date.weekday() == 0: # Monday slower
base_quantity *= 0.8
# Weather influence
temp = 15 + 12 * np.sin((date.timetuple().tm_yday / 365) * 2 * np.pi)
if temp > 30: # Very hot days
if product in ["Pan Integral", "Pan Blanco"]:
base_quantity *= 0.7
elif temp < 5: # Cold days
base_quantity *= 1.1
# Add realistic noise
import numpy as np
quantity = max(1, int(base_quantity + np.random.normal(0, base_quantity * 0.15)))
# Calculate revenue (realistic Spanish bakery prices)
price_per_unit = {
"Pan Integral": 2.80, "Pan Blanco": 2.50, "Croissant": 1.50,
"Magdalenas": 1.20, "Empanadas": 3.50, "Tarta Chocolate": 18.00,
"Roscon Reyes": 25.00, "Palmeras": 1.80
}.get(product, 2.00)
revenue = round(quantity * price_per_unit, 2)
data.append({
"date": date.strftime("%Y-%m-%d"),
"product": product,
"quantity": quantity,
"revenue": revenue,
"temperature": round(temp + np.random.normal(0, 3), 1),
"precipitation": max(0, np.random.exponential(0.8)),
"is_weekend": date.weekday() >= 5,
"is_holiday": self._is_spanish_holiday(date)
})
return pd.DataFrame(data)
def _is_spanish_holiday(self, date: datetime) -> bool:
"""Check if date is a Spanish holiday"""
spanish_holidays = [
(1, 1), # Año Nuevo
(1, 6), # Reyes Magos
(5, 1), # Día del Trabajo
(8, 15), # Asunción de la Virgen
(10, 12), # Fiesta Nacional de España
(11, 1), # Todos los Santos
(12, 6), # Día de la Constitución
(12, 8), # Inmaculada Concepción
(12, 25), # Navidad
]
return (date.month, date.day) in spanish_holidays
@pytest.fixture
async def mock_external_apis(self):
"""Mock external APIs (AEMET and Madrid OpenData)"""
with patch('app.external.aemet.AEMETClient') as mock_aemet, \
patch('app.external.madrid_opendata.MadridOpenDataClient') as mock_madrid:
# Mock AEMET weather data
mock_aemet_instance = AsyncMock()
mock_aemet.return_value = mock_aemet_instance
# Generate realistic Madrid weather data
weather_data = []
for i in range(365):
date = datetime(2023, 1, 1) + timedelta(days=i)
day_of_year = date.timetuple().tm_yday
# Madrid climate: hot summers, mild winters
base_temp = 14 + 12 * np.sin((day_of_year / 365) * 2 * np.pi)
weather_data.append({
"date": date,
"temperature": round(base_temp + np.random.normal(0, 4), 1),
"precipitation": max(0, np.random.exponential(1.2)),
"humidity": np.random.uniform(25, 75),
"wind_speed": np.random.uniform(3, 20),
"pressure": np.random.uniform(995, 1025),
"description": np.random.choice([
"Soleado", "Parcialmente nublado", "Nublado",
"Lluvia ligera", "Despejado"
]),
"source": "aemet"
})
mock_aemet_instance.get_historical_weather.return_value = weather_data
mock_aemet_instance.get_current_weather.return_value = weather_data[-1]
# Mock Madrid traffic data
mock_madrid_instance = AsyncMock()
mock_madrid.return_value = mock_madrid_instance
traffic_data = []
for i in range(365):
date = datetime(2023, 1, 1) + timedelta(days=i)
# Multiple measurements per day
for hour in range(6, 22, 2): # Every 2 hours from 6 AM to 10 PM
measurement_time = date.replace(hour=hour)
# Realistic Madrid traffic patterns
if hour in [7, 8, 9, 18, 19, 20]: # Rush hours
volume = np.random.randint(1200, 2000)
congestion = "high"
speed = np.random.randint(10, 25)
elif hour in [12, 13, 14]: # Lunch time
volume = np.random.randint(800, 1200)
congestion = "medium"
speed = np.random.randint(20, 35)
else: # Off-peak
volume = np.random.randint(300, 800)
congestion = "low"
speed = np.random.randint(30, 50)
traffic_data.append({
"date": measurement_time,
"traffic_volume": volume,
"occupation_percentage": np.random.randint(15, 85),
"load_percentage": np.random.randint(25, 90),
"average_speed": speed,
"congestion_level": congestion,
"pedestrian_count": np.random.randint(100, 800),
"measurement_point_id": "MADRID_CENTER_001",
"measurement_point_name": "Puerta del Sol",
"road_type": "URB",
"source": "madrid_opendata"
})
mock_madrid_instance.get_historical_traffic.return_value = traffic_data
mock_madrid_instance.get_current_traffic.return_value = traffic_data[-1]
yield {
'aemet': mock_aemet_instance,
'madrid': mock_madrid_instance
}
@pytest.mark.asyncio
async def test_complete_training_workflow_api(
self,
test_client,
real_bakery_data,
mock_external_apis
):
"""Test complete training workflow through API endpoints"""
# Step 1: Check service health
health_response = await test_client.get("/health")
assert health_response.status_code == 200
health_data = health_response.json()
assert health_data["status"] == "healthy"
# Step 2: Validate training data quality
with patch('app.services.training_service.TrainingService._fetch_sales_data',
return_value=real_bakery_data):
validation_response = await test_client.post(
"/training/validate",
json={
"tenant_id": "test_bakery_001",
"include_weather": True,
"include_traffic": True
}
)
assert validation_response.status_code == 200
validation_data = validation_response.json()
assert validation_data["is_valid"] is True
assert validation_data["data_points"] > 1000 # Sufficient data
assert validation_data["missing_percentage"] < 10
# Step 3: Start training job for multiple products
training_request = {
"products": ["Pan Integral", "Croissant", "Magdalenas"],
"include_weather": True,
"include_traffic": True,
"config": {
"seasonality_mode": "additive",
"changepoint_prior_scale": 0.05,
"seasonality_prior_scale": 10.0,
"validation_enabled": True
}
}
with patch('app.services.training_service.TrainingService._fetch_sales_data',
return_value=real_bakery_data):
start_response = await test_client.post(
"/training/jobs",
json=training_request,
headers={"X-Tenant-ID": "test_bakery_001"}
)
assert start_response.status_code == 201
job_data = start_response.json()
job_id = job_data["job_id"]
assert job_data["status"] == "pending"
# Step 4: Monitor job progress
max_wait_time = 300 # 5 minutes
start_time = time.time()
while time.time() - start_time < max_wait_time:
status_response = await test_client.get(f"/training/jobs/{job_id}/status")
assert status_response.status_code == 200
status_data = status_response.json()
if status_data["status"] == "completed":
# Training completed successfully
assert "models_trained" in status_data
assert len(status_data["models_trained"]) == 3 # Three products
# Check model quality
for model_info in status_data["models_trained"]:
assert "product_name" in model_info
assert "model_id" in model_info
assert "metrics" in model_info
metrics = model_info["metrics"]
assert "mape" in metrics
assert "rmse" in metrics
assert "mae" in metrics
# Quality thresholds for bakery data
assert metrics["mape"] < 50, f"MAPE too high for {model_info['product_name']}: {metrics['mape']}"
assert metrics["rmse"] > 0
break
elif status_data["status"] == "failed":
pytest.fail(f"Training job failed: {status_data.get('error_message', 'Unknown error')}")
# Wait before checking again
await asyncio.sleep(10)
else:
pytest.fail(f"Training job did not complete within {max_wait_time} seconds")
# Step 5: Get detailed job logs
logs_response = await test_client.get(f"/training/jobs/{job_id}/logs")
assert logs_response.status_code == 200
logs_data = logs_response.json()
assert "logs" in logs_data
assert len(logs_data["logs"]) > 0

View File

@@ -1,848 +0,0 @@
# services/training/tests/test_integration.py
"""
Integration tests for training service
Tests complete workflows and service interactions
"""
import pytest
import asyncio
from unittest.mock import AsyncMock, Mock, patch
from httpx import AsyncClient
from datetime import datetime, timedelta
from app.main import app
from app.schemas.training import TrainingJobRequest
class TestTrainingWorkflowIntegration:
"""Test complete training workflows end-to-end"""
@pytest.mark.asyncio
async def test_complete_training_workflow(
self,
test_client: AsyncClient,
test_db_session,
mock_messaging,
mock_data_service,
mock_ml_trainer
):
"""Test complete training workflow from API to completion"""
# Step 1: Start training job
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30,
"seasonality_mode": "additive"
}
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == 200
job_data = response.json()
job_id = job_data["job_id"]
# Step 2: Check initial status
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == 200
status_data = response.json()
assert status_data["status"] in ["pending", "started"]
# Step 3: Simulate background task completion
# In real scenario, this would be handled by background tasks
await asyncio.sleep(0.1) # Allow background task to start
# Step 4: Check completion status
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
# The job should exist in database even if not completed yet
assert response.status_code == 200
@pytest.mark.asyncio
async def test_single_product_training_workflow(
self,
test_client: AsyncClient,
mock_messaging,
mock_data_service,
mock_ml_trainer
):
"""Test single product training complete workflow"""
product_name = "Pan Integral"
request_data = {
"include_weather": True,
"include_traffic": False,
"seasonality_mode": "additive"
}
# Start single product training
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(
f"/training/products/{product_name}",
json=request_data
)
assert response.status_code == 200
job_data = response.json()
job_id = job_data["job_id"]
assert f"training started for {product_name}" in job_data["message"].lower()
# Check job status
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == 200
status_data = response.json()
assert status_data["job_id"] == job_id
@pytest.mark.asyncio
async def test_training_validation_workflow(
self,
test_client: AsyncClient,
mock_data_service
):
"""Test training data validation workflow"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
# Validate training data
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/validate", json=request_data)
assert response.status_code == 200
validation_data = response.json()
assert "is_valid" in validation_data
assert "issues" in validation_data
assert "recommendations" in validation_data
assert "estimated_training_time" in validation_data
# If validation passes, start actual training
if validation_data["is_valid"]:
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_job_cancellation_workflow(
self,
test_client: AsyncClient,
training_job_in_db,
mock_messaging
):
"""Test job cancellation workflow"""
job_id = training_job_in_db.job_id
# Check initial status
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == 200
initial_status = response.json()
assert initial_status["status"] == "pending"
# Cancel the job
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post(f"/training/jobs/{job_id}/cancel")
assert response.status_code == 200
cancel_response = response.json()
assert "cancelled" in cancel_response["message"].lower()
# Verify cancellation
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
assert response.status_code == 200
final_status = response.json()
assert final_status["status"] == "cancelled"
class TestServiceInteractionIntegration:
"""Test interactions between training service and external services"""
@pytest.mark.asyncio
async def test_data_service_integration(self, training_service, mock_data_service):
"""Test integration with data service"""
from app.schemas.training import TrainingJobRequest
request = TrainingJobRequest(
include_weather=True,
include_traffic=True,
min_data_points=30
)
# Test sales data fetching
sales_data = await training_service._fetch_sales_data("test-tenant", request)
assert isinstance(sales_data, list)
# Test weather data fetching
weather_data = await training_service._fetch_weather_data("test-tenant", request)
assert isinstance(weather_data, list)
# Test traffic data fetching
traffic_data = await training_service._fetch_traffic_data("test-tenant", request)
assert isinstance(traffic_data, list)
@pytest.mark.asyncio
async def test_messaging_integration(self, mock_messaging):
"""Test integration with messaging system"""
from app.services.messaging import (
publish_job_started,
publish_job_completed,
publish_model_trained
)
# Test various message types
result1 = await publish_job_started("job-123", "tenant-123", {})
result2 = await publish_job_completed("job-123", "tenant-123", {"status": "success"})
result3 = await publish_model_trained("model-123", "tenant-123", "Pan Integral", {"mae": 5.0})
assert result1 is True
assert result2 is True
assert result3 is True
@pytest.mark.asyncio
async def test_database_integration(self, test_db_session, training_service):
"""Test database operations integration"""
# Create a training job
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="integration-test-job",
config={"test": True}
)
assert job.job_id == "integration-test-job"
# Update job status
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="running",
progress=50,
current_step="Processing data"
)
# Retrieve updated job
updated_job = await training_service.get_job_status(
db=test_db_session,
job_id=job.job_id,
tenant_id="test-tenant"
)
assert updated_job.status == "running"
assert updated_job.progress == 50
class TestErrorHandlingIntegration:
"""Test error handling across service boundaries"""
@pytest.mark.asyncio
async def test_data_service_failure_handling(
self,
test_client: AsyncClient,
mock_messaging
):
"""Test handling when data service is unavailable"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
# Mock data service failure
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get.side_effect = Exception("Service unavailable")
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
# Should still create job but might fail during execution
assert response.status_code == 200
@pytest.mark.asyncio
async def test_messaging_failure_handling(
self,
test_client: AsyncClient,
mock_data_service
):
"""Test handling when messaging fails"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
# Mock messaging failure
with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
# Should still succeed even if messaging fails
assert response.status_code == 200
@pytest.mark.asyncio
async def test_ml_training_failure_handling(
self,
test_client: AsyncClient,
mock_messaging,
mock_data_service
):
"""Test handling when ML training fails"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
# Mock ML training failure
with patch('app.ml.trainer.BakeryMLTrainer.train_tenant_models', side_effect=Exception("ML training failed")):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=request_data)
# Job should be created successfully
assert response.status_code == 200
# Background task would handle the failure
class TestPerformanceIntegration:
"""Test performance characteristics of integrated workflows"""
@pytest.mark.asyncio
async def test_concurrent_training_jobs(
self,
test_client: AsyncClient,
mock_messaging,
mock_data_service,
mock_ml_trainer
):
"""Test handling multiple concurrent training jobs"""
request_data = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30
}
# Start multiple jobs concurrently
tasks = []
for i in range(5):
with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"):
task = test_client.post("/training/jobs", json=request_data)
tasks.append(task)
responses = await asyncio.gather(*tasks)
# All jobs should be created successfully
for response in responses:
assert response.status_code == 200
data = response.json()
assert "job_id" in data
@pytest.mark.asyncio
async def test_large_dataset_handling(
self,
training_service,
test_db_session
):
"""Test handling of large datasets"""
# Simulate large dataset
large_config = {
"include_weather": True,
"include_traffic": True,
"min_data_points": 1000, # Large minimum
"products": [f"Product-{i}" for i in range(100)] # Many products
}
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="large-dataset-job",
config=large_config
)
assert job.config == large_config
assert job.job_id == "large-dataset-job"
@pytest.mark.asyncio
async def test_rapid_status_checks(
self,
test_client: AsyncClient,
training_job_in_db
):
"""Test rapid successive status checks"""
job_id = training_job_in_db.job_id
# Make many rapid status requests
tasks = []
for _ in range(20):
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
task = test_client.get(f"/training/jobs/{job_id}/status")
tasks.append(task)
responses = await asyncio.gather(*tasks)
# All requests should succeed
for response in responses:
assert response.status_code == 200
class TestSecurityIntegration:
"""Test security aspects of service integration"""
@pytest.mark.asyncio
async def test_tenant_isolation(
self,
test_client: AsyncClient,
training_job_in_db,
mock_messaging
):
"""Test that tenants cannot access each other's jobs"""
job_id = training_job_in_db.job_id
# Try to access job with different tenant ID
with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"):
response = await test_client.get(f"/training/jobs/{job_id}/status")
# Should not find the job (belongs to different tenant)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_input_validation_integration(
self,
test_client: AsyncClient
):
"""Test input validation across API boundaries"""
# Test invalid seasonality mode
invalid_request = {
"seasonality_mode": "invalid_mode",
"min_data_points": -5 # Invalid negative value
}
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=invalid_request)
assert response.status_code == 422 # Validation error
@pytest.mark.asyncio
async def test_sql_injection_protection(
self,
test_client: AsyncClient
):
"""Test protection against SQL injection attempts"""
# Try SQL injection in job ID
malicious_job_id = "job'; DROP TABLE model_training_logs; --"
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.get(f"/training/jobs/{malicious_job_id}/status")
# Should return 404, not cause database error
assert response.status_code == 404
class TestRecoveryIntegration:
"""Test recovery and resilience scenarios"""
@pytest.mark.asyncio
async def test_service_restart_recovery(
self,
test_db_session,
training_service,
training_job_in_db
):
"""Test service recovery after restart"""
# Simulate service restart by creating new service instance
new_training_service = training_service.__class__()
# Should be able to access existing jobs
existing_job = await new_training_service.get_job_status(
db=test_db_session,
job_id=training_job_in_db.job_id,
tenant_id=training_job_in_db.tenant_id
)
assert existing_job is not None
assert existing_job.job_id == training_job_in_db.job_id
@pytest.mark.asyncio
async def test_partial_failure_recovery(
self,
training_service,
test_db_session
):
"""Test recovery from partial failures"""
# Create job that might fail partway through
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="partial-failure-job",
config={"simulate_failure": True}
)
# Simulate partial progress
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="running",
progress=50,
current_step="Halfway through training"
)
# Simulate failure
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="failed",
progress=50,
current_step="Training failed",
error_message="Simulated failure"
)
# Verify failure was recorded
failed_job = await training_service.get_job_status(
db=test_db_session,
job_id=job.job_id,
tenant_id="test-tenant"
)
assert failed_job.status == "failed"
assert failed_job.error_message == "Simulated failure"
assert failed_job.progress == 50
class TestComplianceIntegration:
"""Test compliance and audit requirements"""
@pytest.mark.asyncio
async def test_audit_trail_creation(
self,
training_service,
test_db_session
):
"""Test that audit trail is properly created"""
# Create and update job
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="audit-test-job",
config={"audit_test": True}
)
# Multiple status updates
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="running",
progress=25,
current_step="Started processing"
)
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="running",
progress=75,
current_step="Almost complete"
)
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="completed",
progress=100,
current_step="Completed successfully"
)
# Verify audit trail
logs = await training_service.get_training_logs(
db=test_db_session,
job_id=job.job_id,
tenant_id="test-tenant"
)
assert logs is not None
assert len(logs) > 0
# Check final status
final_job = await training_service.get_job_status(
db=test_db_session,
job_id=job.job_id,
tenant_id="test-tenant"
)
assert final_job.status == "completed"
assert final_job.progress == 100
@pytest.mark.asyncio
async def test_data_retention_compliance(
self,
training_service,
test_db_session
):
"""Test data retention and cleanup compliance"""
from datetime import datetime, timedelta
# Create old job (simulate old data)
old_job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="old-job",
config={"created_long_ago": True}
)
# Manually set old timestamp
from sqlalchemy import update
from app.models.training import ModelTrainingLog
old_timestamp = datetime.now() - timedelta(days=400)
await test_db_session.execute(
update(ModelTrainingLog)
.where(ModelTrainingLog.job_id == old_job.job_id)
.values(start_time=old_timestamp, created_at=old_timestamp)
)
await test_db_session.commit()
# Verify old job exists
retrieved_job = await training_service.get_job_status(
db=test_db_session,
job_id=old_job.job_id,
tenant_id="test-tenant"
)
assert retrieved_job is not None
# In a real implementation, there would be cleanup procedures
@pytest.mark.asyncio
async def test_gdpr_compliance_features(
self,
training_service,
test_db_session
):
"""Test GDPR compliance features"""
# Create job with tenant data
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="gdpr-test-tenant",
job_id="gdpr-test-job",
config={"gdpr_test": True}
)
# Verify job is associated with tenant
assert job.tenant_id == "gdpr-test-tenant"
# Test data access (right to access)
tenant_jobs = await training_service.list_training_jobs(
db=test_db_session,
tenant_id="gdpr-test-tenant"
)
assert len(tenant_jobs) >= 1
assert any(job.job_id == "gdpr-test-job" for job in tenant_jobs)
@pytest.mark.slow
class TestLongRunningIntegration:
"""Test long-running integration scenarios (marked as slow)"""
@pytest.mark.asyncio
async def test_extended_training_simulation(
self,
training_service,
test_db_session,
mock_messaging
):
"""Test extended training process simulation"""
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="long-running-job",
config={"extended_test": True}
)
# Simulate progress over time
progress_steps = [
(10, "Initializing"),
(25, "Loading data"),
(50, "Training models"),
(75, "Validating results"),
(90, "Storing models"),
(100, "Completed")
]
for progress, step in progress_steps:
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="running" if progress < 100 else "completed",
progress=progress,
current_step=step
)
# Small delay to simulate real progression
await asyncio.sleep(0.01)
# Verify final state
final_job = await training_service.get_job_status(
db=test_db_session,
job_id=job.job_id,
tenant_id="test-tenant"
)
assert final_job.status == "completed"
assert final_job.progress == 100
assert final_job.current_step == "Completed"
@pytest.mark.asyncio
async def test_memory_usage_stability(
self,
training_service,
test_db_session
):
"""Test memory usage stability over many operations"""
# Create many jobs to test memory stability
for i in range(50):
job = await training_service.create_training_job(
db=test_db_session,
tenant_id=f"tenant-{i % 5}", # 5 different tenants
job_id=f"memory-test-job-{i}",
config={"iteration": i}
)
# Update status
await training_service._update_job_status(
db=test_db_session,
job_id=job.job_id,
status="completed",
progress=100,
current_step="Completed"
)
# List jobs for each tenant
for tenant_i in range(5):
tenant_id = f"tenant-{tenant_i}"
jobs = await training_service.list_training_jobs(
db=test_db_session,
tenant_id=tenant_id,
limit=20
)
# Should have 10 jobs per tenant (50 total / 5 tenants)
assert len(jobs) == 10
class TestBackwardCompatibility:
"""Test backward compatibility with existing systems"""
@pytest.mark.asyncio
async def test_legacy_config_handling(
self,
training_service,
test_db_session
):
"""Test handling of legacy configuration formats"""
# Test with old-style configuration
legacy_config = {
"weather_enabled": True, # Old key
"traffic_enabled": True, # Old key
"minimum_samples": 30, # Old key
"prophet_config": { # Old nested structure
"seasonality": "additive"
}
}
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="legacy-config-job",
config=legacy_config
)
assert job.config == legacy_config
assert job.job_id == "legacy-config-job"
@pytest.mark.asyncio
async def test_api_version_compatibility(
self,
test_client: AsyncClient
):
"""Test API version compatibility"""
# Test with minimal request (old API style)
minimal_request = {
"include_weather": True
}
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
response = await test_client.post("/training/jobs", json=minimal_request)
# Should work with defaults for missing fields
assert response.status_code == 200
data = response.json()
assert "job_id" in data
# Utility functions for integration tests
async def wait_for_condition(condition_func, timeout=5.0, interval=0.1):
"""Wait for a condition to become true"""
import time
start_time = time.time()
while time.time() - start_time < timeout:
if await condition_func():
return True
await asyncio.sleep(interval)
return False
def assert_job_progression(job_updates):
"""Assert that job updates show proper progression"""
assert len(job_updates) > 0
# Check progress is non-decreasing
for i in range(1, len(job_updates)):
assert job_updates[i]["progress"] >= job_updates[i-1]["progress"]
# Check final status
final_update = job_updates[-1]
assert final_update["status"] in ["completed", "failed", "cancelled"]
def assert_valid_job_structure(job_data):
"""Assert job data has valid structure"""
required_fields = ["job_id", "status", "tenant_id"]
for field in required_fields:
assert field in job_data
assert isinstance(job_data["progress"], int)
assert 0 <= job_data["progress"] <= 100
assert job_data["status"] in ["pending", "running", "completed", "failed", "cancelled"]

View File

@@ -1,467 +0,0 @@
# services/training/tests/test_messaging.py
"""
Tests for training service messaging functionality
"""
import pytest
from unittest.mock import AsyncMock, Mock, patch
import json
from app.services import messaging
class TestTrainingMessaging:
"""Test training service messaging functions"""
@pytest.fixture
def mock_publisher(self):
"""Mock the RabbitMQ publisher"""
with patch('app.services.messaging.training_publisher') as mock_pub:
mock_pub.publish_event = AsyncMock(return_value=True)
mock_pub.connect = AsyncMock(return_value=True)
mock_pub.disconnect = AsyncMock(return_value=None)
yield mock_pub
@pytest.mark.asyncio
async def test_setup_messaging_success(self, mock_publisher):
"""Test successful messaging setup"""
await messaging.setup_messaging()
mock_publisher.connect.assert_called_once()
@pytest.mark.asyncio
async def test_setup_messaging_failure(self, mock_publisher):
"""Test messaging setup failure"""
mock_publisher.connect.return_value = False
await messaging.setup_messaging()
mock_publisher.connect.assert_called_once()
@pytest.mark.asyncio
async def test_cleanup_messaging(self, mock_publisher):
"""Test messaging cleanup"""
await messaging.cleanup_messaging()
mock_publisher.disconnect.assert_called_once()
@pytest.mark.asyncio
async def test_publish_job_started(self, mock_publisher):
"""Test publishing job started event"""
job_id = "test-job-123"
tenant_id = "test-tenant"
config = {"include_weather": True}
result = await messaging.publish_job_started(job_id, tenant_id, config)
assert result is True
mock_publisher.publish_event.assert_called_once()
# Check call arguments
call_args = mock_publisher.publish_event.call_args
assert call_args[1]["exchange_name"] == "training.events"
assert call_args[1]["routing_key"] == "training.started"
event_data = call_args[1]["event_data"]
assert event_data["service_name"] == "training-service"
assert event_data["data"]["job_id"] == job_id
assert event_data["data"]["tenant_id"] == tenant_id
assert event_data["data"]["config"] == config
@pytest.mark.asyncio
async def test_publish_job_progress(self, mock_publisher):
"""Test publishing job progress event"""
job_id = "test-job-123"
tenant_id = "test-tenant"
progress = 50
step = "Training models"
result = await messaging.publish_job_progress(job_id, tenant_id, progress, step)
assert result is True
mock_publisher.publish_event.assert_called_once()
call_args = mock_publisher.publish_event.call_args
assert call_args[1]["routing_key"] == "training.progress"
event_data = call_args[1]["event_data"]
assert event_data["data"]["progress"] == progress
assert event_data["data"]["current_step"] == step
@pytest.mark.asyncio
async def test_publish_job_completed(self, mock_publisher):
"""Test publishing job completed event"""
job_id = "test-job-123"
tenant_id = "test-tenant"
results = {
"products_trained": 3,
"summary": {"success_rate": 100.0}
}
result = await messaging.publish_job_completed(job_id, tenant_id, results)
assert result is True
mock_publisher.publish_event.assert_called_once()
call_args = mock_publisher.publish_event.call_args
assert call_args[1]["routing_key"] == "training.completed"
event_data = call_args[1]["event_data"]
assert event_data["data"]["results"] == results
assert event_data["data"]["models_trained"] == 3
assert event_data["data"]["success_rate"] == 100.0
@pytest.mark.asyncio
async def test_publish_job_failed(self, mock_publisher):
"""Test publishing job failed event"""
job_id = "test-job-123"
tenant_id = "test-tenant"
error = "Data service unavailable"
result = await messaging.publish_job_failed(job_id, tenant_id, error)
assert result is True
mock_publisher.publish_event.assert_called_once()
call_args = mock_publisher.publish_event.call_args
assert call_args[1]["routing_key"] == "training.failed"
event_data = call_args[1]["event_data"]
assert event_data["data"]["error"] == error
@pytest.mark.asyncio
async def test_publish_job_cancelled(self, mock_publisher):
"""Test publishing job cancelled event"""
job_id = "test-job-123"
tenant_id = "test-tenant"
result = await messaging.publish_job_cancelled(job_id, tenant_id)
assert result is True
mock_publisher.publish_event.assert_called_once()
call_args = mock_publisher.publish_event.call_args
assert call_args[1]["routing_key"] == "training.cancelled"
@pytest.mark.asyncio
async def test_publish_product_training_started(self, mock_publisher):
"""Test publishing product training started event"""
job_id = "test-product-job-123"
tenant_id = "test-tenant"
product_name = "Pan Integral"
result = await messaging.publish_product_training_started(job_id, tenant_id, product_name)
assert result is True
mock_publisher.publish_event.assert_called_once()
call_args = mock_publisher.publish_event.call_args
assert call_args[1]["routing_key"] == "training.product.started"
event_data = call_args[1]["event_data"]
assert event_data["data"]["product_name"] == product_name
@pytest.mark.asyncio
async def test_publish_product_training_completed(self, mock_publisher):
"""Test publishing product training completed event"""
job_id = "test-product-job-123"
tenant_id = "test-tenant"
product_name = "Pan Integral"
model_id = "test-model-123"
result = await messaging.publish_product_training_completed(
job_id, tenant_id, product_name, model_id
)
assert result is True
mock_publisher.publish_event.assert_called_once()
call_args = mock_publisher.publish_event.call_args
assert call_args[1]["routing_key"] == "training.product.completed"
event_data = call_args[1]["event_data"]
assert event_data["data"]["model_id"] == model_id
assert event_data["data"]["product_name"] == product_name
@pytest.mark.asyncio
async def test_publish_model_trained(self, mock_publisher):
"""Test publishing model trained event"""
model_id = "test-model-123"
tenant_id = "test-tenant"
product_name = "Pan Integral"
metrics = {"mae": 5.2, "rmse": 7.8, "mape": 12.5}
result = await messaging.publish_model_trained(model_id, tenant_id, product_name, metrics)
assert result is True
mock_publisher.publish_event.assert_called_once()
call_args = mock_publisher.publish_event.call_args
assert call_args[1]["routing_key"] == "training.model.trained"
event_data = call_args[1]["event_data"]
assert event_data["data"]["training_metrics"] == metrics
@pytest.mark.asyncio
async def test_publish_model_updated(self, mock_publisher):
"""Test publishing model updated event"""
model_id = "test-model-123"
tenant_id = "test-tenant"
product_name = "Pan Integral"
version = 2
result = await messaging.publish_model_updated(model_id, tenant_id, product_name, version)
assert result is True
mock_publisher.publish_event.assert_called_once()
call_args = mock_publisher.publish_event.call_args
assert call_args[1]["routing_key"] == "training.model.updated"
event_data = call_args[1]["event_data"]
assert event_data["data"]["version"] == version
@pytest.mark.asyncio
async def test_publish_model_validated(self, mock_publisher):
"""Test publishing model validated event"""
model_id = "test-model-123"
tenant_id = "test-tenant"
product_name = "Pan Integral"
validation_results = {"is_valid": True, "accuracy": 0.95}
result = await messaging.publish_model_validated(
model_id, tenant_id, product_name, validation_results
)
assert result is True
mock_publisher.publish_event.assert_called_once()
call_args = mock_publisher.publish_event.call_args
assert call_args[1]["routing_key"] == "training.model.validated"
event_data = call_args[1]["event_data"]
assert event_data["data"]["validation_results"] == validation_results
@pytest.mark.asyncio
async def test_publish_model_saved(self, mock_publisher):
"""Test publishing model saved event"""
model_id = "test-model-123"
tenant_id = "test-tenant"
product_name = "Pan Integral"
model_path = "/models/test-model-123.pkl"
result = await messaging.publish_model_saved(model_id, tenant_id, product_name, model_path)
assert result is True
mock_publisher.publish_event.assert_called_once()
call_args = mock_publisher.publish_event.call_args
assert call_args[1]["routing_key"] == "training.model.saved"
event_data = call_args[1]["event_data"]
assert event_data["data"]["model_path"] == model_path
class TestMessagingErrorHandling:
"""Test error handling in messaging"""
@pytest.fixture
def failing_publisher(self):
"""Mock publisher that fails"""
with patch('app.services.messaging.training_publisher') as mock_pub:
mock_pub.publish_event = AsyncMock(return_value=False)
mock_pub.connect = AsyncMock(return_value=False)
yield mock_pub
@pytest.mark.asyncio
async def test_publish_event_failure(self, failing_publisher):
"""Test handling of publish event failure"""
result = await messaging.publish_job_started("job-123", "tenant-123", {})
assert result is False
failing_publisher.publish_event.assert_called_once()
@pytest.mark.asyncio
async def test_setup_messaging_connection_failure(self, failing_publisher):
"""Test setup with connection failure"""
await messaging.setup_messaging()
failing_publisher.connect.assert_called_once()
@pytest.mark.asyncio
async def test_publish_with_exception(self):
"""Test publishing with exception"""
with patch('app.services.messaging.training_publisher') as mock_pub:
mock_pub.publish_event.side_effect = Exception("Connection lost")
result = await messaging.publish_job_started("job-123", "tenant-123", {})
assert result is False
class TestMessagingIntegration:
"""Test messaging integration with shared components"""
@pytest.mark.asyncio
async def test_event_structure_consistency(self):
"""Test that events follow consistent structure"""
with patch('app.services.messaging.training_publisher') as mock_pub:
mock_pub.publish_event = AsyncMock(return_value=True)
# Test different event types
await messaging.publish_job_started("job-123", "tenant-123", {})
await messaging.publish_job_completed("job-123", "tenant-123", {})
await messaging.publish_model_trained("model-123", "tenant-123", "Pan", {})
# Verify all calls have consistent structure
assert mock_pub.publish_event.call_count == 3
for call in mock_pub.publish_event.call_args_list:
event_data = call[1]["event_data"]
# All events should have these fields
assert "service_name" in event_data
assert "event_type" in event_data
assert "data" in event_data
assert event_data["service_name"] == "training-service"
@pytest.mark.asyncio
async def test_shared_event_classes_usage(self):
"""Test that shared event classes are used properly"""
with patch('shared.messaging.events.TrainingStartedEvent') as mock_event_class:
mock_event = Mock()
mock_event.to_dict.return_value = {
"service_name": "training-service",
"event_type": "training.started",
"data": {"job_id": "test-job"}
}
mock_event_class.return_value = mock_event
with patch('app.services.messaging.training_publisher') as mock_pub:
mock_pub.publish_event = AsyncMock(return_value=True)
await messaging.publish_job_started("test-job", "test-tenant", {})
# Verify shared event class was used
mock_event_class.assert_called_once()
mock_event.to_dict.assert_called_once()
@pytest.mark.asyncio
async def test_routing_key_consistency(self):
"""Test that routing keys follow consistent patterns"""
with patch('app.services.messaging.training_publisher') as mock_pub:
mock_pub.publish_event = AsyncMock(return_value=True)
# Test various event types
events_and_keys = [
(messaging.publish_job_started, "training.started"),
(messaging.publish_job_progress, "training.progress"),
(messaging.publish_job_completed, "training.completed"),
(messaging.publish_job_failed, "training.failed"),
(messaging.publish_job_cancelled, "training.cancelled"),
(messaging.publish_product_training_started, "training.product.started"),
(messaging.publish_product_training_completed, "training.product.completed"),
(messaging.publish_model_trained, "training.model.trained"),
(messaging.publish_model_updated, "training.model.updated"),
(messaging.publish_model_validated, "training.model.validated"),
(messaging.publish_model_saved, "training.model.saved")
]
for event_func, expected_key in events_and_keys:
mock_pub.reset_mock()
# Call event function with appropriate parameters
if "progress" in expected_key:
await event_func("job-123", "tenant-123", 50, "step")
elif "model" in expected_key and "trained" in expected_key:
await event_func("model-123", "tenant-123", "product", {})
elif "model" in expected_key and "updated" in expected_key:
await event_func("model-123", "tenant-123", "product", 1)
elif "model" in expected_key and "validated" in expected_key:
await event_func("model-123", "tenant-123", "product", {})
elif "model" in expected_key and "saved" in expected_key:
await event_func("model-123", "tenant-123", "product", "/path")
elif "product" in expected_key and "completed" in expected_key:
await event_func("job-123", "tenant-123", "product", "model-123")
elif "product" in expected_key:
await event_func("job-123", "tenant-123", "product")
elif "failed" in expected_key:
await event_func("job-123", "tenant-123", "error")
elif "cancelled" in expected_key:
await event_func("job-123", "tenant-123")
else:
await event_func("job-123", "tenant-123", {})
# Verify routing key
call_args = mock_pub.publish_event.call_args
assert call_args[1]["routing_key"] == expected_key
@pytest.mark.asyncio
async def test_exchange_consistency(self):
"""Test that all events use the same exchange"""
with patch('app.services.messaging.training_publisher') as mock_pub:
mock_pub.publish_event = AsyncMock(return_value=True)
# Test multiple events
await messaging.publish_job_started("job-123", "tenant-123", {})
await messaging.publish_model_trained("model-123", "tenant-123", "product", {})
await messaging.publish_product_training_started("job-123", "tenant-123", "product")
# Verify all use same exchange
for call in mock_pub.publish_event.call_args_list:
assert call[1]["exchange_name"] == "training.events"
class TestMessagingPerformance:
"""Test messaging performance and reliability"""
@pytest.mark.asyncio
async def test_concurrent_publishing(self):
"""Test concurrent event publishing"""
import asyncio
with patch('app.services.messaging.training_publisher') as mock_pub:
mock_pub.publish_event = AsyncMock(return_value=True)
# Create multiple concurrent publishing tasks
tasks = []
for i in range(10):
task = messaging.publish_job_progress(f"job-{i}", "tenant-123", i * 10, f"step-{i}")
tasks.append(task)
# Execute all tasks concurrently
results = await asyncio.gather(*tasks)
# Verify all succeeded
assert all(results)
assert mock_pub.publish_event.call_count == 10
@pytest.mark.asyncio
async def test_large_event_data(self):
"""Test publishing events with large data payloads"""
with patch('app.services.messaging.training_publisher') as mock_pub:
mock_pub.publish_event = AsyncMock(return_value=True)
# Create large config data
large_config = {
"products": [f"Product-{i}" for i in range(1000)],
"features": [f"feature-{i}" for i in range(100)],
"hyperparameters": {f"param-{i}": i for i in range(50)}
}
result = await messaging.publish_job_started("job-123", "tenant-123", large_config)
assert result is True
mock_pub.publish_event.assert_called_once()
@pytest.mark.asyncio
async def test_rapid_sequential_publishing(self):
"""Test rapid sequential event publishing"""
with patch('app.services.messaging.training_publisher') as mock_pub:
mock_pub.publish_event = AsyncMock(return_value=True)
# Publish many events in sequence
for i in range(100):
await messaging.publish_job_progress("job-123", "tenant-123", i, f"step-{i}")
assert mock_pub.publish_event.call_count == 100

View File

@@ -1,630 +0,0 @@
# ================================================================
# services/training/tests/test_performance.py
# ================================================================
"""
Performance and Load Testing for Training Service
Tests training performance with real-world data volumes
"""
import pytest
import asyncio
import pandas as pd
import numpy as np
import time
from datetime import datetime, timedelta
from concurrent.futures import ThreadPoolExecutor
import psutil
import gc
from typing import List, Dict, Any
import logging
from app.ml.trainer import BakeryMLTrainer
from app.ml.data_processor import BakeryDataProcessor
from app.services.training_service import TrainingService
class TestTrainingPerformance:
"""Performance tests for training service components"""
@pytest.fixture
def large_sales_dataset(self):
"""Generate large dataset for performance testing (2 years of data)"""
start_date = datetime(2022, 1, 1)
end_date = datetime(2024, 1, 1)
date_range = pd.date_range(start=start_date, end=end_date, freq='D')
products = [
"Pan Integral", "Pan Blanco", "Croissant", "Magdalenas",
"Empanadas", "Tarta Chocolate", "Roscon Reyes", "Palmeras",
"Donuts", "Berlinas", "Napolitanas", "Ensaimadas"
]
data = []
for date in date_range:
for product in products:
# Realistic sales simulation
base_quantity = np.random.randint(5, 150)
# Seasonal patterns
if date.month in [12, 1]: # Winter/Holiday season
base_quantity *= 1.4
elif date.month in [6, 7, 8]: # Summer
base_quantity *= 0.8
# Weekly patterns
if date.weekday() >= 5: # Weekends
base_quantity *= 1.2
elif date.weekday() == 0: # Monday
base_quantity *= 0.7
# Add noise
quantity = max(1, int(base_quantity + np.random.normal(0, base_quantity * 0.1)))
data.append({
"date": date.strftime("%Y-%m-%d"),
"product": product,
"quantity": quantity,
"revenue": round(quantity * np.random.uniform(1.5, 8.0), 2),
"temperature": round(15 + 12 * np.sin((date.timetuple().tm_yday / 365) * 2 * np.pi) + np.random.normal(0, 3), 1),
"precipitation": max(0, np.random.exponential(0.8)),
"is_weekend": date.weekday() >= 5,
"is_holiday": self._is_spanish_holiday(date)
})
return pd.DataFrame(data)
def _is_spanish_holiday(self, date: datetime) -> bool:
"""Check if date is a Spanish holiday"""
holidays = [
(1, 1), # New Year
(1, 6), # Epiphany
(5, 1), # Labor Day
(8, 15), # Assumption
(10, 12), # National Day
(11, 1), # All Saints
(12, 6), # Constitution Day
(12, 8), # Immaculate Conception
(12, 25), # Christmas
]
return (date.month, date.day) in holidays
@pytest.mark.asyncio
async def test_single_product_training_performance(self, large_sales_dataset):
"""Test performance of single product training with large dataset"""
trainer = BakeryMLTrainer()
product_data = large_sales_dataset[large_sales_dataset['product'] == 'Pan Integral'].copy()
# Measure memory before training
process = psutil.Process()
memory_before = process.memory_info().rss / 1024 / 1024 # MB
start_time = time.time()
result = await trainer.train_single_product(
tenant_id="perf_test_tenant",
product_name="Pan Integral",
sales_data=product_data,
config={
"include_weather": True,
"include_traffic": False, # Skip traffic for performance
"seasonality_mode": "additive"
}
)
end_time = time.time()
training_duration = end_time - start_time
# Measure memory after training
memory_after = process.memory_info().rss / 1024 / 1024 # MB
memory_used = memory_after - memory_before
# Performance assertions
assert training_duration < 120, f"Training took too long: {training_duration:.2f}s"
assert memory_used < 500, f"Memory usage too high: {memory_used:.2f}MB"
assert result['status'] == 'completed'
# Quality assertions
metrics = result['metrics']
assert metrics['mape'] < 50, f"MAPE too high: {metrics['mape']:.2f}%"
print(f"Performance Results:")
print(f" Training Duration: {training_duration:.2f}s")
print(f" Memory Used: {memory_used:.2f}MB")
print(f" Data Points: {len(product_data)}")
print(f" MAPE: {metrics['mape']:.2f}%")
print(f" RMSE: {metrics['rmse']:.2f}")
@pytest.mark.asyncio
async def test_concurrent_training_performance(self, large_sales_dataset):
"""Test performance of concurrent training jobs"""
trainer = BakeryMLTrainer()
products = ["Pan Integral", "Croissant", "Magdalenas"]
async def train_product(product_name: str):
"""Train a single product"""
product_data = large_sales_dataset[large_sales_dataset['product'] == product_name].copy()
start_time = time.time()
result = await trainer.train_single_product(
tenant_id=f"concurrent_test_{product_name.replace(' ', '_').lower()}",
product_name=product_name,
sales_data=product_data,
config={"include_weather": True, "include_traffic": False}
)
end_time = time.time()
return {
'product': product_name,
'duration': end_time - start_time,
'status': result['status'],
'metrics': result.get('metrics', {})
}
# Run concurrent training
start_time = time.time()
tasks = [train_product(product) for product in products]
results = await asyncio.gather(*tasks)
total_time = time.time() - start_time
# Verify all trainings completed
for result in results:
assert result['status'] == 'completed'
assert result['duration'] < 120 # Individual training time
# Concurrent execution should be faster than sequential
sequential_time_estimate = sum(r['duration'] for r in results)
efficiency = sequential_time_estimate / total_time
assert efficiency > 1.5, f"Concurrency efficiency too low: {efficiency:.2f}x"
print(f"Concurrent Training Results:")
print(f" Total Time: {total_time:.2f}s")
print(f" Sequential Estimate: {sequential_time_estimate:.2f}s")
print(f" Efficiency: {efficiency:.2f}x")
for result in results:
print(f" {result['product']}: {result['duration']:.2f}s, MAPE: {result['metrics'].get('mape', 'N/A'):.2f}%")
@pytest.mark.asyncio
async def test_data_processing_scalability(self, large_sales_dataset):
"""Test data processing performance with increasing data sizes"""
data_processor = BakeryDataProcessor()
# Test with different data sizes
data_sizes = [1000, 5000, 10000, 20000, len(large_sales_dataset)]
performance_results = []
for size in data_sizes:
# Take a sample of the specified size
sample_data = large_sales_dataset.head(size).copy()
start_time = time.time()
# Process the data
processed_data = await data_processor.prepare_training_data(
sales_data=sample_data,
include_weather=True,
include_traffic=True,
tenant_id="scalability_test",
product_name="Pan Integral"
)
processing_time = time.time() - start_time
performance_results.append({
'data_size': size,
'processing_time': processing_time,
'processed_rows': len(processed_data),
'throughput': size / processing_time if processing_time > 0 else 0
})
# Verify linear or sub-linear scaling
for i in range(1, len(performance_results)):
prev_result = performance_results[i-1]
curr_result = performance_results[i]
size_ratio = curr_result['data_size'] / prev_result['data_size']
time_ratio = curr_result['processing_time'] / prev_result['processing_time']
# Processing time should scale better than linearly
assert time_ratio < size_ratio * 1.5, f"Poor scaling at size {curr_result['data_size']}"
print("Data Processing Scalability Results:")
for result in performance_results:
print(f" Size: {result['data_size']:,} rows, Time: {result['processing_time']:.2f}s, "
f"Throughput: {result['throughput']:.0f} rows/s")
@pytest.mark.asyncio
async def test_memory_usage_optimization(self, large_sales_dataset):
"""Test memory usage optimization during training"""
trainer = BakeryMLTrainer()
process = psutil.Process()
# Baseline memory
gc.collect() # Force garbage collection
baseline_memory = process.memory_info().rss / 1024 / 1024 # MB
memory_snapshots = [{'stage': 'baseline', 'memory_mb': baseline_memory}]
# Load data
product_data = large_sales_dataset[large_sales_dataset['product'] == 'Pan Integral'].copy()
current_memory = process.memory_info().rss / 1024 / 1024
memory_snapshots.append({'stage': 'data_loaded', 'memory_mb': current_memory})
# Train model
result = await trainer.train_single_product(
tenant_id="memory_test_tenant",
product_name="Pan Integral",
sales_data=product_data,
config={"include_weather": True, "include_traffic": True}
)
current_memory = process.memory_info().rss / 1024 / 1024
memory_snapshots.append({'stage': 'model_trained', 'memory_mb': current_memory})
# Cleanup
del product_data
del result
gc.collect()
final_memory = process.memory_info().rss / 1024 / 1024
memory_snapshots.append({'stage': 'cleanup', 'memory_mb': final_memory})
# Memory assertions
peak_memory = max(snapshot['memory_mb'] for snapshot in memory_snapshots)
memory_increase = peak_memory - baseline_memory
memory_after_cleanup = final_memory - baseline_memory
assert memory_increase < 800, f"Peak memory increase too high: {memory_increase:.2f}MB"
assert memory_after_cleanup < 100, f"Memory not properly cleaned up: {memory_after_cleanup:.2f}MB"
print("Memory Usage Analysis:")
for snapshot in memory_snapshots:
print(f" {snapshot['stage']}: {snapshot['memory_mb']:.2f}MB")
print(f" Peak increase: {memory_increase:.2f}MB")
print(f" After cleanup: {memory_after_cleanup:.2f}MB")
@pytest.mark.asyncio
async def test_training_service_throughput(self, large_sales_dataset):
"""Test training service throughput with multiple requests"""
training_service = TrainingService()
# Simulate multiple training requests
num_requests = 5
products = ["Pan Integral", "Croissant", "Magdalenas", "Empanadas", "Tarta Chocolate"]
async def execute_training_request(request_id: int, product: str):
"""Execute a single training request"""
product_data = large_sales_dataset[large_sales_dataset['product'] == product].copy()
with patch.object(training_service, '_fetch_sales_data', return_value=product_data):
start_time = time.time()
result = await training_service.execute_training_job(
db=None, # Mock DB session
tenant_id=f"throughput_test_tenant_{request_id}",
job_id=f"job_{request_id}_{product.replace(' ', '_').lower()}",
request={
'products': [product],
'include_weather': True,
'include_traffic': False,
'config': {'seasonality_mode': 'additive'}
}
)
duration = time.time() - start_time
return {
'request_id': request_id,
'product': product,
'duration': duration,
'status': result.get('status', 'unknown'),
'models_trained': len(result.get('models_trained', []))
}
# Execute requests concurrently
start_time = time.time()
tasks = [
execute_training_request(i, products[i % len(products)])
for i in range(num_requests)
]
results = await asyncio.gather(*tasks)
total_time = time.time() - start_time
# Calculate throughput metrics
successful_requests = sum(1 for r in results if r['status'] == 'completed')
throughput = successful_requests / total_time # requests per second
# Performance assertions
assert successful_requests >= num_requests * 0.8, "Too many failed requests"
assert throughput >= 0.1, f"Throughput too low: {throughput:.3f} req/s"
assert total_time < 300, f"Total time too long: {total_time:.2f}s"
print(f"Training Service Throughput Results:")
print(f" Total Requests: {num_requests}")
print(f" Successful: {successful_requests}")
print(f" Total Time: {total_time:.2f}s")
print(f" Throughput: {throughput:.3f} req/s")
print(f" Average Request Time: {total_time/num_requests:.2f}s")
@pytest.mark.asyncio
async def test_large_dataset_edge_cases(self, large_sales_dataset):
"""Test handling of edge cases with large datasets"""
data_processor = BakeryDataProcessor()
# Test 1: Dataset with many missing values
corrupted_data = large_sales_dataset.copy()
# Introduce 30% missing values randomly
mask = np.random.random(len(corrupted_data)) < 0.3
corrupted_data.loc[mask, 'quantity'] = np.nan
start_time = time.time()
result = await data_processor.validate_data_quality(corrupted_data)
validation_time = time.time() - start_time
assert validation_time < 10, f"Validation too slow: {validation_time:.2f}s"
assert result['is_valid'] is False
assert 'high_missing_data' in result['issues']
# Test 2: Dataset with extreme outliers
outlier_data = large_sales_dataset.copy()
# Add extreme outliers (100x normal values)
outlier_indices = np.random.choice(len(outlier_data), size=int(len(outlier_data) * 0.01), replace=False)
outlier_data.loc[outlier_indices, 'quantity'] *= 100
start_time = time.time()
cleaned_data = await data_processor.clean_outliers(outlier_data)
cleaning_time = time.time() - start_time
assert cleaning_time < 15, f"Outlier cleaning too slow: {cleaning_time:.2f}s"
assert len(cleaned_data) > len(outlier_data) * 0.95 # Should retain most data
# Test 3: Very sparse data (many products with few sales)
sparse_data = large_sales_dataset.copy()
# Keep only 10% of data for each product randomly
sparse_data = sparse_data.groupby('product').apply(
lambda x: x.sample(n=max(1, int(len(x) * 0.1)))
).reset_index(drop=True)
start_time = time.time()
validation_result = await data_processor.validate_data_quality(sparse_data)
sparse_validation_time = time.time() - start_time
assert sparse_validation_time < 5, f"Sparse data validation too slow: {sparse_validation_time:.2f}s"
print("Edge Case Performance Results:")
print(f" Corrupted data validation: {validation_time:.2f}s")
print(f" Outlier cleaning: {cleaning_time:.2f}s")
print(f" Sparse data validation: {sparse_validation_time:.2f}s")
class TestTrainingServiceLoad:
"""Load testing for training service under stress"""
@pytest.mark.asyncio
async def test_sustained_load_training(self, large_sales_dataset):
"""Test training service under sustained load"""
trainer = BakeryMLTrainer()
# Define load test parameters
duration_minutes = 2 # Run for 2 minutes
requests_per_minute = 3
products = ["Pan Integral", "Croissant", "Magdalenas"]
async def sustained_training_worker(worker_id: int, duration: float):
"""Worker that continuously submits training requests"""
start_time = time.time()
completed_requests = 0
failed_requests = 0
while time.time() - start_time < duration:
try:
product = products[completed_requests % len(products)]
product_data = large_sales_dataset[
large_sales_dataset['product'] == product
].copy()
result = await trainer.train_single_product(
tenant_id=f"load_test_worker_{worker_id}",
product_name=product,
sales_data=product_data,
config={"include_weather": False, "include_traffic": False} # Minimal config for speed
)
if result['status'] == 'completed':
completed_requests += 1
else:
failed_requests += 1
except Exception as e:
failed_requests += 1
logging.error(f"Training request failed: {e}")
# Wait before next request
await asyncio.sleep(60 / requests_per_minute)
return {
'worker_id': worker_id,
'completed': completed_requests,
'failed': failed_requests,
'duration': time.time() - start_time
}
# Start multiple workers
num_workers = 2
duration_seconds = duration_minutes * 60
start_time = time.time()
tasks = [
sustained_training_worker(i, duration_seconds)
for i in range(num_workers)
]
results = await asyncio.gather(*tasks)
total_time = time.time() - start_time
# Analyze results
total_completed = sum(r['completed'] for r in results)
total_failed = sum(r['failed'] for r in results)
success_rate = total_completed / (total_completed + total_failed) if (total_completed + total_failed) > 0 else 0
# Performance assertions
assert success_rate >= 0.8, f"Success rate too low: {success_rate:.2%}"
assert total_completed >= duration_minutes * requests_per_minute * num_workers * 0.7, "Throughput too low"
print(f"Sustained Load Test Results:")
print(f" Duration: {total_time:.2f}s")
print(f" Workers: {num_workers}")
print(f" Completed Requests: {total_completed}")
print(f" Failed Requests: {total_failed}")
print(f" Success Rate: {success_rate:.2%}")
print(f" Average Throughput: {total_completed/total_time:.2f} req/s")
@pytest.mark.asyncio
async def test_resource_exhaustion_recovery(self, large_sales_dataset):
"""Test service recovery from resource exhaustion"""
trainer = BakeryMLTrainer()
# Simulate resource exhaustion by running many concurrent requests
num_concurrent = 10 # High concurrency to stress the system
async def resource_intensive_task(task_id: int):
"""Task designed to consume resources"""
try:
# Use all products to increase memory usage
all_products_data = large_sales_dataset.copy()
result = await trainer.train_tenant_models(
tenant_id=f"resource_test_{task_id}",
sales_data=all_products_data,
config={
"train_all_products": True,
"include_weather": True,
"include_traffic": True
}
)
return {'task_id': task_id, 'status': 'completed', 'error': None}
except Exception as e:
return {'task_id': task_id, 'status': 'failed', 'error': str(e)}
# Launch all tasks simultaneously
start_time = time.time()
tasks = [resource_intensive_task(i) for i in range(num_concurrent)]
results = await asyncio.gather(*tasks, return_exceptions=True)
duration = time.time() - start_time
# Analyze results
completed = sum(1 for r in results if isinstance(r, dict) and r['status'] == 'completed')
failed = sum(1 for r in results if isinstance(r, dict) and r['status'] == 'failed')
exceptions = sum(1 for r in results if isinstance(r, Exception))
# The system should handle some failures gracefully
# but should complete at least some requests
total_processed = completed + failed + exceptions
processing_rate = total_processed / num_concurrent
assert processing_rate >= 0.5, f"Too many requests not processed: {processing_rate:.2%}"
assert duration < 600, f"Recovery took too long: {duration:.2f}s" # 10 minutes max
print(f"Resource Exhaustion Test Results:")
print(f" Concurrent Requests: {num_concurrent}")
print(f" Completed: {completed}")
print(f" Failed: {failed}")
print(f" Exceptions: {exceptions}")
print(f" Duration: {duration:.2f}s")
print(f" Processing Rate: {processing_rate:.2%}")
# ================================================================
# BENCHMARK UTILITIES
# ================================================================
class PerformanceBenchmark:
"""Utility class for performance benchmarking"""
@staticmethod
def measure_execution_time(func):
"""Decorator to measure execution time"""
async def wrapper(*args, **kwargs):
start_time = time.time()
result = await func(*args, **kwargs)
execution_time = time.time() - start_time
if hasattr(result, 'update') and isinstance(result, dict):
result['execution_time'] = execution_time
return result
return wrapper
@staticmethod
def memory_profiler(func):
"""Decorator to profile memory usage"""
async def wrapper(*args, **kwargs):
process = psutil.Process()
# Memory before
gc.collect()
memory_before = process.memory_info().rss / 1024 / 1024
result = await func(*args, **kwargs)
# Memory after
memory_after = process.memory_info().rss / 1024 / 1024
memory_used = memory_after - memory_before
if hasattr(result, 'update') and isinstance(result, dict):
result['memory_used_mb'] = memory_used
return result
return wrapper
# ================================================================
# STANDALONE EXECUTION
# ================================================================
if __name__ == "__main__":
"""
Run performance tests as standalone script
Usage: python test_performance.py
"""
import sys
import os
from unittest.mock import patch
# Add the training service root to Python path
training_service_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, training_service_root)
print("=" * 60)
print("TRAINING SERVICE PERFORMANCE TEST SUITE")
print("=" * 60)
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Run performance tests
pytest.main([
__file__,
"-v",
"--tb=short",
"-s", # Don't capture output
"--durations=10", # Show 10 slowest tests
"-m", "not slow", # Skip slow tests unless specifically requested
])
print("\n" + "=" * 60)
print("PERFORMANCE TESTING COMPLETE")
print("=" * 60)

View File

@@ -1,688 +0,0 @@
# services/training/tests/test_service.py
"""
Tests for training service business logic layer
"""
import pytest
from unittest.mock import AsyncMock, Mock, patch
from datetime import datetime, timedelta
import httpx
from app.services.training_service import TrainingService
from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
from app.models.training import ModelTrainingLog, TrainedModel
class TestTrainingService:
"""Test the training service business logic"""
@pytest.fixture
def training_service(self, mock_ml_trainer):
return TrainingService()
@pytest.mark.asyncio
async def test_create_training_job_success(
self,
training_service,
test_db_session
):
"""Test successful training job creation"""
job_id = "test-job-123"
tenant_id = "test-tenant"
config = {"include_weather": True, "include_traffic": True}
result = await training_service.create_training_job(
db=test_db_session,
tenant_id=tenant_id,
job_id=job_id,
config=config
)
assert isinstance(result, ModelTrainingLog)
assert result.job_id == job_id
assert result.tenant_id == tenant_id
assert result.status == "pending"
assert result.progress == 0
assert result.config == config
@pytest.mark.asyncio
async def test_create_single_product_job_success(
self,
training_service,
test_db_session
):
"""Test successful single product job creation"""
job_id = "test-product-job-123"
tenant_id = "test-tenant"
product_name = "Pan Integral"
config = {"include_weather": True}
result = await training_service.create_single_product_job(
db=test_db_session,
tenant_id=tenant_id,
product_name=product_name,
job_id=job_id,
config=config
)
assert isinstance(result, ModelTrainingLog)
assert result.job_id == job_id
assert result.tenant_id == tenant_id
assert result.config["single_product"] == product_name
assert f"Initializing training for {product_name}" in result.current_step
@pytest.mark.asyncio
async def test_get_job_status_existing(
self,
training_service,
test_db_session,
training_job_in_db
):
"""Test getting status of existing job"""
result = await training_service.get_job_status(
db=test_db_session,
job_id=training_job_in_db.job_id,
tenant_id=training_job_in_db.tenant_id
)
assert result is not None
assert result.job_id == training_job_in_db.job_id
assert result.status == training_job_in_db.status
@pytest.mark.asyncio
async def test_get_job_status_nonexistent(
self,
training_service,
test_db_session
):
"""Test getting status of non-existent job"""
result = await training_service.get_job_status(
db=test_db_session,
job_id="nonexistent-job",
tenant_id="test-tenant"
)
assert result is None
@pytest.mark.asyncio
async def test_list_training_jobs(
self,
training_service,
test_db_session,
training_job_in_db
):
"""Test listing training jobs"""
result = await training_service.list_training_jobs(
db=test_db_session,
tenant_id=training_job_in_db.tenant_id,
limit=10
)
assert isinstance(result, list)
assert len(result) >= 1
assert result[0].job_id == training_job_in_db.job_id
@pytest.mark.asyncio
async def test_list_training_jobs_with_filter(
self,
training_service,
test_db_session,
training_job_in_db
):
"""Test listing training jobs with status filter"""
result = await training_service.list_training_jobs(
db=test_db_session,
tenant_id=training_job_in_db.tenant_id,
limit=10,
status_filter="pending"
)
assert isinstance(result, list)
for job in result:
assert job.status == "pending"
@pytest.mark.asyncio
async def test_cancel_training_job_success(
self,
training_service,
test_db_session,
training_job_in_db
):
"""Test successful job cancellation"""
result = await training_service.cancel_training_job(
db=test_db_session,
job_id=training_job_in_db.job_id,
tenant_id=training_job_in_db.tenant_id
)
assert result is True
# Verify status was updated
updated_job = await training_service.get_job_status(
db=test_db_session,
job_id=training_job_in_db.job_id,
tenant_id=training_job_in_db.tenant_id
)
assert updated_job.status == "cancelled"
@pytest.mark.asyncio
async def test_cancel_nonexistent_job(
self,
training_service,
test_db_session
):
"""Test cancelling non-existent job"""
result = await training_service.cancel_training_job(
db=test_db_session,
job_id="nonexistent-job",
tenant_id="test-tenant"
)
assert result is False
@pytest.mark.asyncio
async def test_validate_training_data_valid(
self,
training_service,
test_db_session,
mock_data_service
):
"""Test validation with valid data"""
config = {"min_data_points": 30}
result = await training_service.validate_training_data(
db=test_db_session,
tenant_id="test-tenant",
config=config
)
assert isinstance(result, dict)
assert "is_valid" in result
assert "issues" in result
assert "recommendations" in result
assert "estimated_time_minutes" in result
@pytest.mark.asyncio
async def test_validate_training_data_no_data(
self,
training_service,
test_db_session
):
"""Test validation with no data"""
config = {"min_data_points": 30}
with patch('app.services.training_service.TrainingService._fetch_sales_data', return_value=AsyncMock(return_value=[])):
result = await training_service.validate_training_data(
db=test_db_session,
tenant_id="test-tenant",
config=config
)
assert result["is_valid"] is False
assert "No sales data found" in result["issues"][0]
@pytest.mark.asyncio
async def test_update_job_status(
self,
training_service,
test_db_session,
training_job_in_db
):
"""Test updating job status"""
await training_service._update_job_status(
db=test_db_session,
job_id=training_job_in_db.job_id,
status="running",
progress=50,
current_step="Training models"
)
# Verify update
updated_job = await training_service.get_job_status(
db=test_db_session,
job_id=training_job_in_db.job_id,
tenant_id=training_job_in_db.tenant_id
)
assert updated_job.status == "running"
assert updated_job.progress == 50
assert updated_job.current_step == "Training models"
@pytest.mark.asyncio
async def test_store_trained_models(
self,
training_service,
test_db_session
):
"""Test storing trained models"""
tenant_id = "test-tenant"
training_results = {
"training_results": {
"Pan Integral": {
"status": "success",
"model_info": {
"model_id": "test-model-123",
"model_path": "/test/models/test-model-123.pkl",
"type": "prophet",
"training_samples": 100,
"features": ["temperature", "humidity"],
"hyperparameters": {"seasonality_mode": "additive"},
"training_metrics": {"mae": 5.2, "rmse": 7.8},
"data_period": {
"start_date": "2024-01-01T00:00:00",
"end_date": "2024-01-31T00:00:00"
}
}
}
}
}
await training_service._store_trained_models(
db=test_db_session,
tenant_id=tenant_id,
training_results=training_results
)
# Verify model was stored
from sqlalchemy import select
result = await test_db_session.execute(
select(TrainedModel).where(
TrainedModel.tenant_id == tenant_id,
TrainedModel.product_name == "Pan Integral"
)
)
stored_model = result.scalar_one_or_none()
assert stored_model is not None
assert stored_model.model_id == "test-model-123"
assert stored_model.is_active is True
@pytest.mark.asyncio
async def test_get_training_logs(
self,
training_service,
test_db_session,
training_job_in_db
):
"""Test getting training logs"""
result = await training_service.get_training_logs(
db=test_db_session,
job_id=training_job_in_db.job_id,
tenant_id=training_job_in_db.tenant_id
)
assert isinstance(result, list)
assert len(result) > 0
# Check log content
log_text = " ".join(result)
assert training_job_in_db.job_id in log_text or "Job started" in log_text
class TestTrainingServiceDataFetching:
"""Test external data fetching functionality"""
@pytest.fixture
def training_service(self):
return TrainingService()
@pytest.mark.asyncio
async def test_fetch_sales_data_success(self, training_service):
"""Test successful sales data fetching"""
mock_request = Mock()
mock_request.start_date = None
mock_request.end_date = None
mock_response_data = {
"sales": [
{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45}
]
}
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
result = await training_service._fetch_sales_data(
tenant_id="test-tenant",
request=mock_request
)
assert result == mock_response_data["sales"]
@pytest.mark.asyncio
async def test_fetch_sales_data_error(self, training_service):
"""Test sales data fetching with API error"""
mock_request = Mock()
mock_request.start_date = None
mock_request.end_date = None
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 500
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
result = await training_service._fetch_sales_data(
tenant_id="test-tenant",
request=mock_request
)
assert result == []
@pytest.mark.asyncio
async def test_fetch_weather_data_success(self, training_service):
"""Test successful weather data fetching"""
mock_request = Mock()
mock_request.start_date = None
mock_request.end_date = None
mock_response_data = {
"weather": [
{"date": "2024-01-01", "temperature": 15.2, "precipitation": 0.0}
]
}
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
result = await training_service._fetch_weather_data(
tenant_id="test-tenant",
request=mock_request
)
assert result == mock_response_data["weather"]
@pytest.mark.asyncio
async def test_fetch_traffic_data_success(self, training_service):
"""Test successful traffic data fetching"""
mock_request = Mock()
mock_request.start_date = None
mock_request.end_date = None
mock_response_data = {
"traffic": [
{"date": "2024-01-01", "traffic_volume": 120}
]
}
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
result = await training_service._fetch_traffic_data(
tenant_id="test-tenant",
request=mock_request
)
assert result == mock_response_data["traffic"]
@pytest.mark.asyncio
async def test_fetch_data_with_date_filters(self, training_service):
"""Test data fetching with date filters"""
from datetime import datetime
mock_request = Mock()
mock_request.start_date = datetime(2024, 1, 1)
mock_request.end_date = datetime(2024, 1, 31)
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"sales": []}
mock_get = mock_client.return_value.__aenter__.return_value.get
mock_get.return_value = mock_response
await training_service._fetch_sales_data(
tenant_id="test-tenant",
request=mock_request
)
# Verify dates were passed in params
call_args = mock_get.call_args
params = call_args[1]["params"]
assert "start_date" in params
assert "end_date" in params
assert params["start_date"] == "2024-01-01T00:00:00"
assert params["end_date"] == "2024-01-31T00:00:00"
class TestTrainingServiceExecution:
"""Test training execution workflow"""
@pytest.fixture
def training_service(self, mock_ml_trainer):
return TrainingService()
@pytest.mark.asyncio
async def test_execute_training_job_success(
self,
training_service,
test_db_session,
mock_messaging,
mock_data_service
):
"""Test successful training job execution"""
# Create job first
job_id = "test-execution-job"
training_log = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id=job_id,
config={"include_weather": True}
)
request = TrainingJobRequest(
include_weather=True,
include_traffic=True,
min_data_points=30
)
with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch_sales, \
patch('app.services.training_service.TrainingService._fetch_weather_data') as mock_fetch_weather, \
patch('app.services.training_service.TrainingService._fetch_traffic_data') as mock_fetch_traffic, \
patch('app.services.training_service.TrainingService._store_trained_models') as mock_store:
mock_fetch_sales.return_value = [{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45}]
mock_fetch_weather.return_value = []
mock_fetch_traffic.return_value = []
mock_store.return_value = None
await training_service.execute_training_job(
db=test_db_session,
job_id=job_id,
tenant_id="test-tenant",
request=request
)
# Verify job was completed
updated_job = await training_service.get_job_status(
db=test_db_session,
job_id=job_id,
tenant_id="test-tenant"
)
assert updated_job.status == "completed"
assert updated_job.progress == 100
@pytest.mark.asyncio
async def test_execute_training_job_failure(
self,
training_service,
test_db_session,
mock_messaging
):
"""Test training job execution with failure"""
# Create job first
job_id = "test-failure-job"
await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id=job_id,
config={}
)
request = TrainingJobRequest(min_data_points=30)
with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch:
mock_fetch.side_effect = Exception("Data service unavailable")
with pytest.raises(Exception):
await training_service.execute_training_job(
db=test_db_session,
job_id=job_id,
tenant_id="test-tenant",
request=request
)
# Verify job was marked as failed
updated_job = await training_service.get_job_status(
db=test_db_session,
job_id=job_id,
tenant_id="test-tenant"
)
assert updated_job.status == "failed"
assert "Data service unavailable" in updated_job.error_message
@pytest.mark.asyncio
async def test_execute_single_product_training_success(
self,
training_service,
test_db_session,
mock_messaging,
mock_data_service
):
"""Test successful single product training execution"""
job_id = "test-single-product-job"
product_name = "Pan Integral"
await training_service.create_single_product_job(
db=test_db_session,
tenant_id="test-tenant",
product_name=product_name,
job_id=job_id,
config={}
)
request = SingleProductTrainingRequest(
include_weather=True,
include_traffic=False
)
with patch('app.services.training_service.TrainingService._fetch_product_sales_data') as mock_fetch_sales, \
patch('app.services.training_service.TrainingService._fetch_weather_data') as mock_fetch_weather, \
patch('app.services.training_service.TrainingService._store_single_trained_model') as mock_store:
mock_fetch_sales.return_value = [{"date": "2024-01-01", "product_name": product_name, "quantity": 45}]
mock_fetch_weather.return_value = []
mock_store.return_value = None
await training_service.execute_single_product_training(
db=test_db_session,
job_id=job_id,
tenant_id="test-tenant",
product_name=product_name,
request=request
)
# Verify job was completed
updated_job = await training_service.get_job_status(
db=test_db_session,
job_id=job_id,
tenant_id="test-tenant"
)
assert updated_job.status == "completed"
assert updated_job.progress == 100
class TestTrainingServiceEdgeCases:
"""Test edge cases and error conditions"""
@pytest.fixture
def training_service(self):
return TrainingService()
@pytest.mark.asyncio
async def test_database_connection_failure(self, training_service):
"""Test handling of database connection failures"""
with patch('sqlalchemy.ext.asyncio.AsyncSession') as mock_session:
mock_session.side_effect = Exception("Database connection failed")
with pytest.raises(Exception):
await training_service.create_training_job(
db=mock_session,
tenant_id="test-tenant",
job_id="test-job",
config={}
)
@pytest.mark.asyncio
async def test_external_service_timeout(self, training_service):
"""Test handling of external service timeouts"""
mock_request = Mock()
mock_request.start_date = None
mock_request.end_date = None
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get.side_effect = httpx.TimeoutException("Request timeout")
result = await training_service._fetch_sales_data(
tenant_id="test-tenant",
request=mock_request
)
# Should return empty list on timeout
assert result == []
@pytest.mark.asyncio
async def test_concurrent_job_creation(self, training_service, test_db_session):
"""Test handling of concurrent job creation"""
# This test would need more sophisticated setup for true concurrency testing
# For now, just test that multiple jobs can be created
job_ids = ["concurrent-job-1", "concurrent-job-2", "concurrent-job-3"]
jobs = []
for job_id in job_ids:
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id=job_id,
config={}
)
jobs.append(job)
assert len(jobs) == 3
for i, job in enumerate(jobs):
assert job.job_id == job_ids[i]
@pytest.mark.asyncio
async def test_malformed_config_handling(self, training_service, test_db_session):
"""Test handling of malformed configuration"""
malformed_config = {
"invalid_key": "invalid_value",
"nested": {"data": None}
}
# Should not raise exception, just store the config as-is
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="malformed-config-job",
config=malformed_config
)
assert job.config == malformed_config