Fix generating pytest for training service 2
This commit is contained in:
@@ -220,6 +220,19 @@ async def get_metrics():
|
||||
return app.state.metrics_collector.get_metrics()
|
||||
return {"status": "metrics not available"}
|
||||
|
||||
@app.get("/health/live")
|
||||
async def liveness_check():
|
||||
return {"status": "alive"}
|
||||
|
||||
@app.get("/health/ready")
|
||||
async def readiness_check():
|
||||
ready = getattr(app.state, 'ready', True)
|
||||
return {"status": "ready" if ready else "not ready"}
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"service": "training-service", "version": "1.0.0"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,673 +0,0 @@
|
||||
# ================================================================
|
||||
# services/training/tests/run_tests.py
|
||||
# ================================================================
|
||||
"""
|
||||
Main test runner script for Training Service
|
||||
Executes comprehensive test suite and generates reports
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import subprocess
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any
|
||||
import logging
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingTestRunner:
|
||||
"""Main test runner for training service"""
|
||||
|
||||
def __init__(self):
|
||||
self.test_dir = Path(__file__).parent
|
||||
self.results_dir = self.test_dir / "results"
|
||||
self.results_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Test configuration
|
||||
self.test_suites = {
|
||||
"unit": {
|
||||
"files": ["test_api.py", "test_ml.py", "test_service.py"],
|
||||
"description": "Unit tests for individual components",
|
||||
"timeout": 300 # 5 minutes
|
||||
},
|
||||
"integration": {
|
||||
"files": ["test_ml_pipeline_integration.py"],
|
||||
"description": "Integration tests for ML pipeline with external data",
|
||||
"timeout": 600 # 10 minutes
|
||||
},
|
||||
"performance": {
|
||||
"files": ["test_performance.py"],
|
||||
"description": "Performance and load testing",
|
||||
"timeout": 900 # 15 minutes
|
||||
},
|
||||
"end_to_end": {
|
||||
"files": ["test_end_to_end.py"],
|
||||
"description": "End-to-end workflow testing",
|
||||
"timeout": 800 # 13 minutes
|
||||
}
|
||||
}
|
||||
|
||||
self.test_results = {}
|
||||
|
||||
async def setup_test_environment(self):
|
||||
"""Setup test environment and dependencies"""
|
||||
logger.info("Setting up test environment...")
|
||||
|
||||
# Check if we're running in Docker
|
||||
if os.path.exists("/.dockerenv"):
|
||||
logger.info("Running in Docker environment")
|
||||
else:
|
||||
logger.info("Running in local environment")
|
||||
|
||||
# Verify required files exist
|
||||
required_files = [
|
||||
"conftest.py",
|
||||
"test_ml_pipeline_integration.py",
|
||||
"test_performance.py"
|
||||
]
|
||||
|
||||
for file in required_files:
|
||||
file_path = self.test_dir / file
|
||||
if not file_path.exists():
|
||||
logger.warning(f"Required test file missing: {file}")
|
||||
|
||||
# Create test data if needed
|
||||
await self.create_test_data()
|
||||
|
||||
# Verify external services (mock or real)
|
||||
await self.verify_external_services()
|
||||
|
||||
async def create_test_data(self):
|
||||
"""Create or verify test data exists"""
|
||||
logger.info("Creating/verifying test data...")
|
||||
|
||||
test_data_dir = self.test_dir / "fixtures" / "test_data"
|
||||
test_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create bakery sales sample if it doesn't exist
|
||||
sales_file = test_data_dir / "bakery_sales_sample.csv"
|
||||
if not sales_file.exists():
|
||||
logger.info("Creating sample sales data...")
|
||||
await self.generate_sample_sales_data(sales_file)
|
||||
|
||||
# Create weather data sample
|
||||
weather_file = test_data_dir / "madrid_weather_sample.json"
|
||||
if not weather_file.exists():
|
||||
logger.info("Creating sample weather data...")
|
||||
await self.generate_sample_weather_data(weather_file)
|
||||
|
||||
# Create traffic data sample
|
||||
traffic_file = test_data_dir / "madrid_traffic_sample.json"
|
||||
if not traffic_file.exists():
|
||||
logger.info("Creating sample traffic data...")
|
||||
await self.generate_sample_traffic_data(traffic_file)
|
||||
|
||||
async def generate_sample_sales_data(self, file_path: Path):
|
||||
"""Generate sample sales data for testing"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Generate 6 months of sample data
|
||||
start_date = datetime(2023, 6, 1)
|
||||
dates = [start_date + timedelta(days=i) for i in range(180)]
|
||||
|
||||
products = ["Pan Integral", "Croissant", "Magdalenas", "Empanadas", "Tarta Chocolate"]
|
||||
|
||||
data = []
|
||||
for date in dates:
|
||||
for product in products:
|
||||
base_quantity = np.random.randint(10, 100)
|
||||
|
||||
# Weekend boost
|
||||
if date.weekday() >= 5:
|
||||
base_quantity *= 1.2
|
||||
|
||||
# Seasonal variation
|
||||
temp = 15 + 10 * np.sin((date.timetuple().tm_yday / 365) * 2 * np.pi)
|
||||
|
||||
data.append({
|
||||
"date": date.strftime("%Y-%m-%d"),
|
||||
"product": product,
|
||||
"quantity": int(base_quantity),
|
||||
"revenue": round(base_quantity * np.random.uniform(2.5, 8.0), 2),
|
||||
"temperature": round(temp + np.random.normal(0, 3), 1),
|
||||
"precipitation": max(0, np.random.exponential(0.5)),
|
||||
"is_weekend": date.weekday() >= 5,
|
||||
"is_holiday": False
|
||||
})
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
df.to_csv(file_path, index=False)
|
||||
logger.info(f"Created sample sales data: {len(df)} records")
|
||||
|
||||
async def generate_sample_weather_data(self, file_path: Path):
|
||||
"""Generate sample weather data"""
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
import numpy as np
|
||||
|
||||
start_date = datetime(2023, 6, 1)
|
||||
weather_data = []
|
||||
|
||||
for i in range(180):
|
||||
date = start_date + timedelta(days=i)
|
||||
day_of_year = date.timetuple().tm_yday
|
||||
base_temp = 14 + 12 * np.sin((day_of_year / 365) * 2 * np.pi)
|
||||
|
||||
weather_data.append({
|
||||
"date": date.isoformat(),
|
||||
"temperature": round(base_temp + np.random.normal(0, 5), 1),
|
||||
"precipitation": max(0, np.random.exponential(1.0)),
|
||||
"humidity": np.random.uniform(30, 80),
|
||||
"wind_speed": np.random.uniform(5, 25),
|
||||
"pressure": np.random.uniform(1000, 1025),
|
||||
"description": np.random.choice(["Soleado", "Nuboso", "Lluvioso"]),
|
||||
"source": "aemet_test"
|
||||
})
|
||||
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(weather_data, f, indent=2)
|
||||
logger.info(f"Created sample weather data: {len(weather_data)} records")
|
||||
|
||||
async def generate_sample_traffic_data(self, file_path: Path):
|
||||
"""Generate sample traffic data"""
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
import numpy as np
|
||||
|
||||
start_date = datetime(2023, 6, 1)
|
||||
traffic_data = []
|
||||
|
||||
for i in range(180):
|
||||
date = start_date + timedelta(days=i)
|
||||
|
||||
for hour in [8, 12, 18]: # Three measurements per day
|
||||
measurement_time = date.replace(hour=hour)
|
||||
|
||||
if hour in [8, 18]: # Rush hours
|
||||
volume = np.random.randint(800, 1500)
|
||||
congestion = "high"
|
||||
else: # Lunch time
|
||||
volume = np.random.randint(400, 800)
|
||||
congestion = "medium"
|
||||
|
||||
traffic_data.append({
|
||||
"date": measurement_time.isoformat(),
|
||||
"traffic_volume": volume,
|
||||
"occupation_percentage": np.random.randint(10, 90),
|
||||
"load_percentage": np.random.randint(20, 95),
|
||||
"average_speed": np.random.randint(15, 50),
|
||||
"congestion_level": congestion,
|
||||
"pedestrian_count": np.random.randint(50, 500),
|
||||
"measurement_point_id": "TEST_POINT_001",
|
||||
"measurement_point_name": "Plaza Mayor",
|
||||
"road_type": "URB",
|
||||
"source": "madrid_opendata_test"
|
||||
})
|
||||
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(traffic_data, f, indent=2)
|
||||
logger.info(f"Created sample traffic data: {len(traffic_data)} records")
|
||||
|
||||
async def verify_external_services(self):
|
||||
"""Verify external services are available (mock or real)"""
|
||||
logger.info("Verifying external services...")
|
||||
|
||||
# Check if mock services are available
|
||||
mock_services = [
|
||||
("Mock AEMET", "http://localhost:8080/health"),
|
||||
("Mock Madrid OpenData", "http://localhost:8081/health"),
|
||||
("Mock Auth Service", "http://localhost:8082/health"),
|
||||
("Mock Data Service", "http://localhost:8083/health")
|
||||
]
|
||||
|
||||
try:
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
for service_name, url in mock_services:
|
||||
try:
|
||||
response = await client.get(url)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"{service_name} is available")
|
||||
else:
|
||||
logger.warning(f"{service_name} returned status {response.status_code}")
|
||||
except Exception as e:
|
||||
logger.warning(f"{service_name} is not available: {e}")
|
||||
except ImportError:
|
||||
logger.warning("httpx not available, skipping service checks")
|
||||
|
||||
def run_test_suite(self, suite_name: str) -> Dict[str, Any]:
|
||||
"""Run a specific test suite"""
|
||||
suite_config = self.test_suites[suite_name]
|
||||
logger.info(f"Running {suite_name} test suite: {suite_config['description']}")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Prepare pytest command
|
||||
pytest_args = [
|
||||
"python", "-m", "pytest",
|
||||
"-v",
|
||||
"--tb=short",
|
||||
"--capture=no",
|
||||
f"--junitxml={self.results_dir}/junit_{suite_name}.xml",
|
||||
f"--cov=app",
|
||||
f"--cov-report=html:{self.results_dir}/coverage_{suite_name}_html",
|
||||
f"--cov-report=xml:{self.results_dir}/coverage_{suite_name}.xml",
|
||||
"--cov-report=term-missing"
|
||||
]
|
||||
|
||||
# Add test files
|
||||
for test_file in suite_config["files"]:
|
||||
test_path = self.test_dir / test_file
|
||||
if test_path.exists():
|
||||
pytest_args.append(str(test_path))
|
||||
else:
|
||||
logger.warning(f"Test file not found: {test_file}")
|
||||
|
||||
# Run the tests
|
||||
try:
|
||||
result = subprocess.run(
|
||||
pytest_args,
|
||||
cwd=self.test_dir.parent, # Run from training service root
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=suite_config["timeout"]
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
return {
|
||||
"suite": suite_name,
|
||||
"status": "passed" if result.returncode == 0 else "failed",
|
||||
"return_code": result.returncode,
|
||||
"duration": duration,
|
||||
"stdout": result.stdout,
|
||||
"stderr": result.stderr,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
duration = time.time() - start_time
|
||||
logger.error(f"Test suite {suite_name} timed out after {duration:.2f}s")
|
||||
|
||||
return {
|
||||
"suite": suite_name,
|
||||
"status": "timeout",
|
||||
"return_code": -1,
|
||||
"duration": duration,
|
||||
"stdout": "",
|
||||
"stderr": f"Test suite timed out after {suite_config['timeout']}s",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
logger.error(f"Error running test suite {suite_name}: {e}")
|
||||
|
||||
return {
|
||||
"suite": suite_name,
|
||||
"status": "error",
|
||||
"return_code": -1,
|
||||
"duration": duration,
|
||||
"stdout": "",
|
||||
"stderr": str(e),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
def generate_test_report(self):
|
||||
"""Generate comprehensive test report"""
|
||||
logger.info("Generating test report...")
|
||||
|
||||
# Calculate summary statistics
|
||||
total_suites = len(self.test_results)
|
||||
passed_suites = sum(1 for r in self.test_results.values() if r["status"] == "passed")
|
||||
failed_suites = sum(1 for r in self.test_results.values() if r["status"] == "failed")
|
||||
error_suites = sum(1 for r in self.test_results.values() if r["status"] == "error")
|
||||
timeout_suites = sum(1 for r in self.test_results.values() if r["status"] == "timeout")
|
||||
|
||||
total_duration = sum(r["duration"] for r in self.test_results.values())
|
||||
|
||||
# Create detailed report
|
||||
report = {
|
||||
"test_run_summary": {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_suites": total_suites,
|
||||
"passed_suites": passed_suites,
|
||||
"failed_suites": failed_suites,
|
||||
"error_suites": error_suites,
|
||||
"timeout_suites": timeout_suites,
|
||||
"success_rate": (passed_suites / total_suites * 100) if total_suites > 0 else 0,
|
||||
"total_duration_seconds": total_duration
|
||||
},
|
||||
"suite_results": self.test_results,
|
||||
"recommendations": self.generate_recommendations()
|
||||
}
|
||||
|
||||
# Save JSON report
|
||||
report_file = self.results_dir / "test_report.json"
|
||||
with open(report_file, 'w') as f:
|
||||
json.dump(report, f, indent=2)
|
||||
|
||||
# Generate HTML report
|
||||
self.generate_html_report(report)
|
||||
|
||||
# Print summary to console
|
||||
self.print_test_summary(report)
|
||||
|
||||
return report
|
||||
|
||||
def generate_recommendations(self) -> List[str]:
|
||||
"""Generate recommendations based on test results"""
|
||||
recommendations = []
|
||||
|
||||
failed_suites = [name for name, result in self.test_results.items() if result["status"] == "failed"]
|
||||
timeout_suites = [name for name, result in self.test_results.items() if result["status"] == "timeout"]
|
||||
|
||||
if failed_suites:
|
||||
recommendations.append(f"Failed test suites: {', '.join(failed_suites)}. Check logs for detailed error messages.")
|
||||
|
||||
if timeout_suites:
|
||||
recommendations.append(f"Timeout in suites: {', '.join(timeout_suites)}. Consider increasing timeout or optimizing performance.")
|
||||
|
||||
# Performance recommendations
|
||||
slow_suites = [
|
||||
name for name, result in self.test_results.items()
|
||||
if result["duration"] > 300 # 5 minutes
|
||||
]
|
||||
if slow_suites:
|
||||
recommendations.append(f"Slow test suites: {', '.join(slow_suites)}. Consider performance optimization.")
|
||||
|
||||
if not recommendations:
|
||||
recommendations.append("All tests passed successfully! Consider adding more edge case tests.")
|
||||
|
||||
return recommendations
|
||||
|
||||
def generate_html_report(self, report: Dict[str, Any]):
|
||||
"""Generate HTML test report"""
|
||||
html_template = """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Training Service Test Report</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; margin: 40px; }
|
||||
.header { background-color: #f8f9fa; padding: 20px; border-radius: 5px; }
|
||||
.summary { display: flex; gap: 20px; margin: 20px 0; }
|
||||
.metric { background: white; border: 1px solid #dee2e6; padding: 15px; border-radius: 5px; text-align: center; }
|
||||
.metric-value { font-size: 24px; font-weight: bold; }
|
||||
.passed { color: #28a745; }
|
||||
.failed { color: #dc3545; }
|
||||
.timeout { color: #fd7e14; }
|
||||
.error { color: #6c757d; }
|
||||
.suite-result { margin: 20px 0; padding: 15px; border: 1px solid #dee2e6; border-radius: 5px; }
|
||||
.recommendations { background-color: #e7f3ff; padding: 15px; border-radius: 5px; margin: 20px 0; }
|
||||
pre { background-color: #f8f9fa; padding: 10px; border-radius: 3px; overflow-x: auto; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>Training Service Test Report</h1>
|
||||
<p>Generated: {timestamp}</p>
|
||||
</div>
|
||||
|
||||
<div class="summary">
|
||||
<div class="metric">
|
||||
<div class="metric-value">{total_suites}</div>
|
||||
<div>Total Suites</div>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<div class="metric-value passed">{passed_suites}</div>
|
||||
<div>Passed</div>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<div class="metric-value failed">{failed_suites}</div>
|
||||
<div>Failed</div>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<div class="metric-value timeout">{timeout_suites}</div>
|
||||
<div>Timeout</div>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<div class="metric-value">{success_rate:.1f}%</div>
|
||||
<div>Success Rate</div>
|
||||
</div>
|
||||
<div class="metric">
|
||||
<div class="metric-value">{duration:.1f}s</div>
|
||||
<div>Total Duration</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="recommendations">
|
||||
<h3>Recommendations</h3>
|
||||
<ul>
|
||||
{recommendations_html}
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<h2>Suite Results</h2>
|
||||
{suite_results_html}
|
||||
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
# Format recommendations
|
||||
recommendations_html = '\n'.join(
|
||||
f"<li>{rec}</li>" for rec in report["recommendations"]
|
||||
)
|
||||
|
||||
# Format suite results
|
||||
suite_results_html = ""
|
||||
for suite_name, result in report["suite_results"].items():
|
||||
status_class = result["status"]
|
||||
suite_results_html += f"""
|
||||
<div class="suite-result">
|
||||
<h3>{suite_name.title()} Tests <span class="{status_class}">({result["status"].upper()})</span></h3>
|
||||
<p><strong>Duration:</strong> {result["duration"]:.2f}s</p>
|
||||
<p><strong>Return Code:</strong> {result["return_code"]}</p>
|
||||
|
||||
{f'<h4>Output:</h4><pre>{result["stdout"][:1000]}{"..." if len(result["stdout"]) > 1000 else ""}</pre>' if result["stdout"] else ""}
|
||||
{f'<h4>Errors:</h4><pre>{result["stderr"][:1000]}{"..." if len(result["stderr"]) > 1000 else ""}</pre>' if result["stderr"] else ""}
|
||||
</div>
|
||||
"""
|
||||
|
||||
# Fill template
|
||||
html_content = html_template.format(
|
||||
timestamp=report["test_run_summary"]["timestamp"],
|
||||
total_suites=report["test_run_summary"]["total_suites"],
|
||||
passed_suites=report["test_run_summary"]["passed_suites"],
|
||||
failed_suites=report["test_run_summary"]["failed_suites"],
|
||||
timeout_suites=report["test_run_summary"]["timeout_suites"],
|
||||
success_rate=report["test_run_summary"]["success_rate"],
|
||||
duration=report["test_run_summary"]["total_duration_seconds"],
|
||||
recommendations_html=recommendations_html,
|
||||
suite_results_html=suite_results_html
|
||||
)
|
||||
|
||||
# Save HTML report
|
||||
html_file = self.results_dir / "test_report.html"
|
||||
with open(html_file, 'w') as f:
|
||||
f.write(html_content)
|
||||
|
||||
logger.info(f"HTML report saved to: {html_file}")
|
||||
|
||||
def print_test_summary(self, report: Dict[str, Any]):
|
||||
"""Print test summary to console"""
|
||||
summary = report["test_run_summary"]
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("TRAINING SERVICE TEST RESULTS SUMMARY")
|
||||
print("=" * 80)
|
||||
print(f"Timestamp: {summary['timestamp']}")
|
||||
print(f"Total Suites: {summary['total_suites']}")
|
||||
print(f"Passed: {summary['passed_suites']}")
|
||||
print(f"Failed: {summary['failed_suites']}")
|
||||
print(f"Errors: {summary['error_suites']}")
|
||||
print(f"Timeouts: {summary['timeout_suites']}")
|
||||
print(f"Success Rate: {summary['success_rate']:.1f}%")
|
||||
print(f"Total Duration: {summary['total_duration_seconds']:.2f}s")
|
||||
|
||||
print("\nSUITE DETAILS:")
|
||||
print("-" * 50)
|
||||
for suite_name, result in report["suite_results"].items():
|
||||
status_icon = "✅" if result["status"] == "passed" else "❌"
|
||||
print(f"{status_icon} {suite_name.ljust(15)}: {result['status'].upper().ljust(10)} ({result['duration']:.2f}s)")
|
||||
|
||||
print("\nRECOMMENDATIONS:")
|
||||
print("-" * 50)
|
||||
for i, rec in enumerate(report["recommendations"], 1):
|
||||
print(f"{i}. {rec}")
|
||||
|
||||
print("\nFILES GENERATED:")
|
||||
print("-" * 50)
|
||||
print(f"📄 JSON Report: {self.results_dir}/test_report.json")
|
||||
print(f"🌐 HTML Report: {self.results_dir}/test_report.html")
|
||||
print(f"📊 Coverage Reports: {self.results_dir}/coverage_*_html/")
|
||||
print(f"📋 JUnit XML: {self.results_dir}/junit_*.xml")
|
||||
print("=" * 80)
|
||||
|
||||
async def run_all_tests(self):
|
||||
"""Run all test suites"""
|
||||
logger.info("Starting comprehensive test run...")
|
||||
|
||||
# Setup environment
|
||||
await self.setup_test_environment()
|
||||
|
||||
# Run each test suite
|
||||
for suite_name in self.test_suites.keys():
|
||||
logger.info(f"Starting {suite_name} test suite...")
|
||||
result = self.run_test_suite(suite_name)
|
||||
self.test_results[suite_name] = result
|
||||
|
||||
if result["status"] == "passed":
|
||||
logger.info(f"✅ {suite_name} tests PASSED ({result['duration']:.2f}s)")
|
||||
elif result["status"] == "failed":
|
||||
logger.error(f"❌ {suite_name} tests FAILED ({result['duration']:.2f}s)")
|
||||
elif result["status"] == "timeout":
|
||||
logger.error(f"⏰ {suite_name} tests TIMED OUT ({result['duration']:.2f}s)")
|
||||
else:
|
||||
logger.error(f"💥 {suite_name} tests ERROR ({result['duration']:.2f}s)")
|
||||
|
||||
# Generate final report
|
||||
report = self.generate_test_report()
|
||||
|
||||
return report
|
||||
|
||||
def run_specific_suite(self, suite_name: str):
|
||||
"""Run a specific test suite"""
|
||||
if suite_name not in self.test_suites:
|
||||
logger.error(f"Unknown test suite: {suite_name}")
|
||||
logger.info(f"Available suites: {', '.join(self.test_suites.keys())}")
|
||||
return None
|
||||
|
||||
logger.info(f"Running {suite_name} test suite only...")
|
||||
result = self.run_test_suite(suite_name)
|
||||
self.test_results[suite_name] = result
|
||||
|
||||
# Generate report for single suite
|
||||
report = self.generate_test_report()
|
||||
return report
|
||||
|
||||
|
||||
# ================================================================
|
||||
# MAIN EXECUTION
|
||||
# ================================================================
|
||||
|
||||
async def main():
|
||||
"""Main execution function"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Training Service Test Runner")
|
||||
parser.add_argument(
|
||||
"--suite",
|
||||
choices=list(TrainingTestRunner().test_suites.keys()) + ["all"],
|
||||
default="all",
|
||||
help="Test suite to run (default: all)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
action="store_true",
|
||||
help="Verbose output"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quick",
|
||||
action="store_true",
|
||||
help="Run quick tests only (skip performance tests)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging level
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
# Create test runner
|
||||
runner = TrainingTestRunner()
|
||||
|
||||
# Modify test suites for quick run
|
||||
if args.quick:
|
||||
# Skip performance tests in quick mode
|
||||
if "performance" in runner.test_suites:
|
||||
del runner.test_suites["performance"]
|
||||
logger.info("Quick mode: Skipping performance tests")
|
||||
|
||||
try:
|
||||
if args.suite == "all":
|
||||
report = await runner.run_all_tests()
|
||||
else:
|
||||
report = runner.run_specific_suite(args.suite)
|
||||
|
||||
# Exit with appropriate code
|
||||
if report and report["test_run_summary"]["failed_suites"] == 0 and report["test_run_summary"]["error_suites"] == 0:
|
||||
logger.info("All tests completed successfully!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("Some tests failed!")
|
||||
sys.exit(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Test run interrupted by user")
|
||||
sys.exit(130)
|
||||
except Exception as e:
|
||||
logger.error(f"Test run failed with error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Handle both direct execution and pytest discovery
|
||||
if len(sys.argv) > 1 and sys.argv[1] in ["--suite", "-h", "--help"]:
|
||||
# Running as main script with arguments
|
||||
asyncio.run(main())
|
||||
else:
|
||||
# Running as pytest discovery or direct execution without args
|
||||
print("Training Service Test Runner")
|
||||
print("=" * 50)
|
||||
print("Usage:")
|
||||
print(" python run_tests.py --suite all # Run all test suites")
|
||||
print(" python run_tests.py --suite unit # Run unit tests only")
|
||||
print(" python run_tests.py --suite integration # Run integration tests only")
|
||||
print(" python run_tests.py --suite performance # Run performance tests only")
|
||||
print(" python run_tests.py --quick # Run quick tests (skip performance)")
|
||||
print(" python run_tests.py -v # Verbose output")
|
||||
print()
|
||||
print("Available test suites:")
|
||||
runner = TrainingTestRunner()
|
||||
for suite_name, config in runner.test_suites.items():
|
||||
print(f" {suite_name.ljust(15)}: {config['description']}")
|
||||
print()
|
||||
|
||||
# If no arguments provided, run all tests
|
||||
if len(sys.argv) == 1:
|
||||
print("No arguments provided. Running all tests...")
|
||||
asyncio.run(TrainingTestRunner().run_all_tests())
|
||||
@@ -1,687 +0,0 @@
|
||||
# services/training/tests/test_api.py
|
||||
"""
|
||||
Tests for training service API endpoints
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from fastapi import status
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.schemas.training import TrainingJobRequest
|
||||
|
||||
|
||||
class TestTrainingAPI:
|
||||
"""Test training API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, test_client: AsyncClient):
|
||||
"""Test health check endpoint"""
|
||||
response = await test_client.get("/health")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["service"] == "training-service"
|
||||
assert data["version"] == "1.0.0"
|
||||
assert "status" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_check_ready(self, test_client: AsyncClient):
|
||||
"""Test readiness check when service is ready"""
|
||||
# Mock app state as ready
|
||||
from app.main import app # Add import at top
|
||||
with patch.object(app.state, 'ready', True, create=True):
|
||||
response = await test_client.get("/health/ready")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["status"] == "ready"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_check_not_ready(self, test_client: AsyncClient):
|
||||
"""Test readiness check when service is not ready"""
|
||||
with patch('app.main.app.state.ready', False):
|
||||
response = await test_client.get("/health/ready")
|
||||
|
||||
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
data = response.json()
|
||||
assert data["status"] == "not_ready"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_liveness_check_healthy(self, test_client: AsyncClient):
|
||||
"""Test liveness check when service is healthy"""
|
||||
with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=True)):
|
||||
response = await test_client.get("/health/live")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["status"] == "alive"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_liveness_check_unhealthy(self, test_client: AsyncClient):
|
||||
"""Test liveness check when database is unhealthy"""
|
||||
with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=False)):
|
||||
response = await test_client.get("/health/live")
|
||||
|
||||
assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
data = response.json()
|
||||
assert data["status"] == "unhealthy"
|
||||
assert data["reason"] == "database_unavailable"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_endpoint(self, test_client: AsyncClient):
|
||||
"""Test metrics endpoint"""
|
||||
response = await test_client.get("/metrics")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
expected_metrics = [
|
||||
"training_jobs_active",
|
||||
"training_jobs_completed",
|
||||
"training_jobs_failed",
|
||||
"models_trained_total",
|
||||
"uptime_seconds"
|
||||
]
|
||||
|
||||
for metric in expected_metrics:
|
||||
assert metric in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_root_endpoint(self, test_client: AsyncClient):
|
||||
"""Test root endpoint"""
|
||||
response = await test_client.get("/")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["service"] == "training-service"
|
||||
assert data["version"] == "1.0.0"
|
||||
assert "description" in data
|
||||
|
||||
|
||||
class TestTrainingJobsAPI:
|
||||
"""Test training jobs API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_training_job_success(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_ml_trainer,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test starting a training job successfully"""
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30,
|
||||
"seasonality_mode": "additive"
|
||||
}
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "job_id" in data
|
||||
assert data["status"] == "started"
|
||||
assert data["tenant_id"] == "test-tenant"
|
||||
assert "estimated_duration_minutes" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_training_job_validation_error(self, test_client: AsyncClient):
|
||||
"""Test starting training job with validation error"""
|
||||
request_data = {
|
||||
"seasonality_mode": "invalid_mode", # Invalid value
|
||||
"min_data_points": 5 # Too low
|
||||
}
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_status_existing_job(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test getting status of existing training job"""
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["job_id"] == job_id
|
||||
assert data["status"] == "pending"
|
||||
assert "progress" in data
|
||||
assert "started_at" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_status_nonexistent_job(self, test_client: AsyncClient):
|
||||
"""Test getting status of non-existent training job"""
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.get("/training/jobs/nonexistent-job/status")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_training_jobs(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test listing training jobs"""
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.get("/training/jobs")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert isinstance(data, list)
|
||||
assert len(data) >= 1
|
||||
|
||||
# Check first job structure
|
||||
job = data[0]
|
||||
assert "job_id" in job
|
||||
assert "status" in job
|
||||
assert "started_at" in job
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_training_jobs_with_status_filter(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test listing training jobs with status filter"""
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.get("/training/jobs?status=pending")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert isinstance(data, list)
|
||||
# All jobs should have status "pending"
|
||||
for job in data:
|
||||
assert job["status"] == "pending"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_training_job_success(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db,
|
||||
mock_messaging
|
||||
):
|
||||
"""Test cancelling a training job successfully"""
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post(f"/training/jobs/{job_id}/cancel")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "message" in data
|
||||
assert "cancelled" in data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_nonexistent_job(self, test_client: AsyncClient):
|
||||
"""Test cancelling a non-existent training job"""
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs/nonexistent-job/cancel")
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_logs(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test getting training logs"""
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/logs")
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "job_id" in data
|
||||
assert "logs" in data
|
||||
assert isinstance(data["logs"], list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_training_data_valid(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test validating valid training data"""
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/validate", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "is_valid" in data
|
||||
assert "issues" in data
|
||||
assert "recommendations" in data
|
||||
assert "estimated_training_time" in data
|
||||
|
||||
|
||||
class TestSingleProductTrainingAPI:
|
||||
"""Test single product training API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_single_product_success(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_ml_trainer,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test training a single product successfully"""
|
||||
product_name = "Pan Integral"
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"seasonality_mode": "additive"
|
||||
}
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
f"/training/products/{product_name}",
|
||||
json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert "job_id" in data
|
||||
assert data["status"] == "started"
|
||||
assert data["tenant_id"] == "test-tenant"
|
||||
assert f"training started for {product_name}" in data["message"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_single_product_validation_error(self, test_client: AsyncClient):
|
||||
"""Test single product training with validation error"""
|
||||
product_name = "Pan Integral"
|
||||
request_data = {
|
||||
"seasonality_mode": "invalid_mode" # Invalid value
|
||||
}
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
f"/training/products/{product_name}",
|
||||
json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_train_single_product_special_characters(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_ml_trainer,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test training product with special characters in name"""
|
||||
product_name = "Pan Francés" # With accent
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"seasonality_mode": "additive"
|
||||
}
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
f"/training/products/{product_name}",
|
||||
json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "job_id" in data
|
||||
|
||||
|
||||
class TestModelsAPI:
|
||||
"""Test models API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_models(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
trained_model_in_db
|
||||
):
|
||||
"""Test listing trained models"""
|
||||
with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get("/models")
|
||||
|
||||
# This endpoint might not exist yet, so we expect either 200 or 404
|
||||
assert response.status_code in [status.HTTP_200_OK, status.HTTP_404_NOT_FOUND]
|
||||
|
||||
if response.status_code == status.HTTP_200_OK:
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_details(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
trained_model_in_db
|
||||
):
|
||||
"""Test getting model details"""
|
||||
model_id = trained_model_in_db.model_id
|
||||
|
||||
with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/models/{model_id}")
|
||||
|
||||
# This endpoint might not exist yet
|
||||
assert response.status_code in [
|
||||
status.HTTP_200_OK,
|
||||
status.HTTP_404_NOT_FOUND,
|
||||
status.HTTP_501_NOT_IMPLEMENTED
|
||||
]
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling in API endpoints"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_error_handling(self, test_client: AsyncClient):
|
||||
"""Test handling of database errors"""
|
||||
with patch('app.services.training_service.TrainingService.create_training_job') as mock_create:
|
||||
mock_create.side_effect = Exception("Database connection failed")
|
||||
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_tenant_id(self, test_client: AsyncClient):
|
||||
"""Test handling when tenant ID is missing"""
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
# Don't mock get_current_tenant_id to simulate missing auth
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
# Should fail due to missing authentication
|
||||
assert response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_job_id_format(self, test_client: AsyncClient):
|
||||
"""Test handling of invalid job ID format"""
|
||||
invalid_job_id = "invalid-job-id-with-special-chars@#$"
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{invalid_job_id}/status")
|
||||
|
||||
# Should handle gracefully
|
||||
assert response.status_code in [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messaging_failure_handling(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test handling when messaging fails"""
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")), \
|
||||
patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
# Should still succeed even if messaging fails
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert "job_id" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_payload(self, test_client: AsyncClient):
|
||||
"""Test handling of invalid JSON payload"""
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
"/training/jobs",
|
||||
content="invalid json {{{",
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_content_type(self, test_client: AsyncClient):
|
||||
"""Test handling of unsupported content type"""
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
"/training/jobs",
|
||||
content="some text data",
|
||||
headers={"Content-Type": "text/plain"}
|
||||
)
|
||||
|
||||
assert response.status_code in [
|
||||
status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
]
|
||||
|
||||
|
||||
class TestAuthenticationIntegration:
|
||||
"""Test authentication integration"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_endpoints_require_auth(self, test_client: AsyncClient):
|
||||
"""Test that endpoints require authentication in production"""
|
||||
# This test would be more meaningful in a production environment
|
||||
# where authentication is actually enforced
|
||||
|
||||
endpoints_to_test = [
|
||||
("POST", "/training/jobs"),
|
||||
("GET", "/training/jobs"),
|
||||
("POST", "/training/products/Pan Integral"),
|
||||
("POST", "/training/validate")
|
||||
]
|
||||
|
||||
for method, endpoint in endpoints_to_test:
|
||||
if method == "POST":
|
||||
response = await test_client.post(endpoint, json={})
|
||||
else:
|
||||
response = await test_client.get(endpoint)
|
||||
|
||||
# In test environment with mocked auth, should work
|
||||
# In production, would require valid authentication
|
||||
assert response.status_code != status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tenant_isolation_in_api(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test tenant isolation at API level"""
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
# Try to access job with different tenant
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
# Should not find job for different tenant
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class TestAPIValidation:
|
||||
"""Test API validation and input handling"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_training_request_validation(self, test_client: AsyncClient):
|
||||
"""Test comprehensive training request validation"""
|
||||
|
||||
# Test valid request
|
||||
valid_request = {
|
||||
"include_weather": True,
|
||||
"include_traffic": False,
|
||||
"min_data_points": 30,
|
||||
"seasonality_mode": "additive",
|
||||
"daily_seasonality": True,
|
||||
"weekly_seasonality": True,
|
||||
"yearly_seasonality": True
|
||||
}
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=valid_request)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Test invalid seasonality mode
|
||||
invalid_request = valid_request.copy()
|
||||
invalid_request["seasonality_mode"] = "invalid_mode"
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=invalid_request)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
# Test invalid min_data_points
|
||||
invalid_request = valid_request.copy()
|
||||
invalid_request["min_data_points"] = 5 # Too low
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=invalid_request)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_product_request_validation(self, test_client: AsyncClient):
|
||||
"""Test single product training request validation"""
|
||||
|
||||
product_name = "Pan Integral"
|
||||
|
||||
# Test valid request
|
||||
valid_request = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"seasonality_mode": "multiplicative"
|
||||
}
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
f"/training/products/{product_name}",
|
||||
json=valid_request
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Test empty product name
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
"/training/products/",
|
||||
json=valid_request
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_parameter_validation(self, test_client: AsyncClient):
|
||||
"""Test query parameter validation"""
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
# Test valid limit parameter
|
||||
response = await test_client.get("/training/jobs?limit=5")
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Test invalid limit parameter
|
||||
response = await test_client.get("/training/jobs?limit=invalid")
|
||||
assert response.status_code in [
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status.HTTP_400_BAD_REQUEST
|
||||
]
|
||||
|
||||
# Test negative limit
|
||||
response = await test_client.get("/training/jobs?limit=-1")
|
||||
assert response.status_code in [
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
status.HTTP_400_BAD_REQUEST
|
||||
]
|
||||
|
||||
|
||||
class TestAPIPerformance:
|
||||
"""Test API performance characteristics"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests(self, test_client: AsyncClient):
|
||||
"""Test handling of concurrent requests"""
|
||||
import asyncio
|
||||
|
||||
# Create multiple concurrent requests
|
||||
tasks = []
|
||||
for i in range(10):
|
||||
with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"):
|
||||
task = test_client.get("/health")
|
||||
tasks.append(task)
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
# All requests should succeed
|
||||
for response in responses:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_payload_handling(self, test_client: AsyncClient):
|
||||
"""Test handling of large request payloads"""
|
||||
|
||||
# Create large request payload
|
||||
large_request = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30,
|
||||
"large_config": {f"key_{i}": f"value_{i}" for i in range(1000)}
|
||||
}
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=large_request)
|
||||
|
||||
# Should handle large payload gracefully
|
||||
assert response.status_code in [
|
||||
status.HTTP_200_OK,
|
||||
status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_successive_requests(self, test_client: AsyncClient):
|
||||
"""Test rapid successive requests to same endpoint"""
|
||||
|
||||
# Make rapid requests
|
||||
responses = []
|
||||
for _ in range(20):
|
||||
response = await test_client.get("/health")
|
||||
responses.append(response)
|
||||
|
||||
# All should succeed
|
||||
for response in responses:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@@ -1,311 +0,0 @@
|
||||
# ================================================================
|
||||
# services/training/tests/test_end_to_end.py
|
||||
# ================================================================
|
||||
"""
|
||||
End-to-End Testing for Training Service
|
||||
Tests complete workflows from API to ML pipeline to results
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import httpx
|
||||
import pandas as pd
|
||||
import json
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any
|
||||
from unittest.mock import patch, AsyncMock
|
||||
import uuid
|
||||
|
||||
from app.main import app
|
||||
from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
|
||||
|
||||
|
||||
class TestTrainingServiceEndToEnd:
|
||||
"""End-to-end tests for complete training workflows"""
|
||||
|
||||
@pytest.fixture
|
||||
async def test_client(self):
|
||||
"""Create test client for the training service"""
|
||||
from httpx import AsyncClient
|
||||
async with AsyncClient(app=app, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
@pytest.fixture
|
||||
def real_bakery_data(self):
|
||||
"""Use the actual bakery sales data from the uploaded CSV"""
|
||||
# This fixture would load the real bakery_sales_2023_2024.csv data
|
||||
# For testing, we'll simulate the structure based on the document description
|
||||
|
||||
# Generate realistic data matching the CSV structure
|
||||
start_date = datetime(2023, 1, 1)
|
||||
dates = [start_date + timedelta(days=i) for i in range(365)]
|
||||
|
||||
products = [
|
||||
"Pan Integral", "Pan Blanco", "Croissant", "Magdalenas",
|
||||
"Empanadas", "Tarta Chocolate", "Roscon Reyes", "Palmeras"
|
||||
]
|
||||
|
||||
data = []
|
||||
for date in dates:
|
||||
for product in products:
|
||||
# Realistic sales patterns for Madrid bakery
|
||||
base_quantity = {
|
||||
"Pan Integral": 80, "Pan Blanco": 120, "Croissant": 45,
|
||||
"Magdalenas": 30, "Empanadas": 25, "Tarta Chocolate": 15,
|
||||
"Roscon Reyes": 8, "Palmeras": 12
|
||||
}.get(product, 20)
|
||||
|
||||
# Seasonal variations
|
||||
if date.month == 12 and product == "Roscon Reyes":
|
||||
base_quantity *= 5 # Christmas specialty
|
||||
elif date.month in [6, 7, 8]: # Summer
|
||||
base_quantity *= 0.85
|
||||
elif date.month in [11, 12, 1]: # Winter
|
||||
base_quantity *= 1.15
|
||||
|
||||
# Weekly patterns
|
||||
if date.weekday() >= 5: # Weekends
|
||||
base_quantity *= 1.3
|
||||
elif date.weekday() == 0: # Monday slower
|
||||
base_quantity *= 0.8
|
||||
|
||||
# Weather influence
|
||||
temp = 15 + 12 * np.sin((date.timetuple().tm_yday / 365) * 2 * np.pi)
|
||||
if temp > 30: # Very hot days
|
||||
if product in ["Pan Integral", "Pan Blanco"]:
|
||||
base_quantity *= 0.7
|
||||
elif temp < 5: # Cold days
|
||||
base_quantity *= 1.1
|
||||
|
||||
# Add realistic noise
|
||||
import numpy as np
|
||||
quantity = max(1, int(base_quantity + np.random.normal(0, base_quantity * 0.15)))
|
||||
|
||||
# Calculate revenue (realistic Spanish bakery prices)
|
||||
price_per_unit = {
|
||||
"Pan Integral": 2.80, "Pan Blanco": 2.50, "Croissant": 1.50,
|
||||
"Magdalenas": 1.20, "Empanadas": 3.50, "Tarta Chocolate": 18.00,
|
||||
"Roscon Reyes": 25.00, "Palmeras": 1.80
|
||||
}.get(product, 2.00)
|
||||
|
||||
revenue = round(quantity * price_per_unit, 2)
|
||||
|
||||
data.append({
|
||||
"date": date.strftime("%Y-%m-%d"),
|
||||
"product": product,
|
||||
"quantity": quantity,
|
||||
"revenue": revenue,
|
||||
"temperature": round(temp + np.random.normal(0, 3), 1),
|
||||
"precipitation": max(0, np.random.exponential(0.8)),
|
||||
"is_weekend": date.weekday() >= 5,
|
||||
"is_holiday": self._is_spanish_holiday(date)
|
||||
})
|
||||
|
||||
return pd.DataFrame(data)
|
||||
|
||||
def _is_spanish_holiday(self, date: datetime) -> bool:
|
||||
"""Check if date is a Spanish holiday"""
|
||||
spanish_holidays = [
|
||||
(1, 1), # Año Nuevo
|
||||
(1, 6), # Reyes Magos
|
||||
(5, 1), # Día del Trabajo
|
||||
(8, 15), # Asunción de la Virgen
|
||||
(10, 12), # Fiesta Nacional de España
|
||||
(11, 1), # Todos los Santos
|
||||
(12, 6), # Día de la Constitución
|
||||
(12, 8), # Inmaculada Concepción
|
||||
(12, 25), # Navidad
|
||||
]
|
||||
return (date.month, date.day) in spanish_holidays
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_external_apis(self):
|
||||
"""Mock external APIs (AEMET and Madrid OpenData)"""
|
||||
with patch('app.external.aemet.AEMETClient') as mock_aemet, \
|
||||
patch('app.external.madrid_opendata.MadridOpenDataClient') as mock_madrid:
|
||||
|
||||
# Mock AEMET weather data
|
||||
mock_aemet_instance = AsyncMock()
|
||||
mock_aemet.return_value = mock_aemet_instance
|
||||
|
||||
# Generate realistic Madrid weather data
|
||||
weather_data = []
|
||||
for i in range(365):
|
||||
date = datetime(2023, 1, 1) + timedelta(days=i)
|
||||
day_of_year = date.timetuple().tm_yday
|
||||
# Madrid climate: hot summers, mild winters
|
||||
base_temp = 14 + 12 * np.sin((day_of_year / 365) * 2 * np.pi)
|
||||
|
||||
weather_data.append({
|
||||
"date": date,
|
||||
"temperature": round(base_temp + np.random.normal(0, 4), 1),
|
||||
"precipitation": max(0, np.random.exponential(1.2)),
|
||||
"humidity": np.random.uniform(25, 75),
|
||||
"wind_speed": np.random.uniform(3, 20),
|
||||
"pressure": np.random.uniform(995, 1025),
|
||||
"description": np.random.choice([
|
||||
"Soleado", "Parcialmente nublado", "Nublado",
|
||||
"Lluvia ligera", "Despejado"
|
||||
]),
|
||||
"source": "aemet"
|
||||
})
|
||||
|
||||
mock_aemet_instance.get_historical_weather.return_value = weather_data
|
||||
mock_aemet_instance.get_current_weather.return_value = weather_data[-1]
|
||||
|
||||
# Mock Madrid traffic data
|
||||
mock_madrid_instance = AsyncMock()
|
||||
mock_madrid.return_value = mock_madrid_instance
|
||||
|
||||
traffic_data = []
|
||||
for i in range(365):
|
||||
date = datetime(2023, 1, 1) + timedelta(days=i)
|
||||
|
||||
# Multiple measurements per day
|
||||
for hour in range(6, 22, 2): # Every 2 hours from 6 AM to 10 PM
|
||||
measurement_time = date.replace(hour=hour)
|
||||
|
||||
# Realistic Madrid traffic patterns
|
||||
if hour in [7, 8, 9, 18, 19, 20]: # Rush hours
|
||||
volume = np.random.randint(1200, 2000)
|
||||
congestion = "high"
|
||||
speed = np.random.randint(10, 25)
|
||||
elif hour in [12, 13, 14]: # Lunch time
|
||||
volume = np.random.randint(800, 1200)
|
||||
congestion = "medium"
|
||||
speed = np.random.randint(20, 35)
|
||||
else: # Off-peak
|
||||
volume = np.random.randint(300, 800)
|
||||
congestion = "low"
|
||||
speed = np.random.randint(30, 50)
|
||||
|
||||
traffic_data.append({
|
||||
"date": measurement_time,
|
||||
"traffic_volume": volume,
|
||||
"occupation_percentage": np.random.randint(15, 85),
|
||||
"load_percentage": np.random.randint(25, 90),
|
||||
"average_speed": speed,
|
||||
"congestion_level": congestion,
|
||||
"pedestrian_count": np.random.randint(100, 800),
|
||||
"measurement_point_id": "MADRID_CENTER_001",
|
||||
"measurement_point_name": "Puerta del Sol",
|
||||
"road_type": "URB",
|
||||
"source": "madrid_opendata"
|
||||
})
|
||||
|
||||
mock_madrid_instance.get_historical_traffic.return_value = traffic_data
|
||||
mock_madrid_instance.get_current_traffic.return_value = traffic_data[-1]
|
||||
|
||||
yield {
|
||||
'aemet': mock_aemet_instance,
|
||||
'madrid': mock_madrid_instance
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_training_workflow_api(
|
||||
self,
|
||||
test_client,
|
||||
real_bakery_data,
|
||||
mock_external_apis
|
||||
):
|
||||
"""Test complete training workflow through API endpoints"""
|
||||
|
||||
# Step 1: Check service health
|
||||
health_response = await test_client.get("/health")
|
||||
assert health_response.status_code == 200
|
||||
health_data = health_response.json()
|
||||
assert health_data["status"] == "healthy"
|
||||
|
||||
# Step 2: Validate training data quality
|
||||
with patch('app.services.training_service.TrainingService._fetch_sales_data',
|
||||
return_value=real_bakery_data):
|
||||
|
||||
validation_response = await test_client.post(
|
||||
"/training/validate",
|
||||
json={
|
||||
"tenant_id": "test_bakery_001",
|
||||
"include_weather": True,
|
||||
"include_traffic": True
|
||||
}
|
||||
)
|
||||
|
||||
assert validation_response.status_code == 200
|
||||
validation_data = validation_response.json()
|
||||
assert validation_data["is_valid"] is True
|
||||
assert validation_data["data_points"] > 1000 # Sufficient data
|
||||
assert validation_data["missing_percentage"] < 10
|
||||
|
||||
# Step 3: Start training job for multiple products
|
||||
training_request = {
|
||||
"products": ["Pan Integral", "Croissant", "Magdalenas"],
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"config": {
|
||||
"seasonality_mode": "additive",
|
||||
"changepoint_prior_scale": 0.05,
|
||||
"seasonality_prior_scale": 10.0,
|
||||
"validation_enabled": True
|
||||
}
|
||||
}
|
||||
|
||||
with patch('app.services.training_service.TrainingService._fetch_sales_data',
|
||||
return_value=real_bakery_data):
|
||||
|
||||
start_response = await test_client.post(
|
||||
"/training/jobs",
|
||||
json=training_request,
|
||||
headers={"X-Tenant-ID": "test_bakery_001"}
|
||||
)
|
||||
|
||||
assert start_response.status_code == 201
|
||||
job_data = start_response.json()
|
||||
job_id = job_data["job_id"]
|
||||
assert job_data["status"] == "pending"
|
||||
|
||||
# Step 4: Monitor job progress
|
||||
max_wait_time = 300 # 5 minutes
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < max_wait_time:
|
||||
status_response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
assert status_response.status_code == 200
|
||||
|
||||
status_data = status_response.json()
|
||||
|
||||
if status_data["status"] == "completed":
|
||||
# Training completed successfully
|
||||
assert "models_trained" in status_data
|
||||
assert len(status_data["models_trained"]) == 3 # Three products
|
||||
|
||||
# Check model quality
|
||||
for model_info in status_data["models_trained"]:
|
||||
assert "product_name" in model_info
|
||||
assert "model_id" in model_info
|
||||
assert "metrics" in model_info
|
||||
|
||||
metrics = model_info["metrics"]
|
||||
assert "mape" in metrics
|
||||
assert "rmse" in metrics
|
||||
assert "mae" in metrics
|
||||
|
||||
# Quality thresholds for bakery data
|
||||
assert metrics["mape"] < 50, f"MAPE too high for {model_info['product_name']}: {metrics['mape']}"
|
||||
assert metrics["rmse"] > 0
|
||||
|
||||
break
|
||||
elif status_data["status"] == "failed":
|
||||
pytest.fail(f"Training job failed: {status_data.get('error_message', 'Unknown error')}")
|
||||
|
||||
# Wait before checking again
|
||||
await asyncio.sleep(10)
|
||||
else:
|
||||
pytest.fail(f"Training job did not complete within {max_wait_time} seconds")
|
||||
|
||||
# Step 5: Get detailed job logs
|
||||
logs_response = await test_client.get(f"/training/jobs/{job_id}/logs")
|
||||
assert logs_response.status_code == 200
|
||||
logs_data = logs_response.json()
|
||||
assert "logs" in logs_data
|
||||
assert len(logs_data["logs"]) > 0
|
||||
@@ -1,848 +0,0 @@
|
||||
# services/training/tests/test_integration.py
|
||||
"""
|
||||
Integration tests for training service
|
||||
Tests complete workflows and service interactions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from httpx import AsyncClient
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.main import app
|
||||
from app.schemas.training import TrainingJobRequest
|
||||
|
||||
|
||||
class TestTrainingWorkflowIntegration:
|
||||
"""Test complete training workflows end-to-end"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_training_workflow(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
test_db_session,
|
||||
mock_messaging,
|
||||
mock_data_service,
|
||||
mock_ml_trainer
|
||||
):
|
||||
"""Test complete training workflow from API to completion"""
|
||||
|
||||
# Step 1: Start training job
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30,
|
||||
"seasonality_mode": "additive"
|
||||
}
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
job_data = response.json()
|
||||
job_id = job_data["job_id"]
|
||||
|
||||
# Step 2: Check initial status
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
assert response.status_code == 200
|
||||
status_data = response.json()
|
||||
assert status_data["status"] in ["pending", "started"]
|
||||
|
||||
# Step 3: Simulate background task completion
|
||||
# In real scenario, this would be handled by background tasks
|
||||
await asyncio.sleep(0.1) # Allow background task to start
|
||||
|
||||
# Step 4: Check completion status
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
# The job should exist in database even if not completed yet
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_product_training_workflow(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_data_service,
|
||||
mock_ml_trainer
|
||||
):
|
||||
"""Test single product training complete workflow"""
|
||||
|
||||
product_name = "Pan Integral"
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": False,
|
||||
"seasonality_mode": "additive"
|
||||
}
|
||||
|
||||
# Start single product training
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post(
|
||||
f"/training/products/{product_name}",
|
||||
json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
job_data = response.json()
|
||||
job_id = job_data["job_id"]
|
||||
assert f"training started for {product_name}" in job_data["message"].lower()
|
||||
|
||||
# Check job status
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
assert response.status_code == 200
|
||||
status_data = response.json()
|
||||
assert status_data["job_id"] == job_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_training_validation_workflow(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test training data validation workflow"""
|
||||
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
# Validate training data
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/validate", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
validation_data = response.json()
|
||||
|
||||
assert "is_valid" in validation_data
|
||||
assert "issues" in validation_data
|
||||
assert "recommendations" in validation_data
|
||||
assert "estimated_training_time" in validation_data
|
||||
|
||||
# If validation passes, start actual training
|
||||
if validation_data["is_valid"]:
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_job_cancellation_workflow(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db,
|
||||
mock_messaging
|
||||
):
|
||||
"""Test job cancellation workflow"""
|
||||
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
# Check initial status
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
assert response.status_code == 200
|
||||
initial_status = response.json()
|
||||
assert initial_status["status"] == "pending"
|
||||
|
||||
# Cancel the job
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post(f"/training/jobs/{job_id}/cancel")
|
||||
|
||||
assert response.status_code == 200
|
||||
cancel_response = response.json()
|
||||
assert "cancelled" in cancel_response["message"].lower()
|
||||
|
||||
# Verify cancellation
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
assert response.status_code == 200
|
||||
final_status = response.json()
|
||||
assert final_status["status"] == "cancelled"
|
||||
|
||||
|
||||
class TestServiceInteractionIntegration:
|
||||
"""Test interactions between training service and external services"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_service_integration(self, training_service, mock_data_service):
|
||||
"""Test integration with data service"""
|
||||
from app.schemas.training import TrainingJobRequest
|
||||
|
||||
request = TrainingJobRequest(
|
||||
include_weather=True,
|
||||
include_traffic=True,
|
||||
min_data_points=30
|
||||
)
|
||||
|
||||
# Test sales data fetching
|
||||
sales_data = await training_service._fetch_sales_data("test-tenant", request)
|
||||
assert isinstance(sales_data, list)
|
||||
|
||||
# Test weather data fetching
|
||||
weather_data = await training_service._fetch_weather_data("test-tenant", request)
|
||||
assert isinstance(weather_data, list)
|
||||
|
||||
# Test traffic data fetching
|
||||
traffic_data = await training_service._fetch_traffic_data("test-tenant", request)
|
||||
assert isinstance(traffic_data, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messaging_integration(self, mock_messaging):
|
||||
"""Test integration with messaging system"""
|
||||
from app.services.messaging import (
|
||||
publish_job_started,
|
||||
publish_job_completed,
|
||||
publish_model_trained
|
||||
)
|
||||
|
||||
# Test various message types
|
||||
result1 = await publish_job_started("job-123", "tenant-123", {})
|
||||
result2 = await publish_job_completed("job-123", "tenant-123", {"status": "success"})
|
||||
result3 = await publish_model_trained("model-123", "tenant-123", "Pan Integral", {"mae": 5.0})
|
||||
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
assert result3 is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_integration(self, test_db_session, training_service):
|
||||
"""Test database operations integration"""
|
||||
|
||||
# Create a training job
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="integration-test-job",
|
||||
config={"test": True}
|
||||
)
|
||||
|
||||
assert job.job_id == "integration-test-job"
|
||||
|
||||
# Update job status
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="running",
|
||||
progress=50,
|
||||
current_step="Processing data"
|
||||
)
|
||||
|
||||
# Retrieve updated job
|
||||
updated_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert updated_job.status == "running"
|
||||
assert updated_job.progress == 50
|
||||
|
||||
|
||||
class TestErrorHandlingIntegration:
|
||||
"""Test error handling across service boundaries"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_service_failure_handling(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging
|
||||
):
|
||||
"""Test handling when data service is unavailable"""
|
||||
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
# Mock data service failure
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get.side_effect = Exception("Service unavailable")
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
# Should still create job but might fail during execution
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messaging_failure_handling(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test handling when messaging fails"""
|
||||
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
# Mock messaging failure
|
||||
with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")):
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
# Should still succeed even if messaging fails
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ml_training_failure_handling(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test handling when ML training fails"""
|
||||
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
# Mock ML training failure
|
||||
with patch('app.ml.trainer.BakeryMLTrainer.train_tenant_models', side_effect=Exception("ML training failed")):
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=request_data)
|
||||
|
||||
# Job should be created successfully
|
||||
assert response.status_code == 200
|
||||
|
||||
# Background task would handle the failure
|
||||
|
||||
|
||||
class TestPerformanceIntegration:
|
||||
"""Test performance characteristics of integrated workflows"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_training_jobs(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
mock_messaging,
|
||||
mock_data_service,
|
||||
mock_ml_trainer
|
||||
):
|
||||
"""Test handling multiple concurrent training jobs"""
|
||||
|
||||
request_data = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30
|
||||
}
|
||||
|
||||
# Start multiple jobs concurrently
|
||||
tasks = []
|
||||
for i in range(5):
|
||||
with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"):
|
||||
task = test_client.post("/training/jobs", json=request_data)
|
||||
tasks.append(task)
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
# All jobs should be created successfully
|
||||
for response in responses:
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "job_id" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_dataset_handling(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test handling of large datasets"""
|
||||
|
||||
# Simulate large dataset
|
||||
large_config = {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 1000, # Large minimum
|
||||
"products": [f"Product-{i}" for i in range(100)] # Many products
|
||||
}
|
||||
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="large-dataset-job",
|
||||
config=large_config
|
||||
)
|
||||
|
||||
assert job.config == large_config
|
||||
assert job.job_id == "large-dataset-job"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_status_checks(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test rapid successive status checks"""
|
||||
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
# Make many rapid status requests
|
||||
tasks = []
|
||||
for _ in range(20):
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
task = test_client.get(f"/training/jobs/{job_id}/status")
|
||||
tasks.append(task)
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
# All requests should succeed
|
||||
for response in responses:
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestSecurityIntegration:
|
||||
"""Test security aspects of service integration"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tenant_isolation(
|
||||
self,
|
||||
test_client: AsyncClient,
|
||||
training_job_in_db,
|
||||
mock_messaging
|
||||
):
|
||||
"""Test that tenants cannot access each other's jobs"""
|
||||
|
||||
job_id = training_job_in_db.job_id
|
||||
|
||||
# Try to access job with different tenant ID
|
||||
with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{job_id}/status")
|
||||
|
||||
# Should not find the job (belongs to different tenant)
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_validation_integration(
|
||||
self,
|
||||
test_client: AsyncClient
|
||||
):
|
||||
"""Test input validation across API boundaries"""
|
||||
|
||||
# Test invalid seasonality mode
|
||||
invalid_request = {
|
||||
"seasonality_mode": "invalid_mode",
|
||||
"min_data_points": -5 # Invalid negative value
|
||||
}
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=invalid_request)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sql_injection_protection(
|
||||
self,
|
||||
test_client: AsyncClient
|
||||
):
|
||||
"""Test protection against SQL injection attempts"""
|
||||
|
||||
# Try SQL injection in job ID
|
||||
malicious_job_id = "job'; DROP TABLE model_training_logs; --"
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.get(f"/training/jobs/{malicious_job_id}/status")
|
||||
|
||||
# Should return 404, not cause database error
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestRecoveryIntegration:
|
||||
"""Test recovery and resilience scenarios"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_restart_recovery(
|
||||
self,
|
||||
test_db_session,
|
||||
training_service,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test service recovery after restart"""
|
||||
|
||||
# Simulate service restart by creating new service instance
|
||||
new_training_service = training_service.__class__()
|
||||
|
||||
# Should be able to access existing jobs
|
||||
existing_job = await new_training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=training_job_in_db.job_id,
|
||||
tenant_id=training_job_in_db.tenant_id
|
||||
)
|
||||
|
||||
assert existing_job is not None
|
||||
assert existing_job.job_id == training_job_in_db.job_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_failure_recovery(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test recovery from partial failures"""
|
||||
|
||||
# Create job that might fail partway through
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="partial-failure-job",
|
||||
config={"simulate_failure": True}
|
||||
)
|
||||
|
||||
# Simulate partial progress
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="running",
|
||||
progress=50,
|
||||
current_step="Halfway through training"
|
||||
)
|
||||
|
||||
# Simulate failure
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="failed",
|
||||
progress=50,
|
||||
current_step="Training failed",
|
||||
error_message="Simulated failure"
|
||||
)
|
||||
|
||||
# Verify failure was recorded
|
||||
failed_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert failed_job.status == "failed"
|
||||
assert failed_job.error_message == "Simulated failure"
|
||||
assert failed_job.progress == 50
|
||||
|
||||
|
||||
class TestComplianceIntegration:
|
||||
"""Test compliance and audit requirements"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_trail_creation(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test that audit trail is properly created"""
|
||||
|
||||
# Create and update job
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="audit-test-job",
|
||||
config={"audit_test": True}
|
||||
)
|
||||
|
||||
# Multiple status updates
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="running",
|
||||
progress=25,
|
||||
current_step="Started processing"
|
||||
)
|
||||
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="running",
|
||||
progress=75,
|
||||
current_step="Almost complete"
|
||||
)
|
||||
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
current_step="Completed successfully"
|
||||
)
|
||||
|
||||
# Verify audit trail
|
||||
logs = await training_service.get_training_logs(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert logs is not None
|
||||
assert len(logs) > 0
|
||||
|
||||
# Check final status
|
||||
final_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert final_job.status == "completed"
|
||||
assert final_job.progress == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_retention_compliance(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test data retention and cleanup compliance"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Create old job (simulate old data)
|
||||
old_job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="old-job",
|
||||
config={"created_long_ago": True}
|
||||
)
|
||||
|
||||
# Manually set old timestamp
|
||||
from sqlalchemy import update
|
||||
from app.models.training import ModelTrainingLog
|
||||
|
||||
old_timestamp = datetime.now() - timedelta(days=400)
|
||||
await test_db_session.execute(
|
||||
update(ModelTrainingLog)
|
||||
.where(ModelTrainingLog.job_id == old_job.job_id)
|
||||
.values(start_time=old_timestamp, created_at=old_timestamp)
|
||||
)
|
||||
await test_db_session.commit()
|
||||
|
||||
# Verify old job exists
|
||||
retrieved_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=old_job.job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert retrieved_job is not None
|
||||
# In a real implementation, there would be cleanup procedures
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gdpr_compliance_features(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test GDPR compliance features"""
|
||||
|
||||
# Create job with tenant data
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="gdpr-test-tenant",
|
||||
job_id="gdpr-test-job",
|
||||
config={"gdpr_test": True}
|
||||
)
|
||||
|
||||
# Verify job is associated with tenant
|
||||
assert job.tenant_id == "gdpr-test-tenant"
|
||||
|
||||
# Test data access (right to access)
|
||||
tenant_jobs = await training_service.list_training_jobs(
|
||||
db=test_db_session,
|
||||
tenant_id="gdpr-test-tenant"
|
||||
)
|
||||
|
||||
assert len(tenant_jobs) >= 1
|
||||
assert any(job.job_id == "gdpr-test-job" for job in tenant_jobs)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestLongRunningIntegration:
|
||||
"""Test long-running integration scenarios (marked as slow)"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extended_training_simulation(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
mock_messaging
|
||||
):
|
||||
"""Test extended training process simulation"""
|
||||
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="long-running-job",
|
||||
config={"extended_test": True}
|
||||
)
|
||||
|
||||
# Simulate progress over time
|
||||
progress_steps = [
|
||||
(10, "Initializing"),
|
||||
(25, "Loading data"),
|
||||
(50, "Training models"),
|
||||
(75, "Validating results"),
|
||||
(90, "Storing models"),
|
||||
(100, "Completed")
|
||||
]
|
||||
|
||||
for progress, step in progress_steps:
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="running" if progress < 100 else "completed",
|
||||
progress=progress,
|
||||
current_step=step
|
||||
)
|
||||
|
||||
# Small delay to simulate real progression
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Verify final state
|
||||
final_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert final_job.status == "completed"
|
||||
assert final_job.progress == 100
|
||||
assert final_job.current_step == "Completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_usage_stability(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test memory usage stability over many operations"""
|
||||
|
||||
# Create many jobs to test memory stability
|
||||
for i in range(50):
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id=f"tenant-{i % 5}", # 5 different tenants
|
||||
job_id=f"memory-test-job-{i}",
|
||||
config={"iteration": i}
|
||||
)
|
||||
|
||||
# Update status
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job.job_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
current_step="Completed"
|
||||
)
|
||||
|
||||
# List jobs for each tenant
|
||||
for tenant_i in range(5):
|
||||
tenant_id = f"tenant-{tenant_i}"
|
||||
jobs = await training_service.list_training_jobs(
|
||||
db=test_db_session,
|
||||
tenant_id=tenant_id,
|
||||
limit=20
|
||||
)
|
||||
|
||||
# Should have 10 jobs per tenant (50 total / 5 tenants)
|
||||
assert len(jobs) == 10
|
||||
|
||||
|
||||
class TestBackwardCompatibility:
|
||||
"""Test backward compatibility with existing systems"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_config_handling(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test handling of legacy configuration formats"""
|
||||
|
||||
# Test with old-style configuration
|
||||
legacy_config = {
|
||||
"weather_enabled": True, # Old key
|
||||
"traffic_enabled": True, # Old key
|
||||
"minimum_samples": 30, # Old key
|
||||
"prophet_config": { # Old nested structure
|
||||
"seasonality": "additive"
|
||||
}
|
||||
}
|
||||
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="legacy-config-job",
|
||||
config=legacy_config
|
||||
)
|
||||
|
||||
assert job.config == legacy_config
|
||||
assert job.job_id == "legacy-config-job"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_version_compatibility(
|
||||
self,
|
||||
test_client: AsyncClient
|
||||
):
|
||||
"""Test API version compatibility"""
|
||||
|
||||
# Test with minimal request (old API style)
|
||||
minimal_request = {
|
||||
"include_weather": True
|
||||
}
|
||||
|
||||
with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
|
||||
response = await test_client.post("/training/jobs", json=minimal_request)
|
||||
|
||||
# Should work with defaults for missing fields
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "job_id" in data
|
||||
|
||||
|
||||
# Utility functions for integration tests
|
||||
async def wait_for_condition(condition_func, timeout=5.0, interval=0.1):
|
||||
"""Wait for a condition to become true"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
if await condition_func():
|
||||
return True
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def assert_job_progression(job_updates):
|
||||
"""Assert that job updates show proper progression"""
|
||||
assert len(job_updates) > 0
|
||||
|
||||
# Check progress is non-decreasing
|
||||
for i in range(1, len(job_updates)):
|
||||
assert job_updates[i]["progress"] >= job_updates[i-1]["progress"]
|
||||
|
||||
# Check final status
|
||||
final_update = job_updates[-1]
|
||||
assert final_update["status"] in ["completed", "failed", "cancelled"]
|
||||
|
||||
|
||||
def assert_valid_job_structure(job_data):
|
||||
"""Assert job data has valid structure"""
|
||||
required_fields = ["job_id", "status", "tenant_id"]
|
||||
for field in required_fields:
|
||||
assert field in job_data
|
||||
|
||||
assert isinstance(job_data["progress"], int)
|
||||
assert 0 <= job_data["progress"] <= 100
|
||||
assert job_data["status"] in ["pending", "running", "completed", "failed", "cancelled"]
|
||||
@@ -1,467 +0,0 @@
|
||||
# services/training/tests/test_messaging.py
|
||||
"""
|
||||
Tests for training service messaging functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
import json
|
||||
|
||||
from app.services import messaging
|
||||
|
||||
|
||||
class TestTrainingMessaging:
|
||||
"""Test training service messaging functions"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_publisher(self):
|
||||
"""Mock the RabbitMQ publisher"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
mock_pub.connect = AsyncMock(return_value=True)
|
||||
mock_pub.disconnect = AsyncMock(return_value=None)
|
||||
yield mock_pub
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_messaging_success(self, mock_publisher):
|
||||
"""Test successful messaging setup"""
|
||||
await messaging.setup_messaging()
|
||||
|
||||
mock_publisher.connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_messaging_failure(self, mock_publisher):
|
||||
"""Test messaging setup failure"""
|
||||
mock_publisher.connect.return_value = False
|
||||
|
||||
await messaging.setup_messaging()
|
||||
|
||||
mock_publisher.connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_messaging(self, mock_publisher):
|
||||
"""Test messaging cleanup"""
|
||||
await messaging.cleanup_messaging()
|
||||
|
||||
mock_publisher.disconnect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_started(self, mock_publisher):
|
||||
"""Test publishing job started event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
config = {"include_weather": True}
|
||||
|
||||
result = await messaging.publish_job_started(job_id, tenant_id, config)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
# Check call arguments
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["exchange_name"] == "training.events"
|
||||
assert call_args[1]["routing_key"] == "training.started"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["service_name"] == "training-service"
|
||||
assert event_data["data"]["job_id"] == job_id
|
||||
assert event_data["data"]["tenant_id"] == tenant_id
|
||||
assert event_data["data"]["config"] == config
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_progress(self, mock_publisher):
|
||||
"""Test publishing job progress event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
progress = 50
|
||||
step = "Training models"
|
||||
|
||||
result = await messaging.publish_job_progress(job_id, tenant_id, progress, step)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.progress"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["progress"] == progress
|
||||
assert event_data["data"]["current_step"] == step
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_completed(self, mock_publisher):
|
||||
"""Test publishing job completed event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
results = {
|
||||
"products_trained": 3,
|
||||
"summary": {"success_rate": 100.0}
|
||||
}
|
||||
|
||||
result = await messaging.publish_job_completed(job_id, tenant_id, results)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.completed"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["results"] == results
|
||||
assert event_data["data"]["models_trained"] == 3
|
||||
assert event_data["data"]["success_rate"] == 100.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_failed(self, mock_publisher):
|
||||
"""Test publishing job failed event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
error = "Data service unavailable"
|
||||
|
||||
result = await messaging.publish_job_failed(job_id, tenant_id, error)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.failed"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["error"] == error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_job_cancelled(self, mock_publisher):
|
||||
"""Test publishing job cancelled event"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
|
||||
result = await messaging.publish_job_cancelled(job_id, tenant_id)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.cancelled"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_product_training_started(self, mock_publisher):
|
||||
"""Test publishing product training started event"""
|
||||
job_id = "test-product-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
|
||||
result = await messaging.publish_product_training_started(job_id, tenant_id, product_name)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.product.started"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["product_name"] == product_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_product_training_completed(self, mock_publisher):
|
||||
"""Test publishing product training completed event"""
|
||||
job_id = "test-product-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
model_id = "test-model-123"
|
||||
|
||||
result = await messaging.publish_product_training_completed(
|
||||
job_id, tenant_id, product_name, model_id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.product.completed"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["model_id"] == model_id
|
||||
assert event_data["data"]["product_name"] == product_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_model_trained(self, mock_publisher):
|
||||
"""Test publishing model trained event"""
|
||||
model_id = "test-model-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
metrics = {"mae": 5.2, "rmse": 7.8, "mape": 12.5}
|
||||
|
||||
result = await messaging.publish_model_trained(model_id, tenant_id, product_name, metrics)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.model.trained"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["training_metrics"] == metrics
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_model_updated(self, mock_publisher):
|
||||
"""Test publishing model updated event"""
|
||||
model_id = "test-model-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
version = 2
|
||||
|
||||
result = await messaging.publish_model_updated(model_id, tenant_id, product_name, version)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.model.updated"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["version"] == version
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_model_validated(self, mock_publisher):
|
||||
"""Test publishing model validated event"""
|
||||
model_id = "test-model-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
validation_results = {"is_valid": True, "accuracy": 0.95}
|
||||
|
||||
result = await messaging.publish_model_validated(
|
||||
model_id, tenant_id, product_name, validation_results
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.model.validated"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["validation_results"] == validation_results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_model_saved(self, mock_publisher):
|
||||
"""Test publishing model saved event"""
|
||||
model_id = "test-model-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
model_path = "/models/test-model-123.pkl"
|
||||
|
||||
result = await messaging.publish_model_saved(model_id, tenant_id, product_name, model_path)
|
||||
|
||||
assert result is True
|
||||
mock_publisher.publish_event.assert_called_once()
|
||||
|
||||
call_args = mock_publisher.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == "training.model.saved"
|
||||
|
||||
event_data = call_args[1]["event_data"]
|
||||
assert event_data["data"]["model_path"] == model_path
|
||||
|
||||
|
||||
class TestMessagingErrorHandling:
|
||||
"""Test error handling in messaging"""
|
||||
|
||||
@pytest.fixture
|
||||
def failing_publisher(self):
|
||||
"""Mock publisher that fails"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=False)
|
||||
mock_pub.connect = AsyncMock(return_value=False)
|
||||
yield mock_pub
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_event_failure(self, failing_publisher):
|
||||
"""Test handling of publish event failure"""
|
||||
result = await messaging.publish_job_started("job-123", "tenant-123", {})
|
||||
|
||||
assert result is False
|
||||
failing_publisher.publish_event.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_messaging_connection_failure(self, failing_publisher):
|
||||
"""Test setup with connection failure"""
|
||||
await messaging.setup_messaging()
|
||||
|
||||
failing_publisher.connect.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_with_exception(self):
|
||||
"""Test publishing with exception"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event.side_effect = Exception("Connection lost")
|
||||
|
||||
result = await messaging.publish_job_started("job-123", "tenant-123", {})
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestMessagingIntegration:
|
||||
"""Test messaging integration with shared components"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_structure_consistency(self):
|
||||
"""Test that events follow consistent structure"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Test different event types
|
||||
await messaging.publish_job_started("job-123", "tenant-123", {})
|
||||
await messaging.publish_job_completed("job-123", "tenant-123", {})
|
||||
await messaging.publish_model_trained("model-123", "tenant-123", "Pan", {})
|
||||
|
||||
# Verify all calls have consistent structure
|
||||
assert mock_pub.publish_event.call_count == 3
|
||||
|
||||
for call in mock_pub.publish_event.call_args_list:
|
||||
event_data = call[1]["event_data"]
|
||||
|
||||
# All events should have these fields
|
||||
assert "service_name" in event_data
|
||||
assert "event_type" in event_data
|
||||
assert "data" in event_data
|
||||
assert event_data["service_name"] == "training-service"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shared_event_classes_usage(self):
|
||||
"""Test that shared event classes are used properly"""
|
||||
with patch('shared.messaging.events.TrainingStartedEvent') as mock_event_class:
|
||||
mock_event = Mock()
|
||||
mock_event.to_dict.return_value = {
|
||||
"service_name": "training-service",
|
||||
"event_type": "training.started",
|
||||
"data": {"job_id": "test-job"}
|
||||
}
|
||||
mock_event_class.return_value = mock_event
|
||||
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
await messaging.publish_job_started("test-job", "test-tenant", {})
|
||||
|
||||
# Verify shared event class was used
|
||||
mock_event_class.assert_called_once()
|
||||
mock_event.to_dict.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routing_key_consistency(self):
|
||||
"""Test that routing keys follow consistent patterns"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Test various event types
|
||||
events_and_keys = [
|
||||
(messaging.publish_job_started, "training.started"),
|
||||
(messaging.publish_job_progress, "training.progress"),
|
||||
(messaging.publish_job_completed, "training.completed"),
|
||||
(messaging.publish_job_failed, "training.failed"),
|
||||
(messaging.publish_job_cancelled, "training.cancelled"),
|
||||
(messaging.publish_product_training_started, "training.product.started"),
|
||||
(messaging.publish_product_training_completed, "training.product.completed"),
|
||||
(messaging.publish_model_trained, "training.model.trained"),
|
||||
(messaging.publish_model_updated, "training.model.updated"),
|
||||
(messaging.publish_model_validated, "training.model.validated"),
|
||||
(messaging.publish_model_saved, "training.model.saved")
|
||||
]
|
||||
|
||||
for event_func, expected_key in events_and_keys:
|
||||
mock_pub.reset_mock()
|
||||
|
||||
# Call event function with appropriate parameters
|
||||
if "progress" in expected_key:
|
||||
await event_func("job-123", "tenant-123", 50, "step")
|
||||
elif "model" in expected_key and "trained" in expected_key:
|
||||
await event_func("model-123", "tenant-123", "product", {})
|
||||
elif "model" in expected_key and "updated" in expected_key:
|
||||
await event_func("model-123", "tenant-123", "product", 1)
|
||||
elif "model" in expected_key and "validated" in expected_key:
|
||||
await event_func("model-123", "tenant-123", "product", {})
|
||||
elif "model" in expected_key and "saved" in expected_key:
|
||||
await event_func("model-123", "tenant-123", "product", "/path")
|
||||
elif "product" in expected_key and "completed" in expected_key:
|
||||
await event_func("job-123", "tenant-123", "product", "model-123")
|
||||
elif "product" in expected_key:
|
||||
await event_func("job-123", "tenant-123", "product")
|
||||
elif "failed" in expected_key:
|
||||
await event_func("job-123", "tenant-123", "error")
|
||||
elif "cancelled" in expected_key:
|
||||
await event_func("job-123", "tenant-123")
|
||||
else:
|
||||
await event_func("job-123", "tenant-123", {})
|
||||
|
||||
# Verify routing key
|
||||
call_args = mock_pub.publish_event.call_args
|
||||
assert call_args[1]["routing_key"] == expected_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_consistency(self):
|
||||
"""Test that all events use the same exchange"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Test multiple events
|
||||
await messaging.publish_job_started("job-123", "tenant-123", {})
|
||||
await messaging.publish_model_trained("model-123", "tenant-123", "product", {})
|
||||
await messaging.publish_product_training_started("job-123", "tenant-123", "product")
|
||||
|
||||
# Verify all use same exchange
|
||||
for call in mock_pub.publish_event.call_args_list:
|
||||
assert call[1]["exchange_name"] == "training.events"
|
||||
|
||||
|
||||
class TestMessagingPerformance:
|
||||
"""Test messaging performance and reliability"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_publishing(self):
|
||||
"""Test concurrent event publishing"""
|
||||
import asyncio
|
||||
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Create multiple concurrent publishing tasks
|
||||
tasks = []
|
||||
for i in range(10):
|
||||
task = messaging.publish_job_progress(f"job-{i}", "tenant-123", i * 10, f"step-{i}")
|
||||
tasks.append(task)
|
||||
|
||||
# Execute all tasks concurrently
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify all succeeded
|
||||
assert all(results)
|
||||
assert mock_pub.publish_event.call_count == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_event_data(self):
|
||||
"""Test publishing events with large data payloads"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Create large config data
|
||||
large_config = {
|
||||
"products": [f"Product-{i}" for i in range(1000)],
|
||||
"features": [f"feature-{i}" for i in range(100)],
|
||||
"hyperparameters": {f"param-{i}": i for i in range(50)}
|
||||
}
|
||||
|
||||
result = await messaging.publish_job_started("job-123", "tenant-123", large_config)
|
||||
|
||||
assert result is True
|
||||
mock_pub.publish_event.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_sequential_publishing(self):
|
||||
"""Test rapid sequential event publishing"""
|
||||
with patch('app.services.messaging.training_publisher') as mock_pub:
|
||||
mock_pub.publish_event = AsyncMock(return_value=True)
|
||||
|
||||
# Publish many events in sequence
|
||||
for i in range(100):
|
||||
await messaging.publish_job_progress("job-123", "tenant-123", i, f"step-{i}")
|
||||
|
||||
assert mock_pub.publish_event.call_count == 100
|
||||
@@ -1,630 +0,0 @@
|
||||
# ================================================================
|
||||
# services/training/tests/test_performance.py
|
||||
# ================================================================
|
||||
"""
|
||||
Performance and Load Testing for Training Service
|
||||
Tests training performance with real-world data volumes
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import psutil
|
||||
import gc
|
||||
from typing import List, Dict, Any
|
||||
import logging
|
||||
|
||||
from app.ml.trainer import BakeryMLTrainer
|
||||
from app.ml.data_processor import BakeryDataProcessor
|
||||
from app.services.training_service import TrainingService
|
||||
|
||||
|
||||
class TestTrainingPerformance:
|
||||
"""Performance tests for training service components"""
|
||||
|
||||
@pytest.fixture
|
||||
def large_sales_dataset(self):
|
||||
"""Generate large dataset for performance testing (2 years of data)"""
|
||||
start_date = datetime(2022, 1, 1)
|
||||
end_date = datetime(2024, 1, 1)
|
||||
|
||||
date_range = pd.date_range(start=start_date, end=end_date, freq='D')
|
||||
products = [
|
||||
"Pan Integral", "Pan Blanco", "Croissant", "Magdalenas",
|
||||
"Empanadas", "Tarta Chocolate", "Roscon Reyes", "Palmeras",
|
||||
"Donuts", "Berlinas", "Napolitanas", "Ensaimadas"
|
||||
]
|
||||
|
||||
data = []
|
||||
for date in date_range:
|
||||
for product in products:
|
||||
# Realistic sales simulation
|
||||
base_quantity = np.random.randint(5, 150)
|
||||
|
||||
# Seasonal patterns
|
||||
if date.month in [12, 1]: # Winter/Holiday season
|
||||
base_quantity *= 1.4
|
||||
elif date.month in [6, 7, 8]: # Summer
|
||||
base_quantity *= 0.8
|
||||
|
||||
# Weekly patterns
|
||||
if date.weekday() >= 5: # Weekends
|
||||
base_quantity *= 1.2
|
||||
elif date.weekday() == 0: # Monday
|
||||
base_quantity *= 0.7
|
||||
|
||||
# Add noise
|
||||
quantity = max(1, int(base_quantity + np.random.normal(0, base_quantity * 0.1)))
|
||||
|
||||
data.append({
|
||||
"date": date.strftime("%Y-%m-%d"),
|
||||
"product": product,
|
||||
"quantity": quantity,
|
||||
"revenue": round(quantity * np.random.uniform(1.5, 8.0), 2),
|
||||
"temperature": round(15 + 12 * np.sin((date.timetuple().tm_yday / 365) * 2 * np.pi) + np.random.normal(0, 3), 1),
|
||||
"precipitation": max(0, np.random.exponential(0.8)),
|
||||
"is_weekend": date.weekday() >= 5,
|
||||
"is_holiday": self._is_spanish_holiday(date)
|
||||
})
|
||||
|
||||
return pd.DataFrame(data)
|
||||
|
||||
def _is_spanish_holiday(self, date: datetime) -> bool:
|
||||
"""Check if date is a Spanish holiday"""
|
||||
holidays = [
|
||||
(1, 1), # New Year
|
||||
(1, 6), # Epiphany
|
||||
(5, 1), # Labor Day
|
||||
(8, 15), # Assumption
|
||||
(10, 12), # National Day
|
||||
(11, 1), # All Saints
|
||||
(12, 6), # Constitution Day
|
||||
(12, 8), # Immaculate Conception
|
||||
(12, 25), # Christmas
|
||||
]
|
||||
return (date.month, date.day) in holidays
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_product_training_performance(self, large_sales_dataset):
|
||||
"""Test performance of single product training with large dataset"""
|
||||
|
||||
trainer = BakeryMLTrainer()
|
||||
product_data = large_sales_dataset[large_sales_dataset['product'] == 'Pan Integral'].copy()
|
||||
|
||||
# Measure memory before training
|
||||
process = psutil.Process()
|
||||
memory_before = process.memory_info().rss / 1024 / 1024 # MB
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
result = await trainer.train_single_product(
|
||||
tenant_id="perf_test_tenant",
|
||||
product_name="Pan Integral",
|
||||
sales_data=product_data,
|
||||
config={
|
||||
"include_weather": True,
|
||||
"include_traffic": False, # Skip traffic for performance
|
||||
"seasonality_mode": "additive"
|
||||
}
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
training_duration = end_time - start_time
|
||||
|
||||
# Measure memory after training
|
||||
memory_after = process.memory_info().rss / 1024 / 1024 # MB
|
||||
memory_used = memory_after - memory_before
|
||||
|
||||
# Performance assertions
|
||||
assert training_duration < 120, f"Training took too long: {training_duration:.2f}s"
|
||||
assert memory_used < 500, f"Memory usage too high: {memory_used:.2f}MB"
|
||||
assert result['status'] == 'completed'
|
||||
|
||||
# Quality assertions
|
||||
metrics = result['metrics']
|
||||
assert metrics['mape'] < 50, f"MAPE too high: {metrics['mape']:.2f}%"
|
||||
|
||||
print(f"Performance Results:")
|
||||
print(f" Training Duration: {training_duration:.2f}s")
|
||||
print(f" Memory Used: {memory_used:.2f}MB")
|
||||
print(f" Data Points: {len(product_data)}")
|
||||
print(f" MAPE: {metrics['mape']:.2f}%")
|
||||
print(f" RMSE: {metrics['rmse']:.2f}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_training_performance(self, large_sales_dataset):
|
||||
"""Test performance of concurrent training jobs"""
|
||||
|
||||
trainer = BakeryMLTrainer()
|
||||
products = ["Pan Integral", "Croissant", "Magdalenas"]
|
||||
|
||||
async def train_product(product_name: str):
|
||||
"""Train a single product"""
|
||||
product_data = large_sales_dataset[large_sales_dataset['product'] == product_name].copy()
|
||||
|
||||
start_time = time.time()
|
||||
result = await trainer.train_single_product(
|
||||
tenant_id=f"concurrent_test_{product_name.replace(' ', '_').lower()}",
|
||||
product_name=product_name,
|
||||
sales_data=product_data,
|
||||
config={"include_weather": True, "include_traffic": False}
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
return {
|
||||
'product': product_name,
|
||||
'duration': end_time - start_time,
|
||||
'status': result['status'],
|
||||
'metrics': result.get('metrics', {})
|
||||
}
|
||||
|
||||
# Run concurrent training
|
||||
start_time = time.time()
|
||||
tasks = [train_product(product) for product in products]
|
||||
results = await asyncio.gather(*tasks)
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Verify all trainings completed
|
||||
for result in results:
|
||||
assert result['status'] == 'completed'
|
||||
assert result['duration'] < 120 # Individual training time
|
||||
|
||||
# Concurrent execution should be faster than sequential
|
||||
sequential_time_estimate = sum(r['duration'] for r in results)
|
||||
efficiency = sequential_time_estimate / total_time
|
||||
|
||||
assert efficiency > 1.5, f"Concurrency efficiency too low: {efficiency:.2f}x"
|
||||
|
||||
print(f"Concurrent Training Results:")
|
||||
print(f" Total Time: {total_time:.2f}s")
|
||||
print(f" Sequential Estimate: {sequential_time_estimate:.2f}s")
|
||||
print(f" Efficiency: {efficiency:.2f}x")
|
||||
|
||||
for result in results:
|
||||
print(f" {result['product']}: {result['duration']:.2f}s, MAPE: {result['metrics'].get('mape', 'N/A'):.2f}%")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_processing_scalability(self, large_sales_dataset):
|
||||
"""Test data processing performance with increasing data sizes"""
|
||||
|
||||
data_processor = BakeryDataProcessor()
|
||||
|
||||
# Test with different data sizes
|
||||
data_sizes = [1000, 5000, 10000, 20000, len(large_sales_dataset)]
|
||||
performance_results = []
|
||||
|
||||
for size in data_sizes:
|
||||
# Take a sample of the specified size
|
||||
sample_data = large_sales_dataset.head(size).copy()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Process the data
|
||||
processed_data = await data_processor.prepare_training_data(
|
||||
sales_data=sample_data,
|
||||
include_weather=True,
|
||||
include_traffic=True,
|
||||
tenant_id="scalability_test",
|
||||
product_name="Pan Integral"
|
||||
)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
performance_results.append({
|
||||
'data_size': size,
|
||||
'processing_time': processing_time,
|
||||
'processed_rows': len(processed_data),
|
||||
'throughput': size / processing_time if processing_time > 0 else 0
|
||||
})
|
||||
|
||||
# Verify linear or sub-linear scaling
|
||||
for i in range(1, len(performance_results)):
|
||||
prev_result = performance_results[i-1]
|
||||
curr_result = performance_results[i]
|
||||
|
||||
size_ratio = curr_result['data_size'] / prev_result['data_size']
|
||||
time_ratio = curr_result['processing_time'] / prev_result['processing_time']
|
||||
|
||||
# Processing time should scale better than linearly
|
||||
assert time_ratio < size_ratio * 1.5, f"Poor scaling at size {curr_result['data_size']}"
|
||||
|
||||
print("Data Processing Scalability Results:")
|
||||
for result in performance_results:
|
||||
print(f" Size: {result['data_size']:,} rows, Time: {result['processing_time']:.2f}s, "
|
||||
f"Throughput: {result['throughput']:.0f} rows/s")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_usage_optimization(self, large_sales_dataset):
|
||||
"""Test memory usage optimization during training"""
|
||||
|
||||
trainer = BakeryMLTrainer()
|
||||
process = psutil.Process()
|
||||
|
||||
# Baseline memory
|
||||
gc.collect() # Force garbage collection
|
||||
baseline_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||
|
||||
memory_snapshots = [{'stage': 'baseline', 'memory_mb': baseline_memory}]
|
||||
|
||||
# Load data
|
||||
product_data = large_sales_dataset[large_sales_dataset['product'] == 'Pan Integral'].copy()
|
||||
current_memory = process.memory_info().rss / 1024 / 1024
|
||||
memory_snapshots.append({'stage': 'data_loaded', 'memory_mb': current_memory})
|
||||
|
||||
# Train model
|
||||
result = await trainer.train_single_product(
|
||||
tenant_id="memory_test_tenant",
|
||||
product_name="Pan Integral",
|
||||
sales_data=product_data,
|
||||
config={"include_weather": True, "include_traffic": True}
|
||||
)
|
||||
|
||||
current_memory = process.memory_info().rss / 1024 / 1024
|
||||
memory_snapshots.append({'stage': 'model_trained', 'memory_mb': current_memory})
|
||||
|
||||
# Cleanup
|
||||
del product_data
|
||||
del result
|
||||
gc.collect()
|
||||
|
||||
final_memory = process.memory_info().rss / 1024 / 1024
|
||||
memory_snapshots.append({'stage': 'cleanup', 'memory_mb': final_memory})
|
||||
|
||||
# Memory assertions
|
||||
peak_memory = max(snapshot['memory_mb'] for snapshot in memory_snapshots)
|
||||
memory_increase = peak_memory - baseline_memory
|
||||
memory_after_cleanup = final_memory - baseline_memory
|
||||
|
||||
assert memory_increase < 800, f"Peak memory increase too high: {memory_increase:.2f}MB"
|
||||
assert memory_after_cleanup < 100, f"Memory not properly cleaned up: {memory_after_cleanup:.2f}MB"
|
||||
|
||||
print("Memory Usage Analysis:")
|
||||
for snapshot in memory_snapshots:
|
||||
print(f" {snapshot['stage']}: {snapshot['memory_mb']:.2f}MB")
|
||||
print(f" Peak increase: {memory_increase:.2f}MB")
|
||||
print(f" After cleanup: {memory_after_cleanup:.2f}MB")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_training_service_throughput(self, large_sales_dataset):
|
||||
"""Test training service throughput with multiple requests"""
|
||||
|
||||
training_service = TrainingService()
|
||||
|
||||
# Simulate multiple training requests
|
||||
num_requests = 5
|
||||
products = ["Pan Integral", "Croissant", "Magdalenas", "Empanadas", "Tarta Chocolate"]
|
||||
|
||||
async def execute_training_request(request_id: int, product: str):
|
||||
"""Execute a single training request"""
|
||||
product_data = large_sales_dataset[large_sales_dataset['product'] == product].copy()
|
||||
|
||||
with patch.object(training_service, '_fetch_sales_data', return_value=product_data):
|
||||
start_time = time.time()
|
||||
|
||||
result = await training_service.execute_training_job(
|
||||
db=None, # Mock DB session
|
||||
tenant_id=f"throughput_test_tenant_{request_id}",
|
||||
job_id=f"job_{request_id}_{product.replace(' ', '_').lower()}",
|
||||
request={
|
||||
'products': [product],
|
||||
'include_weather': True,
|
||||
'include_traffic': False,
|
||||
'config': {'seasonality_mode': 'additive'}
|
||||
}
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
return {
|
||||
'request_id': request_id,
|
||||
'product': product,
|
||||
'duration': duration,
|
||||
'status': result.get('status', 'unknown'),
|
||||
'models_trained': len(result.get('models_trained', []))
|
||||
}
|
||||
|
||||
# Execute requests concurrently
|
||||
start_time = time.time()
|
||||
tasks = [
|
||||
execute_training_request(i, products[i % len(products)])
|
||||
for i in range(num_requests)
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Calculate throughput metrics
|
||||
successful_requests = sum(1 for r in results if r['status'] == 'completed')
|
||||
throughput = successful_requests / total_time # requests per second
|
||||
|
||||
# Performance assertions
|
||||
assert successful_requests >= num_requests * 0.8, "Too many failed requests"
|
||||
assert throughput >= 0.1, f"Throughput too low: {throughput:.3f} req/s"
|
||||
assert total_time < 300, f"Total time too long: {total_time:.2f}s"
|
||||
|
||||
print(f"Training Service Throughput Results:")
|
||||
print(f" Total Requests: {num_requests}")
|
||||
print(f" Successful: {successful_requests}")
|
||||
print(f" Total Time: {total_time:.2f}s")
|
||||
print(f" Throughput: {throughput:.3f} req/s")
|
||||
print(f" Average Request Time: {total_time/num_requests:.2f}s")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_dataset_edge_cases(self, large_sales_dataset):
|
||||
"""Test handling of edge cases with large datasets"""
|
||||
|
||||
data_processor = BakeryDataProcessor()
|
||||
|
||||
# Test 1: Dataset with many missing values
|
||||
corrupted_data = large_sales_dataset.copy()
|
||||
# Introduce 30% missing values randomly
|
||||
mask = np.random.random(len(corrupted_data)) < 0.3
|
||||
corrupted_data.loc[mask, 'quantity'] = np.nan
|
||||
|
||||
start_time = time.time()
|
||||
result = await data_processor.validate_data_quality(corrupted_data)
|
||||
validation_time = time.time() - start_time
|
||||
|
||||
assert validation_time < 10, f"Validation too slow: {validation_time:.2f}s"
|
||||
assert result['is_valid'] is False
|
||||
assert 'high_missing_data' in result['issues']
|
||||
|
||||
# Test 2: Dataset with extreme outliers
|
||||
outlier_data = large_sales_dataset.copy()
|
||||
# Add extreme outliers (100x normal values)
|
||||
outlier_indices = np.random.choice(len(outlier_data), size=int(len(outlier_data) * 0.01), replace=False)
|
||||
outlier_data.loc[outlier_indices, 'quantity'] *= 100
|
||||
|
||||
start_time = time.time()
|
||||
cleaned_data = await data_processor.clean_outliers(outlier_data)
|
||||
cleaning_time = time.time() - start_time
|
||||
|
||||
assert cleaning_time < 15, f"Outlier cleaning too slow: {cleaning_time:.2f}s"
|
||||
assert len(cleaned_data) > len(outlier_data) * 0.95 # Should retain most data
|
||||
|
||||
# Test 3: Very sparse data (many products with few sales)
|
||||
sparse_data = large_sales_dataset.copy()
|
||||
# Keep only 10% of data for each product randomly
|
||||
sparse_data = sparse_data.groupby('product').apply(
|
||||
lambda x: x.sample(n=max(1, int(len(x) * 0.1)))
|
||||
).reset_index(drop=True)
|
||||
|
||||
start_time = time.time()
|
||||
validation_result = await data_processor.validate_data_quality(sparse_data)
|
||||
sparse_validation_time = time.time() - start_time
|
||||
|
||||
assert sparse_validation_time < 5, f"Sparse data validation too slow: {sparse_validation_time:.2f}s"
|
||||
|
||||
print("Edge Case Performance Results:")
|
||||
print(f" Corrupted data validation: {validation_time:.2f}s")
|
||||
print(f" Outlier cleaning: {cleaning_time:.2f}s")
|
||||
print(f" Sparse data validation: {sparse_validation_time:.2f}s")
|
||||
|
||||
|
||||
class TestTrainingServiceLoad:
|
||||
"""Load testing for training service under stress"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sustained_load_training(self, large_sales_dataset):
|
||||
"""Test training service under sustained load"""
|
||||
|
||||
trainer = BakeryMLTrainer()
|
||||
|
||||
# Define load test parameters
|
||||
duration_minutes = 2 # Run for 2 minutes
|
||||
requests_per_minute = 3
|
||||
|
||||
products = ["Pan Integral", "Croissant", "Magdalenas"]
|
||||
|
||||
async def sustained_training_worker(worker_id: int, duration: float):
|
||||
"""Worker that continuously submits training requests"""
|
||||
start_time = time.time()
|
||||
completed_requests = 0
|
||||
failed_requests = 0
|
||||
|
||||
while time.time() - start_time < duration:
|
||||
try:
|
||||
product = products[completed_requests % len(products)]
|
||||
product_data = large_sales_dataset[
|
||||
large_sales_dataset['product'] == product
|
||||
].copy()
|
||||
|
||||
result = await trainer.train_single_product(
|
||||
tenant_id=f"load_test_worker_{worker_id}",
|
||||
product_name=product,
|
||||
sales_data=product_data,
|
||||
config={"include_weather": False, "include_traffic": False} # Minimal config for speed
|
||||
)
|
||||
|
||||
if result['status'] == 'completed':
|
||||
completed_requests += 1
|
||||
else:
|
||||
failed_requests += 1
|
||||
|
||||
except Exception as e:
|
||||
failed_requests += 1
|
||||
logging.error(f"Training request failed: {e}")
|
||||
|
||||
# Wait before next request
|
||||
await asyncio.sleep(60 / requests_per_minute)
|
||||
|
||||
return {
|
||||
'worker_id': worker_id,
|
||||
'completed': completed_requests,
|
||||
'failed': failed_requests,
|
||||
'duration': time.time() - start_time
|
||||
}
|
||||
|
||||
# Start multiple workers
|
||||
num_workers = 2
|
||||
duration_seconds = duration_minutes * 60
|
||||
|
||||
start_time = time.time()
|
||||
tasks = [
|
||||
sustained_training_worker(i, duration_seconds)
|
||||
for i in range(num_workers)
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Analyze results
|
||||
total_completed = sum(r['completed'] for r in results)
|
||||
total_failed = sum(r['failed'] for r in results)
|
||||
success_rate = total_completed / (total_completed + total_failed) if (total_completed + total_failed) > 0 else 0
|
||||
|
||||
# Performance assertions
|
||||
assert success_rate >= 0.8, f"Success rate too low: {success_rate:.2%}"
|
||||
assert total_completed >= duration_minutes * requests_per_minute * num_workers * 0.7, "Throughput too low"
|
||||
|
||||
print(f"Sustained Load Test Results:")
|
||||
print(f" Duration: {total_time:.2f}s")
|
||||
print(f" Workers: {num_workers}")
|
||||
print(f" Completed Requests: {total_completed}")
|
||||
print(f" Failed Requests: {total_failed}")
|
||||
print(f" Success Rate: {success_rate:.2%}")
|
||||
print(f" Average Throughput: {total_completed/total_time:.2f} req/s")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resource_exhaustion_recovery(self, large_sales_dataset):
|
||||
"""Test service recovery from resource exhaustion"""
|
||||
|
||||
trainer = BakeryMLTrainer()
|
||||
|
||||
# Simulate resource exhaustion by running many concurrent requests
|
||||
num_concurrent = 10 # High concurrency to stress the system
|
||||
|
||||
async def resource_intensive_task(task_id: int):
|
||||
"""Task designed to consume resources"""
|
||||
try:
|
||||
# Use all products to increase memory usage
|
||||
all_products_data = large_sales_dataset.copy()
|
||||
|
||||
result = await trainer.train_tenant_models(
|
||||
tenant_id=f"resource_test_{task_id}",
|
||||
sales_data=all_products_data,
|
||||
config={
|
||||
"train_all_products": True,
|
||||
"include_weather": True,
|
||||
"include_traffic": True
|
||||
}
|
||||
)
|
||||
|
||||
return {'task_id': task_id, 'status': 'completed', 'error': None}
|
||||
|
||||
except Exception as e:
|
||||
return {'task_id': task_id, 'status': 'failed', 'error': str(e)}
|
||||
|
||||
# Launch all tasks simultaneously
|
||||
start_time = time.time()
|
||||
tasks = [resource_intensive_task(i) for i in range(num_concurrent)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Analyze results
|
||||
completed = sum(1 for r in results if isinstance(r, dict) and r['status'] == 'completed')
|
||||
failed = sum(1 for r in results if isinstance(r, dict) and r['status'] == 'failed')
|
||||
exceptions = sum(1 for r in results if isinstance(r, Exception))
|
||||
|
||||
# The system should handle some failures gracefully
|
||||
# but should complete at least some requests
|
||||
total_processed = completed + failed + exceptions
|
||||
processing_rate = total_processed / num_concurrent
|
||||
|
||||
assert processing_rate >= 0.5, f"Too many requests not processed: {processing_rate:.2%}"
|
||||
assert duration < 600, f"Recovery took too long: {duration:.2f}s" # 10 minutes max
|
||||
|
||||
print(f"Resource Exhaustion Test Results:")
|
||||
print(f" Concurrent Requests: {num_concurrent}")
|
||||
print(f" Completed: {completed}")
|
||||
print(f" Failed: {failed}")
|
||||
print(f" Exceptions: {exceptions}")
|
||||
print(f" Duration: {duration:.2f}s")
|
||||
print(f" Processing Rate: {processing_rate:.2%}")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# BENCHMARK UTILITIES
|
||||
# ================================================================
|
||||
|
||||
class PerformanceBenchmark:
|
||||
"""Utility class for performance benchmarking"""
|
||||
|
||||
@staticmethod
|
||||
def measure_execution_time(func):
|
||||
"""Decorator to measure execution time"""
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = await func(*args, **kwargs)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if hasattr(result, 'update') and isinstance(result, dict):
|
||||
result['execution_time'] = execution_time
|
||||
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
def memory_profiler(func):
|
||||
"""Decorator to profile memory usage"""
|
||||
async def wrapper(*args, **kwargs):
|
||||
process = psutil.Process()
|
||||
|
||||
# Memory before
|
||||
gc.collect()
|
||||
memory_before = process.memory_info().rss / 1024 / 1024
|
||||
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# Memory after
|
||||
memory_after = process.memory_info().rss / 1024 / 1024
|
||||
memory_used = memory_after - memory_before
|
||||
|
||||
if hasattr(result, 'update') and isinstance(result, dict):
|
||||
result['memory_used_mb'] = memory_used
|
||||
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
# ================================================================
|
||||
# STANDALONE EXECUTION
|
||||
# ================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Run performance tests as standalone script
|
||||
Usage: python test_performance.py
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
# Add the training service root to Python path
|
||||
training_service_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, training_service_root)
|
||||
|
||||
print("=" * 60)
|
||||
print("TRAINING SERVICE PERFORMANCE TEST SUITE")
|
||||
print("=" * 60)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
# Run performance tests
|
||||
pytest.main([
|
||||
__file__,
|
||||
"-v",
|
||||
"--tb=short",
|
||||
"-s", # Don't capture output
|
||||
"--durations=10", # Show 10 slowest tests
|
||||
"-m", "not slow", # Skip slow tests unless specifically requested
|
||||
])
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("PERFORMANCE TESTING COMPLETE")
|
||||
print("=" * 60)
|
||||
@@ -1,688 +0,0 @@
|
||||
# services/training/tests/test_service.py
|
||||
"""
|
||||
Tests for training service business logic layer
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
from datetime import datetime, timedelta
|
||||
import httpx
|
||||
|
||||
from app.services.training_service import TrainingService
|
||||
from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
|
||||
from app.models.training import ModelTrainingLog, TrainedModel
|
||||
|
||||
|
||||
class TestTrainingService:
|
||||
"""Test the training service business logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def training_service(self, mock_ml_trainer):
|
||||
return TrainingService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_training_job_success(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test successful training job creation"""
|
||||
job_id = "test-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
config = {"include_weather": True, "include_traffic": True}
|
||||
|
||||
result = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
config=config
|
||||
)
|
||||
|
||||
assert isinstance(result, ModelTrainingLog)
|
||||
assert result.job_id == job_id
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.status == "pending"
|
||||
assert result.progress == 0
|
||||
assert result.config == config
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_single_product_job_success(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test successful single product job creation"""
|
||||
job_id = "test-product-job-123"
|
||||
tenant_id = "test-tenant"
|
||||
product_name = "Pan Integral"
|
||||
config = {"include_weather": True}
|
||||
|
||||
result = await training_service.create_single_product_job(
|
||||
db=test_db_session,
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
job_id=job_id,
|
||||
config=config
|
||||
)
|
||||
|
||||
assert isinstance(result, ModelTrainingLog)
|
||||
assert result.job_id == job_id
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.config["single_product"] == product_name
|
||||
assert f"Initializing training for {product_name}" in result.current_step
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_job_status_existing(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test getting status of existing job"""
|
||||
result = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=training_job_in_db.job_id,
|
||||
tenant_id=training_job_in_db.tenant_id
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.job_id == training_job_in_db.job_id
|
||||
assert result.status == training_job_in_db.status
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_job_status_nonexistent(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test getting status of non-existent job"""
|
||||
result = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id="nonexistent-job",
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_training_jobs(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test listing training jobs"""
|
||||
result = await training_service.list_training_jobs(
|
||||
db=test_db_session,
|
||||
tenant_id=training_job_in_db.tenant_id,
|
||||
limit=10
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) >= 1
|
||||
assert result[0].job_id == training_job_in_db.job_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_training_jobs_with_filter(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test listing training jobs with status filter"""
|
||||
result = await training_service.list_training_jobs(
|
||||
db=test_db_session,
|
||||
tenant_id=training_job_in_db.tenant_id,
|
||||
limit=10,
|
||||
status_filter="pending"
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
for job in result:
|
||||
assert job.status == "pending"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_training_job_success(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test successful job cancellation"""
|
||||
result = await training_service.cancel_training_job(
|
||||
db=test_db_session,
|
||||
job_id=training_job_in_db.job_id,
|
||||
tenant_id=training_job_in_db.tenant_id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify status was updated
|
||||
updated_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=training_job_in_db.job_id,
|
||||
tenant_id=training_job_in_db.tenant_id
|
||||
)
|
||||
assert updated_job.status == "cancelled"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_nonexistent_job(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test cancelling non-existent job"""
|
||||
result = await training_service.cancel_training_job(
|
||||
db=test_db_session,
|
||||
job_id="nonexistent-job",
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_training_data_valid(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test validation with valid data"""
|
||||
config = {"min_data_points": 30}
|
||||
|
||||
result = await training_service.validate_training_data(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
config=config
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "is_valid" in result
|
||||
assert "issues" in result
|
||||
assert "recommendations" in result
|
||||
assert "estimated_time_minutes" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_training_data_no_data(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test validation with no data"""
|
||||
config = {"min_data_points": 30}
|
||||
|
||||
with patch('app.services.training_service.TrainingService._fetch_sales_data', return_value=AsyncMock(return_value=[])):
|
||||
result = await training_service.validate_training_data(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
config=config
|
||||
)
|
||||
|
||||
assert result["is_valid"] is False
|
||||
assert "No sales data found" in result["issues"][0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job_status(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test updating job status"""
|
||||
await training_service._update_job_status(
|
||||
db=test_db_session,
|
||||
job_id=training_job_in_db.job_id,
|
||||
status="running",
|
||||
progress=50,
|
||||
current_step="Training models"
|
||||
)
|
||||
|
||||
# Verify update
|
||||
updated_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=training_job_in_db.job_id,
|
||||
tenant_id=training_job_in_db.tenant_id
|
||||
)
|
||||
|
||||
assert updated_job.status == "running"
|
||||
assert updated_job.progress == 50
|
||||
assert updated_job.current_step == "Training models"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_trained_models(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session
|
||||
):
|
||||
"""Test storing trained models"""
|
||||
tenant_id = "test-tenant"
|
||||
training_results = {
|
||||
"training_results": {
|
||||
"Pan Integral": {
|
||||
"status": "success",
|
||||
"model_info": {
|
||||
"model_id": "test-model-123",
|
||||
"model_path": "/test/models/test-model-123.pkl",
|
||||
"type": "prophet",
|
||||
"training_samples": 100,
|
||||
"features": ["temperature", "humidity"],
|
||||
"hyperparameters": {"seasonality_mode": "additive"},
|
||||
"training_metrics": {"mae": 5.2, "rmse": 7.8},
|
||||
"data_period": {
|
||||
"start_date": "2024-01-01T00:00:00",
|
||||
"end_date": "2024-01-31T00:00:00"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await training_service._store_trained_models(
|
||||
db=test_db_session,
|
||||
tenant_id=tenant_id,
|
||||
training_results=training_results
|
||||
)
|
||||
|
||||
# Verify model was stored
|
||||
from sqlalchemy import select
|
||||
result = await test_db_session.execute(
|
||||
select(TrainedModel).where(
|
||||
TrainedModel.tenant_id == tenant_id,
|
||||
TrainedModel.product_name == "Pan Integral"
|
||||
)
|
||||
)
|
||||
|
||||
stored_model = result.scalar_one_or_none()
|
||||
assert stored_model is not None
|
||||
assert stored_model.model_id == "test-model-123"
|
||||
assert stored_model.is_active is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_training_logs(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
training_job_in_db
|
||||
):
|
||||
"""Test getting training logs"""
|
||||
result = await training_service.get_training_logs(
|
||||
db=test_db_session,
|
||||
job_id=training_job_in_db.job_id,
|
||||
tenant_id=training_job_in_db.tenant_id
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) > 0
|
||||
|
||||
# Check log content
|
||||
log_text = " ".join(result)
|
||||
assert training_job_in_db.job_id in log_text or "Job started" in log_text
|
||||
|
||||
|
||||
class TestTrainingServiceDataFetching:
|
||||
"""Test external data fetching functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def training_service(self):
|
||||
return TrainingService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_sales_data_success(self, training_service):
|
||||
"""Test successful sales data fetching"""
|
||||
mock_request = Mock()
|
||||
mock_request.start_date = None
|
||||
mock_request.end_date = None
|
||||
|
||||
mock_response_data = {
|
||||
"sales": [
|
||||
{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45}
|
||||
]
|
||||
}
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
result = await training_service._fetch_sales_data(
|
||||
tenant_id="test-tenant",
|
||||
request=mock_request
|
||||
)
|
||||
|
||||
assert result == mock_response_data["sales"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_sales_data_error(self, training_service):
|
||||
"""Test sales data fetching with API error"""
|
||||
mock_request = Mock()
|
||||
mock_request.start_date = None
|
||||
mock_request.end_date = None
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
result = await training_service._fetch_sales_data(
|
||||
tenant_id="test-tenant",
|
||||
request=mock_request
|
||||
)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_weather_data_success(self, training_service):
|
||||
"""Test successful weather data fetching"""
|
||||
mock_request = Mock()
|
||||
mock_request.start_date = None
|
||||
mock_request.end_date = None
|
||||
|
||||
mock_response_data = {
|
||||
"weather": [
|
||||
{"date": "2024-01-01", "temperature": 15.2, "precipitation": 0.0}
|
||||
]
|
||||
}
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
result = await training_service._fetch_weather_data(
|
||||
tenant_id="test-tenant",
|
||||
request=mock_request
|
||||
)
|
||||
|
||||
assert result == mock_response_data["weather"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_traffic_data_success(self, training_service):
|
||||
"""Test successful traffic data fetching"""
|
||||
mock_request = Mock()
|
||||
mock_request.start_date = None
|
||||
mock_request.end_date = None
|
||||
|
||||
mock_response_data = {
|
||||
"traffic": [
|
||||
{"date": "2024-01-01", "traffic_volume": 120}
|
||||
]
|
||||
}
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
|
||||
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
||||
|
||||
result = await training_service._fetch_traffic_data(
|
||||
tenant_id="test-tenant",
|
||||
request=mock_request
|
||||
)
|
||||
|
||||
assert result == mock_response_data["traffic"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_data_with_date_filters(self, training_service):
|
||||
"""Test data fetching with date filters"""
|
||||
from datetime import datetime
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.start_date = datetime(2024, 1, 1)
|
||||
mock_request.end_date = datetime(2024, 1, 31)
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"sales": []}
|
||||
|
||||
mock_get = mock_client.return_value.__aenter__.return_value.get
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
await training_service._fetch_sales_data(
|
||||
tenant_id="test-tenant",
|
||||
request=mock_request
|
||||
)
|
||||
|
||||
# Verify dates were passed in params
|
||||
call_args = mock_get.call_args
|
||||
params = call_args[1]["params"]
|
||||
assert "start_date" in params
|
||||
assert "end_date" in params
|
||||
assert params["start_date"] == "2024-01-01T00:00:00"
|
||||
assert params["end_date"] == "2024-01-31T00:00:00"
|
||||
|
||||
|
||||
class TestTrainingServiceExecution:
|
||||
"""Test training execution workflow"""
|
||||
|
||||
@pytest.fixture
|
||||
def training_service(self, mock_ml_trainer):
|
||||
return TrainingService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_training_job_success(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
mock_messaging,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test successful training job execution"""
|
||||
# Create job first
|
||||
job_id = "test-execution-job"
|
||||
training_log = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id=job_id,
|
||||
config={"include_weather": True}
|
||||
)
|
||||
|
||||
request = TrainingJobRequest(
|
||||
include_weather=True,
|
||||
include_traffic=True,
|
||||
min_data_points=30
|
||||
)
|
||||
|
||||
with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch_sales, \
|
||||
patch('app.services.training_service.TrainingService._fetch_weather_data') as mock_fetch_weather, \
|
||||
patch('app.services.training_service.TrainingService._fetch_traffic_data') as mock_fetch_traffic, \
|
||||
patch('app.services.training_service.TrainingService._store_trained_models') as mock_store:
|
||||
|
||||
mock_fetch_sales.return_value = [{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45}]
|
||||
mock_fetch_weather.return_value = []
|
||||
mock_fetch_traffic.return_value = []
|
||||
mock_store.return_value = None
|
||||
|
||||
await training_service.execute_training_job(
|
||||
db=test_db_session,
|
||||
job_id=job_id,
|
||||
tenant_id="test-tenant",
|
||||
request=request
|
||||
)
|
||||
|
||||
# Verify job was completed
|
||||
updated_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert updated_job.status == "completed"
|
||||
assert updated_job.progress == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_training_job_failure(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
mock_messaging
|
||||
):
|
||||
"""Test training job execution with failure"""
|
||||
# Create job first
|
||||
job_id = "test-failure-job"
|
||||
await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id=job_id,
|
||||
config={}
|
||||
)
|
||||
|
||||
request = TrainingJobRequest(min_data_points=30)
|
||||
|
||||
with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch:
|
||||
mock_fetch.side_effect = Exception("Data service unavailable")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await training_service.execute_training_job(
|
||||
db=test_db_session,
|
||||
job_id=job_id,
|
||||
tenant_id="test-tenant",
|
||||
request=request
|
||||
)
|
||||
|
||||
# Verify job was marked as failed
|
||||
updated_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert updated_job.status == "failed"
|
||||
assert "Data service unavailable" in updated_job.error_message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_single_product_training_success(
|
||||
self,
|
||||
training_service,
|
||||
test_db_session,
|
||||
mock_messaging,
|
||||
mock_data_service
|
||||
):
|
||||
"""Test successful single product training execution"""
|
||||
job_id = "test-single-product-job"
|
||||
product_name = "Pan Integral"
|
||||
|
||||
await training_service.create_single_product_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
product_name=product_name,
|
||||
job_id=job_id,
|
||||
config={}
|
||||
)
|
||||
|
||||
request = SingleProductTrainingRequest(
|
||||
include_weather=True,
|
||||
include_traffic=False
|
||||
)
|
||||
|
||||
with patch('app.services.training_service.TrainingService._fetch_product_sales_data') as mock_fetch_sales, \
|
||||
patch('app.services.training_service.TrainingService._fetch_weather_data') as mock_fetch_weather, \
|
||||
patch('app.services.training_service.TrainingService._store_single_trained_model') as mock_store:
|
||||
|
||||
mock_fetch_sales.return_value = [{"date": "2024-01-01", "product_name": product_name, "quantity": 45}]
|
||||
mock_fetch_weather.return_value = []
|
||||
mock_store.return_value = None
|
||||
|
||||
await training_service.execute_single_product_training(
|
||||
db=test_db_session,
|
||||
job_id=job_id,
|
||||
tenant_id="test-tenant",
|
||||
product_name=product_name,
|
||||
request=request
|
||||
)
|
||||
|
||||
# Verify job was completed
|
||||
updated_job = await training_service.get_job_status(
|
||||
db=test_db_session,
|
||||
job_id=job_id,
|
||||
tenant_id="test-tenant"
|
||||
)
|
||||
|
||||
assert updated_job.status == "completed"
|
||||
assert updated_job.progress == 100
|
||||
|
||||
|
||||
class TestTrainingServiceEdgeCases:
|
||||
"""Test edge cases and error conditions"""
|
||||
|
||||
@pytest.fixture
|
||||
def training_service(self):
|
||||
return TrainingService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_connection_failure(self, training_service):
|
||||
"""Test handling of database connection failures"""
|
||||
with patch('sqlalchemy.ext.asyncio.AsyncSession') as mock_session:
|
||||
mock_session.side_effect = Exception("Database connection failed")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await training_service.create_training_job(
|
||||
db=mock_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="test-job",
|
||||
config={}
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_external_service_timeout(self, training_service):
|
||||
"""Test handling of external service timeouts"""
|
||||
mock_request = Mock()
|
||||
mock_request.start_date = None
|
||||
mock_request.end_date = None
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.get.side_effect = httpx.TimeoutException("Request timeout")
|
||||
|
||||
result = await training_service._fetch_sales_data(
|
||||
tenant_id="test-tenant",
|
||||
request=mock_request
|
||||
)
|
||||
|
||||
# Should return empty list on timeout
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_job_creation(self, training_service, test_db_session):
|
||||
"""Test handling of concurrent job creation"""
|
||||
# This test would need more sophisticated setup for true concurrency testing
|
||||
# For now, just test that multiple jobs can be created
|
||||
|
||||
job_ids = ["concurrent-job-1", "concurrent-job-2", "concurrent-job-3"]
|
||||
|
||||
jobs = []
|
||||
for job_id in job_ids:
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id=job_id,
|
||||
config={}
|
||||
)
|
||||
jobs.append(job)
|
||||
|
||||
assert len(jobs) == 3
|
||||
for i, job in enumerate(jobs):
|
||||
assert job.job_id == job_ids[i]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_config_handling(self, training_service, test_db_session):
|
||||
"""Test handling of malformed configuration"""
|
||||
malformed_config = {
|
||||
"invalid_key": "invalid_value",
|
||||
"nested": {"data": None}
|
||||
}
|
||||
|
||||
# Should not raise exception, just store the config as-is
|
||||
job = await training_service.create_training_job(
|
||||
db=test_db_session,
|
||||
tenant_id="test-tenant",
|
||||
job_id="malformed-config-job",
|
||||
config=malformed_config
|
||||
)
|
||||
|
||||
assert job.config == malformed_config
|
||||
Reference in New Issue
Block a user