diff --git a/services/training/app/main.py b/services/training/app/main.py
index d83f526d..04812f3f 100644
--- a/services/training/app/main.py
+++ b/services/training/app/main.py
@@ -220,6 +220,19 @@ async def get_metrics():
return app.state.metrics_collector.get_metrics()
return {"status": "metrics not available"}
+@app.get("/health/live")
+async def liveness_check():
+ return {"status": "alive"}
+
+@app.get("/health/ready")
+async def readiness_check():
+ ready = getattr(app.state, 'ready', True)
+ return {"status": "ready" if ready else "not ready"}
+
+@app.get("/")
+async def root():
+ return {"service": "training-service", "version": "1.0.0"}
+
if __name__ == "__main__":
uvicorn.run(
"app.main:app",
diff --git a/services/training/tests/conftest.py b/services/training/tests/conftest.py
deleted file mode 100644
index 8a0087de..00000000
--- a/services/training/tests/conftest.py
+++ /dev/null
@@ -1,1632 +0,0 @@
-# ================================================================
-# services/training/tests/conftest.py
-# ================================================================
-"""
-Test configuration and fixtures for Training Service
-Provides shared fixtures, mock data, and test utilities
-"""
-
-import pytest
-import asyncio
-import pandas as pd
-import numpy as np
-import tempfile
-import os
-import json
-from datetime import datetime, timedelta
-from unittest.mock import Mock, AsyncMock, patch
-from typing import Dict, List, Any, Generator
-from pathlib import Path
-import logging
-from app.models.training import ModelTrainingLog, TrainedModel
-
-# Configure pytest-asyncio
-pytestmark = pytest.mark.asyncio
-
-# Suppress Prophet logging during tests
-logging.getLogger('prophet').setLevel(logging.WARNING)
-logging.getLogger('cmdstanpy').setLevel(logging.WARNING)
-
-
-# ================================================================
-# PYTEST CONFIGURATION
-# ================================================================
-
-@pytest.fixture
-def large_dataset_for_performance():
- """Generate large dataset for performance testing"""
- # Generate 2 years of data with 15 products
- start_date = datetime(2022, 1, 1)
- end_date = datetime(2024, 1, 1)
- date_range = pd.date_range(start=start_date, end=end_date, freq='D')
-
- products = [
- "Pan Integral", "Pan Blanco", "Croissant", "Magdalenas",
- "Empanadas", "Tarta Chocolate", "Roscon Reyes", "Palmeras",
- "Donuts", "Berlinas", "Napolitanas", "Ensaimadas",
- "Baguette", "Pan de Molde", "Bizcocho"
- ]
-
- data = []
- for date in date_range:
- for product in products:
- # Realistic sales with patterns
- base_quantity = np.random.randint(5, 150)
-
- # Seasonal patterns
- if date.month in [12, 1]: # Winter/Holiday season
- base_quantity *= 1.4
- elif date.month in [6, 7, 8]: # Summer
- base_quantity *= 0.8
-
- # Weekly patterns
- if date.weekday() >= 5: # Weekends
- base_quantity *= 1.2
- elif date.weekday() == 0: # Monday
- base_quantity *= 0.7
-
- # Add noise
- quantity = max(1, int(base_quantity + np.random.normal(0, base_quantity * 0.1)))
-
- data.append({
- "date": date.strftime("%Y-%m-%d"),
- "product": product,
- "quantity": quantity,
- "revenue": round(quantity * np.random.uniform(1.5, 8.0), 2),
- "temperature": round(15 + 12 * np.sin((date.timetuple().tm_yday / 365) * 2 * np.pi) + np.random.normal(0, 3), 1),
- "precipitation": max(0, np.random.exponential(0.8)),
- "is_weekend": date.weekday() >= 5,
- "is_holiday": _is_spanish_holiday(date)
- })
-
- return pd.DataFrame(data)
-
-
-@pytest.fixture
-def memory_monitor():
- """Memory monitoring utility for performance tests"""
- import psutil
- import gc
-
- class MemoryMonitor:
- def __init__(self):
- self.process = psutil.Process()
- self.snapshots = []
-
- def snapshot(self, label: str):
- gc.collect() # Force garbage collection
- memory_mb = self.process.memory_info().rss / 1024 / 1024
- self.snapshots.append({
- 'label': label,
- 'memory_mb': memory_mb,
- 'timestamp': datetime.now()
- })
- return memory_mb
-
- def get_peak_usage(self):
- if not self.snapshots:
- return 0
- return max(s['memory_mb'] for s in self.snapshots)
-
- def get_usage_increase(self):
- if len(self.snapshots) < 2:
- return 0
- return self.snapshots[-1]['memory_mb'] - self.snapshots[0]['memory_mb']
-
- def report(self):
- print("\n=== Memory Usage Report ===")
- for snapshot in self.snapshots:
- print(f"{snapshot['label']}: {snapshot['memory_mb']:.2f} MB")
- print(f"Peak Usage: {self.get_peak_usage():.2f} MB")
- print(f"Total Increase: {self.get_usage_increase():.2f} MB")
-
- return MemoryMonitor()
-
-
-@pytest.fixture
-def timing_monitor():
- """Timing monitoring utility for performance tests"""
- import time
-
- class TimingMonitor:
- def __init__(self):
- self.timings = []
- self.start_time = None
-
- def start(self, label: str):
- self.start_time = time.time()
- self.current_label = label
-
- def stop(self):
- if self.start_time is None:
- return 0
-
- duration = time.time() - self.start_time
- self.timings.append({
- 'label': self.current_label,
- 'duration': duration
- })
- self.start_time = None
- return duration
-
- def get_total_time(self):
- return sum(t['duration'] for t in self.timings)
-
- def report(self):
- print("\n=== Timing Report ===")
- for timing in self.timings:
- print(f"{timing['label']}: {timing['duration']:.2f}s")
- print(f"Total Time: {self.get_total_time():.2f}s")
-
- return TimingMonitor()
-
-
-# ================================================================
-# INTEGRATION TEST FIXTURES
-# ================================================================
-
-@pytest.fixture
-async def integration_test_setup(
- mock_external_services,
- sample_bakery_sales_data,
- temp_model_storage
-):
- """Complete setup for integration tests"""
-
- # Patch model storage path
- with patch('app.core.config.settings.MODEL_STORAGE_PATH', str(temp_model_storage)):
-
- # Patch data fetching to use sample data
- with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch:
- mock_fetch.return_value = sample_bakery_sales_data
-
- yield {
- 'external_services': mock_external_services,
- 'sales_data': sample_bakery_sales_data,
- 'model_storage': temp_model_storage,
- 'mock_fetch': mock_fetch
- }
-
-
-@pytest.fixture
-def mock_messaging():
- """Mock messaging system for testing"""
- with patch('app.services.messaging.publish_job_started') as mock_started, \
- patch('app.services.messaging.publish_job_completed') as mock_completed, \
- patch('app.services.messaging.publish_job_failed') as mock_failed, \
- patch('app.services.messaging.publish_model_trained') as mock_model:
-
- yield {
- 'publish_job_started': mock_started,
- 'publish_job_completed': mock_completed,
- 'publish_job_failed': mock_failed,
- 'publish_model_trained': mock_model
- }
-
-
-# ================================================================
-# API TEST FIXTURES
-# ================================================================
-
-@pytest.fixture
-async def test_app():
- """Test FastAPI application instance"""
- from app.main import app
- return app
-
-@pytest.fixture
-def test_client(test_app):
- """Create test client for API testing - SYNC VERSION"""
- from httpx import Client
-
- with Client(app=test_app, base_url="http://test") as client:
- yield client
-
-@pytest.fixture
-def auth_headers():
- """Mock authentication headers"""
- return {
- "Authorization": "Bearer test_token_123",
- "X-Tenant-ID": "test_tenant_123"
- }
-
-
-# ================================================================
-# ERROR SIMULATION FIXTURES
-# ================================================================
-
-@pytest.fixture
-def failing_external_services():
- """Mock external services that fail for error testing"""
- with patch('app.external.aemet.AEMETClient') as mock_aemet, \
- patch('app.external.madrid_opendata.MadridOpenDataClient') as mock_madrid:
-
- # Configure to raise exceptions
- mock_aemet_instance = AsyncMock()
- mock_aemet.return_value = mock_aemet_instance
- mock_aemet_instance.get_historical_weather.side_effect = Exception("AEMET API Error")
-
- mock_madrid_instance = AsyncMock()
- mock_madrid.return_value = mock_madrid_instance
- mock_madrid_instance.get_historical_traffic.side_effect = Exception("Madrid API Error")
-
- yield {
- 'aemet': mock_aemet_instance,
- 'madrid': mock_madrid_instance
- }
-
-
-@pytest.fixture
-def corrupted_sales_data(sample_bakery_sales_data):
- """Sales data with various quality issues for testing"""
- corrupted_data = sample_bakery_sales_data.copy()
-
- # Introduce missing values (20% of quantity data)
- missing_mask = np.random.random(len(corrupted_data)) < 0.2
- corrupted_data.loc[missing_mask, 'quantity'] = np.nan
-
- # Introduce extreme outliers (1% of data)
- outlier_mask = np.random.random(len(corrupted_data)) < 0.01
- corrupted_data.loc[outlier_mask, 'quantity'] *= 100
-
- # Introduce inconsistent dates (0.5% of data)
- future_mask = np.random.random(len(corrupted_data)) < 0.005
- corrupted_data.loc[future_mask, 'date'] = "2025-12-31"
-
- # Introduce negative values (0.2% of data)
- negative_mask = np.random.random(len(corrupted_data)) < 0.002
- corrupted_data.loc[negative_mask, 'quantity'] = -10
-
- return corrupted_data
-
-
-# ================================================================
-# VALIDATION TEST FIXTURES
-# ================================================================
-
-@pytest.fixture
-def insufficient_sales_data():
- """Sales data with insufficient volume for training"""
- # Only 10 days of data
- start_date = datetime(2023, 1, 1)
- dates = [start_date + timedelta(days=i) for i in range(10)]
-
- data = []
- for date in dates:
- data.append({
- "date": date.strftime("%Y-%m-%d"),
- "product": "Pan Integral",
- "quantity": np.random.randint(10, 50),
- "revenue": round(np.random.uniform(20, 100), 2),
- "temperature": round(np.random.uniform(10, 25), 1),
- "precipitation": 0.0,
- "is_weekend": date.weekday() >= 5,
- "is_holiday": False
- })
-
- return pd.DataFrame(data)
-
-
-@pytest.fixture
-def seasonal_product_data():
- """Data for seasonal product (Roscon Reyes) testing"""
- start_date = datetime(2023, 1, 1)
- dates = [start_date + timedelta(days=i) for i in range(365)]
-
- data = []
- for date in dates:
- # Roscon Reyes has strong seasonal pattern (Christmas specialty)
- base_qty = 5 # Very low base
-
- if date.month == 12: # December - high sales
- base_qty = 20 + (date.day - 1) * 2 # Increasing through December
- elif date.month == 1 and date.day <= 6: # Until Epiphany
- base_qty = 50
-
- # Add some noise
- quantity = max(1, int(base_qty + np.random.normal(0, base_qty * 0.2)))
-
- data.append({
- "date": date.strftime("%Y-%m-%d"),
- "product": "Roscon Reyes",
- "quantity": quantity,
- "revenue": round(quantity * 25.0, 2), # Expensive specialty item
- "temperature": round(15 + 12 * np.sin((date.timetuple().tm_yday / 365) * 2 * np.pi), 1),
- "precipitation": max(0, np.random.exponential(0.5)),
- "is_weekend": date.weekday() >= 5,
- "is_holiday": _is_spanish_holiday(date)
- })
-
- return pd.DataFrame(data)
-
-
-# ================================================================
-# CLEANUP FIXTURES
-# ================================================================
-
-@pytest.fixture(autouse=True)
-def cleanup_after_test():
- """Automatic cleanup after each test"""
- yield
-
- # Clean up any test files
- import tempfile
- import shutil
-
- # Clear any temporary model files
- temp_dirs = [d for d in os.listdir(tempfile.gettempdir()) if d.startswith('test_models_')]
- for temp_dir in temp_dirs:
- try:
- shutil.rmtree(os.path.join(tempfile.gettempdir(), temp_dir))
- except:
- pass
-
-
-# ================================================================
-# TEST DATA VALIDATION UTILITIES
-# ================================================================
-
-class TestDataValidator:
- """Utility class for validating test data quality"""
-
- @staticmethod
- def validate_sales_data(df: pd.DataFrame) -> Dict[str, Any]:
- """Validate sales data structure and quality"""
- required_columns = ['date', 'product', 'quantity', 'revenue']
- missing_columns = [col for col in required_columns if col not in df.columns]
-
- if missing_columns:
- return {'valid': False, 'error': f'Missing columns: {missing_columns}'}
-
- # Check data types
- try:
- pd.to_datetime(df['date'])
- except:
- return {'valid': False, 'error': 'Invalid date format'}
-
- if not pd.api.types.is_numeric_dtype(df['quantity']):
- return {'valid': False, 'error': 'Quantity must be numeric'}
-
- if not pd.api.types.is_numeric_dtype(df['revenue']):
- return {'valid': False, 'error': 'Revenue must be numeric'}
-
- # Check for negative values
- if (df['quantity'] < 0).any():
- return {'valid': False, 'error': 'Negative quantities found'}
-
- if (df['revenue'] < 0).any():
- return {'valid': False, 'error': 'Negative revenue found'}
-
- return {'valid': True, 'rows': len(df), 'products': df['product'].nunique()}
-
-
-@pytest.fixture
-def test_data_validator():
- """Test data validator utility"""
- return TestDataValidator()
-
-
-# ================================================================
-# LOGGING CONFIGURATION FOR TESTS
-# ================================================================
-
-@pytest.fixture(autouse=True)
-def configure_test_logging():
- """Configure logging for tests"""
- import logging
-
- # Reduce log level for external libraries during tests
- logging.getLogger('prophet').setLevel(logging.WARNING)
- logging.getLogger('cmdstanpy').setLevel(logging.ERROR)
- logging.getLogger('matplotlib').setLevel(logging.WARNING)
- logging.getLogger('urllib3').setLevel(logging.WARNING)
-
- # Configure our app logging for tests
- logger = logging.getLogger('app')
- logger.setLevel(logging.INFO)
-
- yield
-
- # Reset logging after tests
- logging.getLogger().handlers.clear()
-
-
-# ================================================================
-# ENVIRONMENT SETUP
-# ================================================================
-
-@pytest.fixture(scope="session", autouse=True)
-def setup_test_environment():
- """Setup test environment variables"""
- os.environ.update({
- 'ENVIRONMENT': 'test',
- 'LOG_LEVEL': 'INFO',
- 'MODEL_STORAGE_PATH': '/tmp/test_models',
- 'MAX_TRAINING_TIME_MINUTES': '5',
- 'MIN_TRAINING_DATA_DAYS': '7',
- 'PROPHET_SEASONALITY_MODE': 'additive',
- 'ENABLE_SYNTHETIC_DATA': 'true',
- 'SKIP_EXTERNAL_API_CALLS': 'true'
- })
-
- yield
-
- # Cleanup environment - FIXED: removed (scope="session")
- test_vars = [
- 'ENVIRONMENT', 'LOG_LEVEL', 'MODEL_STORAGE_PATH',
- 'MAX_TRAINING_TIME_MINUTES', 'MIN_TRAINING_DATA_DAYS',
- 'PROPHET_SEASONALITY_MODE', 'ENABLE_SYNTHETIC_DATA',
- 'SKIP_EXTERNAL_API_CALLS'
- ]
-
- for var in test_vars:
- os.environ.pop(var, None) # FIXED: removed the erroneous (scope="session")
-
-def event_loop():
- """Create an instance of the default event loop for the test session."""
- loop = asyncio.new_event_loop()
- yield loop
- loop.close()
-
-
-def pytest_configure(config):
- """Configure pytest with custom markers"""
- config.addinivalue_line(
- "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
- )
- config.addinivalue_line(
- "markers", "integration: marks tests as integration tests"
- )
- config.addinivalue_line(
- "markers", "unit: marks tests as unit tests"
- )
- config.addinivalue_line(
- "markers", "performance: marks tests as performance tests"
- )
- config.addinivalue_line(
- "markers", "external: marks tests that require external services"
- )
-
-
-def pytest_collection_modifyitems(config, items):
- """Modify test collection to add markers automatically"""
- for item in items:
- # Mark performance tests
- if "performance" in item.nodeid:
- item.add_marker(pytest.mark.performance)
- item.add_marker(pytest.mark.slow)
-
- # Mark integration tests
- if "integration" in item.nodeid:
- item.add_marker(pytest.mark.integration)
-
- # Mark end-to-end tests
- if "end_to_end" in item.nodeid:
- item.add_marker(pytest.mark.integration)
- item.add_marker(pytest.mark.external)
-
- # Mark unit tests (default for others)
- if not any(marker.name in ["integration", "performance"] for marker in item.iter_markers()):
- item.add_marker(pytest.mark.unit)
-
-
-# ================================================================
-# TEST DATABASE FIXTURES
-# ================================================================
-
-@pytest_asyncio.fixture
-async def test_db_session():
- """Create async test database session"""
- from app.core.database import database_manager
-
- async with database_manager.async_session_local() as session:
- yield session
-
-@pytest.fixture
-def training_job_in_db(test_db_session):
- """Create a training job in database for testing"""
- from app.models.training import ModelTrainingLog # Add this import
- from datetime import datetime
-
- job = ModelTrainingLog(
- job_id="test-job-123",
- tenant_id="test-tenant",
- status="running",
- progress=50,
- current_step="Training models",
- start_time=datetime.now(), # Use start_time, not started_at
- config={"include_weather": True},
- created_at=datetime.now(),
- updated_at=datetime.now()
- )
- test_db_session.add(job)
- test_db_session.commit()
- test_db_session.refresh(job)
- return job
-
-@pytest.fixture
-def trained_model_in_db(test_db_session):
- """Create a trained model in database for testing"""
- from app.models.training import TrainedModel # Add this import
- from datetime import datetime
-
- model = TrainedModel(
- model_id="test-model-123",
- tenant_id="test-tenant",
- product_name="Pan Integral",
- model_type="prophet",
- model_path="/tmp/test_model.pkl",
- version=1,
- training_samples=100,
- features=["temperature", "humidity"],
- hyperparameters={"seasonality_mode": "additive"},
- training_metrics={"mae": 2.5, "mse": 8.3},
- is_active=True,
- created_at=datetime.now()
- )
- test_db_session.add(model)
- test_db_session.commit()
- test_db_session.refresh(model)
- return model
-
-# ================================================================
-# SAMPLE DATA FIXTURES
-# ================================================================
-
-@pytest.fixture
-def sample_bakery_sales_data():
- """Generate comprehensive bakery sales data for testing"""
- # Generate 1 year of data
- start_date = datetime(2023, 1, 1)
- dates = [start_date + timedelta(days=i) for i in range(365)]
-
- # Spanish bakery products with realistic patterns
- products = [
- "Pan Integral", "Pan Blanco", "Croissant", "Magdalenas",
- "Empanadas", "Tarta Chocolate", "Roscon Reyes", "Palmeras",
- "Donuts", "Berlinas", "Napolitanas", "Ensaimadas"
- ]
-
- # Product-specific configurations
- product_config = {
- "Pan Integral": {"base": 80, "price": 2.80, "weekend_boost": 1.1, "seasonal": False},
- "Pan Blanco": {"base": 120, "price": 2.50, "weekend_boost": 1.2, "seasonal": False},
- "Croissant": {"base": 45, "price": 1.50, "weekend_boost": 1.4, "seasonal": False},
- "Magdalenas": {"base": 30, "price": 1.20, "weekend_boost": 1.1, "seasonal": False},
- "Empanadas": {"base": 25, "price": 3.50, "weekend_boost": 0.9, "seasonal": False},
- "Tarta Chocolate": {"base": 15, "price": 18.00, "weekend_boost": 1.6, "seasonal": False},
- "Roscon Reyes": {"base": 8, "price": 25.00, "weekend_boost": 1.0, "seasonal": True},
- "Palmeras": {"base": 12, "price": 1.80, "weekend_boost": 1.2, "seasonal": False},
- "Donuts": {"base": 20, "price": 1.40, "weekend_boost": 1.3, "seasonal": False},
- "Berlinas": {"base": 18, "price": 1.60, "weekend_boost": 1.2, "seasonal": False},
- "Napolitanas": {"base": 22, "price": 1.70, "weekend_boost": 1.1, "seasonal": False},
- "Ensaimadas": {"base": 15, "price": 2.20, "weekend_boost": 1.0, "seasonal": False}
- }
-
- data = []
-
- for date in dates:
- # Calculate date-specific factors
- day_of_year = date.timetuple().tm_yday
- is_weekend = date.weekday() >= 5
- is_holiday = _is_spanish_holiday(date)
-
- # Madrid weather simulation
- temp = 14 + 12 * np.sin((day_of_year / 365) * 2 * np.pi) + np.random.normal(0, 3)
- precip = max(0, np.random.exponential(0.8))
-
- for product in products:
- config = product_config[product]
-
- # Base quantity
- base_qty = config["base"]
-
- # Apply weekend boost
- if is_weekend:
- base_qty *= config["weekend_boost"]
-
- # Apply holiday boost
- if is_holiday:
- base_qty *= 1.3
-
- # Seasonal products (like Roscon Reyes for Christmas)
- if config["seasonal"] and product == "Roscon Reyes":
- if date.month == 12:
- # Exponential increase through December
- base_qty *= (1 + (date.day - 1) / 5)
- elif date.month == 1 and date.day <= 6:
- # High demand until Epiphany (Jan 6)
- base_qty *= 3
- else:
- # Very low demand rest of year
- base_qty *= 0.1
-
- # Weather effects
- if temp > 30: # Very hot days
- if product in ["Pan Integral", "Pan Blanco"]:
- base_qty *= 0.7 # Less bread
- elif product in ["Donuts", "Berlinas"]:
- base_qty *= 0.8 # Less fried items
- elif temp < 5: # Cold days
- base_qty *= 1.15 # More baked goods
-
- # Add realistic noise and ensure minimum of 1
- quantity = max(1, int(base_qty + np.random.normal(0, base_qty * 0.12)))
- revenue = round(quantity * config["price"], 2)
-
- data.append({
- "date": date.strftime("%Y-%m-%d"),
- "product": product,
- "quantity": quantity,
- "revenue": revenue,
- "temperature": round(temp, 1),
- "precipitation": round(precip, 2),
- "is_weekend": is_weekend,
- "is_holiday": is_holiday
- })
-
- return pd.DataFrame(data)
-
-
-@pytest.fixture
-def sample_weather_data():
- """Generate realistic Madrid weather data"""
- start_date = datetime(2023, 1, 1)
- weather_data = []
-
- for i in range(365):
- date = start_date + timedelta(days=i)
- day_of_year = date.timetuple().tm_yday
-
- # Madrid climate simulation
- base_temp = 14 + 12 * np.sin((day_of_year / 365) * 2 * np.pi)
-
- # Seasonal humidity patterns
- base_humidity = 50 + 20 * np.sin((day_of_year / 365) * 2 * np.pi + np.pi)
-
- weather_data.append({
- "date": date,
- "temperature": round(base_temp + np.random.normal(0, 4), 1),
- "precipitation": max(0, np.random.exponential(1.2)),
- "humidity": np.random.uniform(25, 75),
- "wind_speed": np.random.uniform(3, 20),
- "pressure": np.random.uniform(995, 1025),
- "description": np.random.choice([
- "Soleado", "Parcialmente nublado", "Nublado",
- "Lluvia ligera", "Despejado", "Variable"
- ]),
- "source": "aemet_test"
- })
-
- return weather_data
-
-
-@pytest.fixture
-def sample_traffic_data():
- """Generate realistic Madrid traffic data"""
- start_date = datetime(2023, 1, 1)
- traffic_data = []
-
- for i in range(365):
- date = start_date + timedelta(days=i)
-
- # Generate multiple measurements per day
- for hour in range(6, 22, 2): # Every 2 hours from 6 AM to 10 PM
- measurement_time = date.replace(hour=hour)
-
- # Madrid traffic patterns
- if hour in [7, 8, 9, 18, 19, 20]: # Rush hours
- volume = np.random.randint(1200, 2000)
- congestion = "high"
- speed = np.random.randint(10, 25)
- occupation = np.random.randint(60, 90)
- elif hour in [12, 13, 14]: # Lunch time
- volume = np.random.randint(800, 1200)
- congestion = "medium"
- speed = np.random.randint(20, 35)
- occupation = np.random.randint(40, 70)
- else: # Off-peak
- volume = np.random.randint(300, 800)
- congestion = "low"
- speed = np.random.randint(30, 50)
- occupation = np.random.randint(15, 50)
-
- # Weekend adjustment
- if date.weekday() >= 5:
- volume = int(volume * 0.8) # Less traffic on weekends
- speed = min(50, int(speed * 1.2)) # Faster speeds
-
- traffic_data.append({
- "date": measurement_time,
- "traffic_volume": volume,
- "occupation_percentage": occupation,
- "load_percentage": min(95, occupation + np.random.randint(5, 15)),
- "average_speed": speed,
- "congestion_level": congestion,
- "pedestrian_count": np.random.randint(100, 800),
- "measurement_point_id": "MADRID_TEST_001",
- "measurement_point_name": "Plaza Mayor",
- "road_type": "URB",
- "source": "madrid_opendata_test"
- })
-
- return traffic_data
-
-
-# ================================================================
-# MOCK SERVICES FIXTURES
-# ================================================================
-
-@pytest.fixture
-async def mock_aemet_client(sample_weather_data):
- """Mock AEMET weather API client"""
- with patch('app.external.aemet.AEMETClient') as mock_class:
- mock_instance = AsyncMock()
- mock_class.return_value = mock_instance
-
- # Configure mock responses
- mock_instance.get_historical_weather.return_value = sample_weather_data
- mock_instance.get_current_weather.return_value = sample_weather_data[-1]
- mock_instance.get_weather_forecast.return_value = sample_weather_data[-7:]
-
- yield mock_instance
-
-
-@pytest.fixture
-async def mock_madrid_client(sample_traffic_data):
- """Mock Madrid OpenData API client"""
- with patch('app.external.madrid_opendata.MadridOpenDataClient') as mock_class:
- mock_instance = AsyncMock()
- mock_class.return_value = mock_instance
-
- # Configure mock responses
- mock_instance.get_historical_traffic.return_value = sample_traffic_data
- mock_instance.get_current_traffic.return_value = sample_traffic_data[-1]
-
- yield mock_instance
-
-
-@pytest.fixture
-async def mock_external_services(mock_aemet_client, mock_madrid_client):
- """Combined mock for all external services"""
- return {
- 'aemet': mock_aemet_client,
- 'madrid': mock_madrid_client
- }
-
-
-# ================================================================
-# ML COMPONENT FIXTURES
-# ================================================================
-
-@pytest.fixture
-def mock_ml_trainer():
- """Mock ML trainer for testing"""
- with patch('app.ml.trainer.BakeryMLTrainer') as mock_class:
- mock_instance = AsyncMock()
- mock_class.return_value = mock_instance
-
- # Configure successful training responses
- mock_instance.train_single_product.return_value = {
- "status": "completed",
- "model_id": "test_model_123",
- "metrics": {
- "mape": 25.5,
- "rmse": 12.3,
- "mae": 8.7,
- "r2_score": 0.85
- },
- "training_duration": 45.2,
- "data_points_used": 365
- }
-
- mock_instance.train_tenant_models.return_value = [
- {
- "product_name": "Pan Integral",
- "model_id": "model_pan_integral_123",
- "metrics": {"mape": 22.1, "rmse": 10.5, "mae": 7.8},
- "training_completed": True
- },
- {
- "product_name": "Croissant",
- "model_id": "model_croissant_456",
- "metrics": {"mape": 28.3, "rmse": 8.9, "mae": 6.2},
- "training_completed": True
- }
- ]
-
- yield mock_instance
-
-
-@pytest.fixture
-def mock_data_processor():
- """Mock data processor for testing"""
- with patch('app.ml.data_processor.BakeryDataProcessor') as mock_class:
- mock_instance = AsyncMock()
- mock_class.return_value = mock_instance
-
- # Configure mock responses
- mock_instance.validate_data_quality.return_value = {
- "is_valid": True,
- "data_points": 1000,
- "missing_percentage": 2.5,
- "issues": []
- }
-
- mock_instance.prepare_training_data.return_value = pd.DataFrame({
- "ds": pd.date_range("2023-01-01", periods=365),
- "y": np.random.randint(10, 100, 365),
- "temperature": np.random.uniform(0, 35, 365),
- "traffic_volume": np.random.randint(100, 2000, 365)
- })
-
- yield mock_instance
-
-@pytest.fixture
-def mock_data_service():
- """Mock data service for testing"""
- from unittest.mock import Mock, AsyncMock
-
- mock_service = Mock()
- mock_service.get_sales_data = AsyncMock(return_value=[
- {"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45},
- {"date": "2024-01-02", "product_name": "Pan Integral", "quantity": 38}
- ])
- mock_service.get_weather_data = AsyncMock(return_value=[
- {"date": "2024-01-01", "temperature": 20.5, "humidity": 65}
- ])
- mock_service.get_traffic_data = AsyncMock(return_value=[
- {"date": "2024-01-01", "traffic_index": 0.7}
- ])
-
- return mock_service
-
-@pytest.fixture
-def mock_prophet_manager():
- """Mock Prophet manager for testing"""
- with patch('app.ml.prophet_manager.BakeryProphetManager') as mock_class:
- mock_instance = AsyncMock()
- mock_class.return_value = mock_instance
-
- # Configure mock responses
- mock_instance.train_model.return_value = {
- "model": Mock(), # Mock Prophet model
- "metrics": {
- "mape": 23.7,
- "rmse": 11.2,
- "mae": 8.1
- },
- "cross_validation": {
- "cv_mape_mean": 25.1,
- "cv_mape_std": 3.2
- }
- }
-
- mock_instance.generate_predictions.return_value = pd.DataFrame({
- "ds": pd.date_range("2024-01-01", periods=30),
- "yhat": np.random.uniform(20, 80, 30),
- "yhat_lower": np.random.uniform(10, 60, 30),
- "yhat_upper": np.random.uniform(30, 100, 30)
- })
-
- yield mock_instance
-
-
-# ================================================================
-# UTILITY FIXTURES
-# ================================================================
-
-@pytest.fixture
-def temp_model_storage():
- """Temporary directory for model storage during tests"""
- with tempfile.TemporaryDirectory() as temp_dir:
- yield Path(temp_dir)
-
-
-@pytest.fixture
-def test_config():
- """Test configuration settings"""
- return {
- "MODEL_STORAGE_PATH": "/tmp/test_models",
- "MAX_TRAINING_TIME_MINUTES": 5,
- "MIN_TRAINING_DATA_DAYS": 7,
- "PROPHET_SEASONALITY_MODE": "additive",
- "INCLUDE_SPANISH_HOLIDAYS": True,
- "ENABLE_SYNTHETIC_DATA": True
- }
-
-
-@pytest.fixture
-def sample_training_request():
- """Sample training request for API tests"""
- return {
- "products": ["Pan Integral", "Croissant"],
- "include_weather": True,
- "include_traffic": True,
- "config": {
- "seasonality_mode": "additive",
- "changepoint_prior_scale": 0.05,
- "seasonality_prior_scale": 10.0,
- "validation_enabled": True
- }
- }
-
-
-@pytest.fixture
-def sample_single_product_request():
- """Sample single product training request"""
- return {
- "product_name": "Pan Integral",
- "include_weather": True,
- "include_traffic": False,
- "config": {
- "seasonality_mode": "multiplicative",
- "include_holidays": True,
- "holiday_prior_scale": 15.0
- }
- }
-
-
-# ================================================================
-# HELPER FUNCTIONS
-# ================================================================
-
-def _is_spanish_holiday(date: datetime) -> bool:
- """Check if date is a Spanish holiday"""
- spanish_holidays = [
- (1, 1), # Año Nuevo
- (1, 6), # Reyes Magos
- (5, 1), # Día del Trabajo
- (8, 15), # Asunción de la Virgen
- (10, 12), # Fiesta Nacional de España
- (11, 1), # Todos los Santos
- (12, 6), # Día de la Constitución
- (12, 8), # Inmaculada Concepción
- (12, 25), # Navidad
- ]
- return (date.month, date.day) in spanish_holidays
-
-
-@pytest.fixture
-def spanish_holidays_2023():
- """List of Spanish holidays for 2023"""
- holidays = []
- for month, day in [
- (1, 1), (1, 6), (5, 1), (8, 15), (10, 12),
- (11, 1), (12, 6), (12, 8), (12, 25)
- ]:
- holidays.append(datetime(2023, month, day))
- return holidays
-
-
-# ================================================================
-# PERFORMANCE TESTING FIXTURES
-# ================================================================
-
-@pytest.fixture
-def large_dataset_for_performance():
- """Generate large dataset for performance testing"""
- # Generate 2 years of data with 15 products
- start_date = datetime(2022, 1, 1)
- end_date = datetime(2024, 1, 1)
- date_range = pd.date_range(start=start_date, end=end_date, freq='D')
-
- products = [
- "Pan Integral", "Pan Blanco", "Croissant", "Magdalenas",
- "Empanadas", "Tarta Chocolate", "Roscon Reyes", "Palmeras",
- "Donuts", "Berlinas", "Napolitanas", "Ensaimadas",
- "Baguette", "Pan de Molde", "Bizcocho"
- ]
-
- data = []
- for date in date_range:
- for product in products:
- # Realistic sales with patterns
- base_quantity = np.random.randint(5, 150)
-
- # Seasonal patterns
- if date.month in [12, 1]: # Winter/Holiday season
- base_quantity *= 1.4
- elif date.month in [6, 7, 8]: # Summer
- base_quantity *= 0.8
-
- # Weekly patterns
- if date.weekday() >= 5: # Weekends
- base_quantity *= 1.2
- elif date.weekday() == 0: # Monday
- base_quantity *= 0.7
-
- # Add noise
- quantity = max(1, int(base_quantity + np.random.normal(0, base_quantity * 0.1)))
-
- data.append({
- "date": date.strftime("%Y-%m-%d"),
- "product": product,
- "quantity": quantity,
- "revenue": round(quantity * np.random.uniform(1.5, 8.0), 2),
- "temperature": round(15 + 12 * np.sin((date.timetuple().tm_yday / 365) * 2 * np.pi) + np.random.normal(0, 3), 1),
- "precipitation": max(0, np.random.exponential(0.8)),
- "is_weekend": date.weekday() >= 5,
- "is_holiday": _is_spanish_holiday(date)
- })
-
- return pd.DataFrame(data)
-
-
-@pytest.fixture
-def memory_monitor():
- """Memory monitoring utility for performance tests"""
- import psutil
- import gc
-
- class MemoryMonitor:
- def __init__(self):
- self.process = psutil.Process()
- self.snapshots = []
-
- def snapshot(self, label: str):
- gc.collect() # Force garbage collection
- memory_mb = self.process.memory_info().rss / 1024 / 1024
- self.snapshots.append({
- 'label': label,
- 'memory_mb': memory_mb,
- 'timestamp': datetime.now()
- })
- return memory_mb
-
- def get_peak_usage(self):
- if not self.snapshots:
- return 0
- return max(s['memory_mb'] for s in self.snapshots)
-
- def get_usage_increase(self):
- if len(self.snapshots) < 2:
- return 0
- return self.snapshots[-1]['memory_mb'] - self.snapshots[0]['memory_mb']
-
- def report(self):
- print("\n=== Memory Usage Report ===")
- for snapshot in self.snapshots:
- print(f"{snapshot['label']}: {snapshot['memory_mb']:.2f} MB")
- print(f"Peak Usage: {self.get_peak_usage():.2f} MB")
- print(f"Total Increase: {self.get_usage_increase():.2f} MB")
-
- return MemoryMonitor()
-
-
-@pytest.fixture
-def timing_monitor():
- """Timing monitoring utility for performance tests"""
- import time
-
- class TimingMonitor:
- def __init__(self):
- self.timings = []
- self.start_time = None
-
- def start(self, label: str):
- self.start_time = time.time()
- self.current_label = label
-
- def stop(self):
- if self.start_time is None:
- return 0
-
- duration = time.time() - self.start_time
- self.timings.append({
- 'label': self.current_label,
- 'duration': duration
- })
- self.start_time = None
- return duration
-
- def get_total_time(self):
- return sum(t['duration'] for t in self.timings)
-
- def report(self):
- print("\n=== Timing Report ===")
- for timing in self.timings:
- print(f"{timing['label']}: {timing['duration']:.2f}s")
- print(f"Total Time: {self.get_total_time():.2f}s")
-
- return TimingMonitor()
-
-
-# ================================================================
-# ADDITIONAL FIXTURES FOR COMPREHENSIVE TESTING
-# ================================================================
-
-@pytest.fixture
-def mock_job_scheduler():
- """Mock job scheduler for testing"""
- with patch('app.services.job_scheduler.JobScheduler') as mock_scheduler:
- mock_instance = Mock()
- mock_scheduler.return_value = mock_instance
-
- mock_instance.schedule_job.return_value = "scheduled_job_123"
- mock_instance.cancel_job.return_value = True
- mock_instance.get_job_status.return_value = "running"
-
- yield mock_instance
-
-
-@pytest.fixture
-def sample_model_metadata():
- """Sample model metadata for testing"""
- return {
- "model_id": "test_model_123",
- "tenant_id": "test_tenant",
- "product_name": "Pan Integral",
- "model_type": "prophet",
- "training_date": datetime.now().isoformat(),
- "data_points_used": 365,
- "features_used": ["temperature", "is_weekend", "is_holiday"],
- "metrics": {
- "mape": 23.5,
- "rmse": 12.3,
- "mae": 8.7,
- "r2_score": 0.85
- },
- "hyperparameters": {
- "seasonality_mode": "additive",
- "changepoint_prior_scale": 0.05,
- "seasonality_prior_scale": 10.0
- },
- "version": "1.0",
- "status": "active"
- }
-
-
-@pytest.fixture
-def training_progress_states():
- """Different training progress states for testing"""
- return [
- {"status": "pending", "progress": 0, "current_step": "Initializing training job"},
- {"status": "running", "progress": 10, "current_step": "Fetching sales data"},
- {"status": "running", "progress": 25, "current_step": "Processing weather data"},
- {"status": "running", "progress": 40, "current_step": "Processing traffic data"},
- {"status": "running", "progress": 55, "current_step": "Engineering features"},
- {"status": "running", "progress": 70, "current_step": "Training Pan Integral model"},
- {"status": "running", "progress": 85, "current_step": "Validating model performance"},
- {"status": "running", "progress": 95, "current_step": "Saving model artifacts"},
- {"status": "completed", "progress": 100, "current_step": "Training completed successfully"}
- ]
-
-
-@pytest.fixture
-def error_scenarios():
- """Different error scenarios for testing"""
- return {
- "insufficient_data": {
- "error_type": "DataError",
- "error_message": "Insufficient training data: only 15 days available, minimum 30 required",
- "error_code": "INSUFFICIENT_DATA"
- },
- "external_api_failure": {
- "error_type": "ExternalAPIError",
- "error_message": "Failed to fetch weather data from AEMET API",
- "error_code": "WEATHER_API_ERROR"
- },
- "model_training_failure": {
- "error_type": "ModelTrainingError",
- "error_message": "Prophet model training failed: unable to fit data",
- "error_code": "MODEL_TRAINING_FAILED"
- },
- "data_quality_error": {
- "error_type": "DataQualityError",
- "error_message": "Data quality issues detected: 45% missing values in quantity column",
- "error_code": "DATA_QUALITY_POOR"
- }
- }
-
-
-@pytest.fixture
-def performance_benchmarks():
- """Performance benchmarks for testing"""
- return {
- "single_product_training": {
- "max_duration_seconds": 120,
- "max_memory_mb": 500,
- "min_accuracy_mape": 50
- },
- "multi_product_training": {
- "max_duration_seconds": 300,
- "max_memory_mb": 1000,
- "min_accuracy_mape": 55
- },
- "data_processing": {
- "max_throughput_rows_per_second": 1000,
- "max_memory_per_1k_rows_mb": 10
- },
- "concurrent_jobs": {
- "max_concurrent_jobs": 5,
- "max_queue_time_seconds": 30
- }
- }
-
-
-@pytest.fixture
-def mock_model_storage():
- """Mock model storage system for testing"""
- storage = {}
-
- class MockModelStorage:
- def save_model(self, model_id: str, model_data: Any, metadata: Dict[str, Any]):
- storage[model_id] = {
- "model_data": model_data,
- "metadata": metadata,
- "saved_at": datetime.now()
- }
- return f"/models/{model_id}.pkl"
-
- def load_model(self, model_id: str):
- if model_id in storage:
- return storage[model_id]["model_data"]
- raise FileNotFoundError(f"Model {model_id} not found")
-
- def get_metadata(self, model_id: str):
- if model_id in storage:
- return storage[model_id]["metadata"]
- raise FileNotFoundError(f"Model {model_id} not found")
-
- def delete_model(self, model_id: str):
- if model_id in storage:
- del storage[model_id]
- return True
- return False
-
- def list_models(self, tenant_id: str = None):
- models = []
- for model_id, data in storage.items():
- if tenant_id is None or data["metadata"].get("tenant_id") == tenant_id:
- models.append({
- "model_id": model_id,
- "metadata": data["metadata"],
- "saved_at": data["saved_at"]
- })
- return models
-
- return MockModelStorage()
-
-
-@pytest.fixture
-def real_world_scenarios():
- """Real-world bakery scenarios for testing"""
- return {
- "holiday_rush": {
- "description": "Christmas season with high demand for seasonal products",
- "date_range": ("2023-12-15", "2023-12-31"),
- "expected_patterns": {
- "Roscon Reyes": {"multiplier": 5.0, "trend": "increasing"},
- "Pan Integral": {"multiplier": 1.3, "trend": "stable"},
- "Tarta Chocolate": {"multiplier": 2.0, "trend": "increasing"}
- }
- },
- "summer_slowdown": {
- "description": "Summer period with generally lower sales",
- "date_range": ("2023-07-01", "2023-08-31"),
- "expected_patterns": {
- "Pan Integral": {"multiplier": 0.8, "trend": "decreasing"},
- "Croissant": {"multiplier": 0.9, "trend": "stable"},
- "Cold_drinks": {"multiplier": 1.5, "trend": "increasing"}
- }
- },
- "weekend_patterns": {
- "description": "Weekend shopping patterns",
- "expected_patterns": {
- "weekend_boost": 1.2,
- "peak_hours": ["10:00", "11:00", "18:00", "19:00"],
- "popular_products": ["Croissant", "Palmeras", "Tarta Chocolate"]
- }
- },
- "weather_impact": {
- "description": "Weather impact on sales",
- "scenarios": {
- "rainy_day": {"bread_sales": 1.1, "pastry_sales": 0.9},
- "hot_day": {"bread_sales": 0.8, "cold_items": 1.3},
- "cold_day": {"bread_sales": 1.2, "hot_items": 1.4}
- }
- }
- }
-
-
-@pytest.fixture
-def data_quality_test_cases():
- """Various data quality test cases"""
- return {
- "missing_values": {
- "quantity_missing_5pct": 0.05,
- "quantity_missing_20pct": 0.20,
- "quantity_missing_50pct": 0.50,
- "revenue_missing_10pct": 0.10
- },
- "outliers": {
- "extreme_high": 100, # 100x normal values
- "extreme_low": 0.01, # Near-zero values
- "negative_values": -1,
- "outlier_percentage": 0.01
- },
- "inconsistencies": {
- "future_dates": ["2025-12-31", "2026-01-01"],
- "invalid_dates": ["2023-13-01", "2023-02-30"],
- "mismatched_revenue": True, # Revenue doesn't match quantity * price
- "duplicate_records": True
- },
- "insufficient_data": {
- "too_few_days": 10,
- "too_few_products": 1,
- "sporadic_data": 0.3 # Only 30% of expected data points
- }
- }
-
-
-@pytest.fixture
-def api_test_scenarios():
- """API testing scenarios"""
- return {
- "authentication": {
- "valid_token": "Bearer valid_test_token_123",
- "invalid_token": "Bearer invalid_token",
- "expired_token": "Bearer expired_token_456",
- "missing_token": None
- },
- "request_validation": {
- "valid_request": {
- "products": ["Pan Integral"],
- "include_weather": True,
- "include_traffic": True,
- "config": {"seasonality_mode": "additive"}
- },
- "invalid_products": {
- "products": [], # Empty products list
- "include_weather": True
- },
- "invalid_config": {
- "products": ["Pan Integral"],
- "config": {"seasonality_mode": "invalid_mode"}
- },
- "missing_required_fields": {
- "include_weather": True # Missing products
- }
- },
- "rate_limiting": {
- "max_requests_per_minute": 60,
- "burst_requests": 100
- }
- }
-
-
-@pytest.fixture
-def integration_test_dependencies():
- """Dependencies for integration testing"""
-
- class IntegrationDependencies:
- def __init__(self):
- self.external_services = {}
- self.databases = {}
- self.message_queues = {}
- self.storage_systems = {}
-
- def register_external_service(self, name: str, mock_instance):
- self.external_services[name] = mock_instance
-
- def register_database(self, name: str, mock_session):
- self.databases[name] = mock_session
-
- def register_message_queue(self, name: str, mock_queue):
- self.message_queues[name] = mock_queue
-
- def register_storage(self, name: str, mock_storage):
- self.storage_systems[name] = mock_storage
-
- def get_service(self, name: str):
- return self.external_services.get(name)
-
- def get_database(self, name: str):
- return self.databases.get(name)
-
- def are_all_services_healthy(self):
- # Mock health check for all registered services
- return len(self.external_services) > 0
-
- return IntegrationDependencies()
-
-
-@pytest.fixture
-def load_test_configuration():
- """Configuration for load testing"""
- return {
- "concurrent_users": {
- "light_load": 5,
- "medium_load": 15,
- "heavy_load": 30,
- "stress_load": 50
- },
- "test_duration": {
- "quick_test": 60, # 1 minute
- "standard_test": 300, # 5 minutes
- "extended_test": 900 # 15 minutes
- },
- "request_patterns": {
- "constant_rate": "steady",
- "ramp_up": "increasing",
- "spike": "burst",
- "random": "variable"
- },
- "success_criteria": {
- "min_success_rate": 0.95,
- "max_response_time": 30.0, # seconds
- "max_error_rate": 0.05
- }
- }
-
-
-@pytest.fixture
-def mock_notification_system():
- """Mock notification system for testing"""
- notifications_sent = []
-
- class MockNotificationSystem:
- def send_training_started(self, tenant_id: str, job_id: str, products: List[str]):
- notification = {
- "type": "training_started",
- "tenant_id": tenant_id,
- "job_id": job_id,
- "products": products,
- "timestamp": datetime.now()
- }
- notifications_sent.append(notification)
- return notification
-
- def send_training_completed(self, tenant_id: str, job_id: str, results: Dict[str, Any]):
- notification = {
- "type": "training_completed",
- "tenant_id": tenant_id,
- "job_id": job_id,
- "results": results,
- "timestamp": datetime.now()
- }
- notifications_sent.append(notification)
- return notification
-
- def send_training_failed(self, tenant_id: str, job_id: str, error: str):
- notification = {
- "type": "training_failed",
- "tenant_id": tenant_id,
- "job_id": job_id,
- "error": error,
- "timestamp": datetime.now()
- }
- notifications_sent.append(notification)
- return notification
-
- def get_notifications(self, tenant_id: str = None):
- if tenant_id:
- return [n for n in notifications_sent if n["tenant_id"] == tenant_id]
- return notifications_sent
-
- def clear_notifications(self):
- notifications_sent.clear()
-
- return MockNotificationSystem()
-
-
-@pytest.fixture
-def test_metrics_collector():
- """Test metrics collector for monitoring test performance"""
- metrics = {}
-
- class TestMetricsCollector:
- def __init__(self):
- self.start_times = {}
- self.counters = {}
- self.gauges = {}
- self.histograms = {}
-
- def start_timer(self, metric_name: str):
- self.start_times[metric_name] = time.time()
-
- def end_timer(self, metric_name: str):
- if metric_name in self.start_times:
- duration = time.time() - self.start_times[metric_name]
- if metric_name not in self.histograms:
- self.histograms[metric_name] = []
- self.histograms[metric_name].append(duration)
- del self.start_times[metric_name]
- return duration
- return 0
-
- def increment_counter(self, counter_name: str, value: int = 1):
- self.counters[counter_name] = self.counters.get(counter_name, 0) + value
-
- def set_gauge(self, gauge_name: str, value: float):
- self.gauges[gauge_name] = value
-
- def get_counter(self, counter_name: str):
- return self.counters.get(counter_name, 0)
-
- def get_gauge(self, gauge_name: str):
- return self.gauges.get(gauge_name, 0)
-
- def get_histogram_stats(self, histogram_name: str):
- if histogram_name not in self.histograms:
- return {}
-
- values = self.histograms[histogram_name]
- return {
- "count": len(values),
- "min": min(values) if values else 0,
- "max": max(values) if values else 0,
- "avg": sum(values) / len(values) if values else 0,
- "p50": sorted(values)[len(values)//2] if values else 0,
- "p95": sorted(values)[int(len(values)*0.95)] if values else 0,
- "p99": sorted(values)[int(len(values)*0.99)] if values else 0
- }
-
- def get_all_metrics(self):
- return {
- "counters": self.counters,
- "gauges": self.gauges,
- "histograms": {name: self.get_histogram_stats(name) for name in self.histograms}
- }
-
- def reset(self):
- self.start_times.clear()
- self.counters.clear()
- self.gauges.clear()
- self.histograms.clear()
-
- import time
- return TestMetricsCollector()
-
-
-# ================================================================
-# PYTEST PLUGINS AND HOOKS
-# ================================================================
-
-def pytest_runtest_setup(item):
- """Setup before each test"""
- # Add any pre-test setup logic here
- pass
-
-
-def pytest_runtest_teardown(item, nextitem):
- """Teardown after each test"""
- # Add any post-test cleanup logic here
- import gc
- gc.collect() # Force garbage collection after each test
-
-
-def pytest_sessionstart(session):
- """Called after the Session object has been created"""
- print("\n" + "="*80)
- print("TRAINING SERVICE TEST SESSION STARTING")
- print("="*80)
-
-
-def pytest_sessionfinish(session, exitstatus):
- """Called after whole test run finished"""
- print("\n" + "="*80)
- print("TRAINING SERVICE TEST SESSION FINISHED")
- print(f"Exit Status: {exitstatus}")
- print("="*80)
-
-
-# ================================================================
-# FINAL CONFIGURATION
-# ================================================================
-
-# Ensure numpy doesn't use too many threads during testing
-import numpy as np
-np.seterr(all='ignore') # Ignore numpy warnings during tests
-
-# Configure pandas for testing
-import pandas as pd
-pd.set_option('display.max_columns', None)
-pd.set_option('display.width', None)
-pd.set_option('display.max_colwidth', 50)
-
-# Set random seeds for reproducible tests
-np.random.seed(42)
-import random
-random.seed(42)
\ No newline at end of file
diff --git a/services/training/tests/run_tests.py b/services/training/tests/run_tests.py
deleted file mode 100644
index f99ac527..00000000
--- a/services/training/tests/run_tests.py
+++ /dev/null
@@ -1,673 +0,0 @@
-# ================================================================
-# services/training/tests/run_tests.py
-# ================================================================
-"""
-Main test runner script for Training Service
-Executes comprehensive test suite and generates reports
-"""
-
-import os
-import sys
-import asyncio
-import subprocess
-import json
-import time
-from datetime import datetime
-from pathlib import Path
-from typing import Dict, List, Any
-import logging
-
-# Setup logging
-logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-
-
-class TrainingTestRunner:
- """Main test runner for training service"""
-
- def __init__(self):
- self.test_dir = Path(__file__).parent
- self.results_dir = self.test_dir / "results"
- self.results_dir.mkdir(exist_ok=True)
-
- # Test configuration
- self.test_suites = {
- "unit": {
- "files": ["test_api.py", "test_ml.py", "test_service.py"],
- "description": "Unit tests for individual components",
- "timeout": 300 # 5 minutes
- },
- "integration": {
- "files": ["test_ml_pipeline_integration.py"],
- "description": "Integration tests for ML pipeline with external data",
- "timeout": 600 # 10 minutes
- },
- "performance": {
- "files": ["test_performance.py"],
- "description": "Performance and load testing",
- "timeout": 900 # 15 minutes
- },
- "end_to_end": {
- "files": ["test_end_to_end.py"],
- "description": "End-to-end workflow testing",
- "timeout": 800 # 13 minutes
- }
- }
-
- self.test_results = {}
-
- async def setup_test_environment(self):
- """Setup test environment and dependencies"""
- logger.info("Setting up test environment...")
-
- # Check if we're running in Docker
- if os.path.exists("/.dockerenv"):
- logger.info("Running in Docker environment")
- else:
- logger.info("Running in local environment")
-
- # Verify required files exist
- required_files = [
- "conftest.py",
- "test_ml_pipeline_integration.py",
- "test_performance.py"
- ]
-
- for file in required_files:
- file_path = self.test_dir / file
- if not file_path.exists():
- logger.warning(f"Required test file missing: {file}")
-
- # Create test data if needed
- await self.create_test_data()
-
- # Verify external services (mock or real)
- await self.verify_external_services()
-
- async def create_test_data(self):
- """Create or verify test data exists"""
- logger.info("Creating/verifying test data...")
-
- test_data_dir = self.test_dir / "fixtures" / "test_data"
- test_data_dir.mkdir(parents=True, exist_ok=True)
-
- # Create bakery sales sample if it doesn't exist
- sales_file = test_data_dir / "bakery_sales_sample.csv"
- if not sales_file.exists():
- logger.info("Creating sample sales data...")
- await self.generate_sample_sales_data(sales_file)
-
- # Create weather data sample
- weather_file = test_data_dir / "madrid_weather_sample.json"
- if not weather_file.exists():
- logger.info("Creating sample weather data...")
- await self.generate_sample_weather_data(weather_file)
-
- # Create traffic data sample
- traffic_file = test_data_dir / "madrid_traffic_sample.json"
- if not traffic_file.exists():
- logger.info("Creating sample traffic data...")
- await self.generate_sample_traffic_data(traffic_file)
-
- async def generate_sample_sales_data(self, file_path: Path):
- """Generate sample sales data for testing"""
- import pandas as pd
- import numpy as np
- from datetime import datetime, timedelta
-
- # Generate 6 months of sample data
- start_date = datetime(2023, 6, 1)
- dates = [start_date + timedelta(days=i) for i in range(180)]
-
- products = ["Pan Integral", "Croissant", "Magdalenas", "Empanadas", "Tarta Chocolate"]
-
- data = []
- for date in dates:
- for product in products:
- base_quantity = np.random.randint(10, 100)
-
- # Weekend boost
- if date.weekday() >= 5:
- base_quantity *= 1.2
-
- # Seasonal variation
- temp = 15 + 10 * np.sin((date.timetuple().tm_yday / 365) * 2 * np.pi)
-
- data.append({
- "date": date.strftime("%Y-%m-%d"),
- "product": product,
- "quantity": int(base_quantity),
- "revenue": round(base_quantity * np.random.uniform(2.5, 8.0), 2),
- "temperature": round(temp + np.random.normal(0, 3), 1),
- "precipitation": max(0, np.random.exponential(0.5)),
- "is_weekend": date.weekday() >= 5,
- "is_holiday": False
- })
-
- df = pd.DataFrame(data)
- df.to_csv(file_path, index=False)
- logger.info(f"Created sample sales data: {len(df)} records")
-
- async def generate_sample_weather_data(self, file_path: Path):
- """Generate sample weather data"""
- import json
- from datetime import datetime, timedelta
- import numpy as np
-
- start_date = datetime(2023, 6, 1)
- weather_data = []
-
- for i in range(180):
- date = start_date + timedelta(days=i)
- day_of_year = date.timetuple().tm_yday
- base_temp = 14 + 12 * np.sin((day_of_year / 365) * 2 * np.pi)
-
- weather_data.append({
- "date": date.isoformat(),
- "temperature": round(base_temp + np.random.normal(0, 5), 1),
- "precipitation": max(0, np.random.exponential(1.0)),
- "humidity": np.random.uniform(30, 80),
- "wind_speed": np.random.uniform(5, 25),
- "pressure": np.random.uniform(1000, 1025),
- "description": np.random.choice(["Soleado", "Nuboso", "Lluvioso"]),
- "source": "aemet_test"
- })
-
- with open(file_path, 'w') as f:
- json.dump(weather_data, f, indent=2)
- logger.info(f"Created sample weather data: {len(weather_data)} records")
-
- async def generate_sample_traffic_data(self, file_path: Path):
- """Generate sample traffic data"""
- import json
- from datetime import datetime, timedelta
- import numpy as np
-
- start_date = datetime(2023, 6, 1)
- traffic_data = []
-
- for i in range(180):
- date = start_date + timedelta(days=i)
-
- for hour in [8, 12, 18]: # Three measurements per day
- measurement_time = date.replace(hour=hour)
-
- if hour in [8, 18]: # Rush hours
- volume = np.random.randint(800, 1500)
- congestion = "high"
- else: # Lunch time
- volume = np.random.randint(400, 800)
- congestion = "medium"
-
- traffic_data.append({
- "date": measurement_time.isoformat(),
- "traffic_volume": volume,
- "occupation_percentage": np.random.randint(10, 90),
- "load_percentage": np.random.randint(20, 95),
- "average_speed": np.random.randint(15, 50),
- "congestion_level": congestion,
- "pedestrian_count": np.random.randint(50, 500),
- "measurement_point_id": "TEST_POINT_001",
- "measurement_point_name": "Plaza Mayor",
- "road_type": "URB",
- "source": "madrid_opendata_test"
- })
-
- with open(file_path, 'w') as f:
- json.dump(traffic_data, f, indent=2)
- logger.info(f"Created sample traffic data: {len(traffic_data)} records")
-
- async def verify_external_services(self):
- """Verify external services are available (mock or real)"""
- logger.info("Verifying external services...")
-
- # Check if mock services are available
- mock_services = [
- ("Mock AEMET", "http://localhost:8080/health"),
- ("Mock Madrid OpenData", "http://localhost:8081/health"),
- ("Mock Auth Service", "http://localhost:8082/health"),
- ("Mock Data Service", "http://localhost:8083/health")
- ]
-
- try:
- import httpx
- async with httpx.AsyncClient(timeout=5.0) as client:
- for service_name, url in mock_services:
- try:
- response = await client.get(url)
- if response.status_code == 200:
- logger.info(f"{service_name} is available")
- else:
- logger.warning(f"{service_name} returned status {response.status_code}")
- except Exception as e:
- logger.warning(f"{service_name} is not available: {e}")
- except ImportError:
- logger.warning("httpx not available, skipping service checks")
-
- def run_test_suite(self, suite_name: str) -> Dict[str, Any]:
- """Run a specific test suite"""
- suite_config = self.test_suites[suite_name]
- logger.info(f"Running {suite_name} test suite: {suite_config['description']}")
-
- start_time = time.time()
-
- # Prepare pytest command
- pytest_args = [
- "python", "-m", "pytest",
- "-v",
- "--tb=short",
- "--capture=no",
- f"--junitxml={self.results_dir}/junit_{suite_name}.xml",
- f"--cov=app",
- f"--cov-report=html:{self.results_dir}/coverage_{suite_name}_html",
- f"--cov-report=xml:{self.results_dir}/coverage_{suite_name}.xml",
- "--cov-report=term-missing"
- ]
-
- # Add test files
- for test_file in suite_config["files"]:
- test_path = self.test_dir / test_file
- if test_path.exists():
- pytest_args.append(str(test_path))
- else:
- logger.warning(f"Test file not found: {test_file}")
-
- # Run the tests
- try:
- result = subprocess.run(
- pytest_args,
- cwd=self.test_dir.parent, # Run from training service root
- capture_output=True,
- text=True,
- timeout=suite_config["timeout"]
- )
-
- duration = time.time() - start_time
-
- return {
- "suite": suite_name,
- "status": "passed" if result.returncode == 0 else "failed",
- "return_code": result.returncode,
- "duration": duration,
- "stdout": result.stdout,
- "stderr": result.stderr,
- "timestamp": datetime.now().isoformat()
- }
-
- except subprocess.TimeoutExpired:
- duration = time.time() - start_time
- logger.error(f"Test suite {suite_name} timed out after {duration:.2f}s")
-
- return {
- "suite": suite_name,
- "status": "timeout",
- "return_code": -1,
- "duration": duration,
- "stdout": "",
- "stderr": f"Test suite timed out after {suite_config['timeout']}s",
- "timestamp": datetime.now().isoformat()
- }
-
- except Exception as e:
- duration = time.time() - start_time
- logger.error(f"Error running test suite {suite_name}: {e}")
-
- return {
- "suite": suite_name,
- "status": "error",
- "return_code": -1,
- "duration": duration,
- "stdout": "",
- "stderr": str(e),
- "timestamp": datetime.now().isoformat()
- }
-
- def generate_test_report(self):
- """Generate comprehensive test report"""
- logger.info("Generating test report...")
-
- # Calculate summary statistics
- total_suites = len(self.test_results)
- passed_suites = sum(1 for r in self.test_results.values() if r["status"] == "passed")
- failed_suites = sum(1 for r in self.test_results.values() if r["status"] == "failed")
- error_suites = sum(1 for r in self.test_results.values() if r["status"] == "error")
- timeout_suites = sum(1 for r in self.test_results.values() if r["status"] == "timeout")
-
- total_duration = sum(r["duration"] for r in self.test_results.values())
-
- # Create detailed report
- report = {
- "test_run_summary": {
- "timestamp": datetime.now().isoformat(),
- "total_suites": total_suites,
- "passed_suites": passed_suites,
- "failed_suites": failed_suites,
- "error_suites": error_suites,
- "timeout_suites": timeout_suites,
- "success_rate": (passed_suites / total_suites * 100) if total_suites > 0 else 0,
- "total_duration_seconds": total_duration
- },
- "suite_results": self.test_results,
- "recommendations": self.generate_recommendations()
- }
-
- # Save JSON report
- report_file = self.results_dir / "test_report.json"
- with open(report_file, 'w') as f:
- json.dump(report, f, indent=2)
-
- # Generate HTML report
- self.generate_html_report(report)
-
- # Print summary to console
- self.print_test_summary(report)
-
- return report
-
- def generate_recommendations(self) -> List[str]:
- """Generate recommendations based on test results"""
- recommendations = []
-
- failed_suites = [name for name, result in self.test_results.items() if result["status"] == "failed"]
- timeout_suites = [name for name, result in self.test_results.items() if result["status"] == "timeout"]
-
- if failed_suites:
- recommendations.append(f"Failed test suites: {', '.join(failed_suites)}. Check logs for detailed error messages.")
-
- if timeout_suites:
- recommendations.append(f"Timeout in suites: {', '.join(timeout_suites)}. Consider increasing timeout or optimizing performance.")
-
- # Performance recommendations
- slow_suites = [
- name for name, result in self.test_results.items()
- if result["duration"] > 300 # 5 minutes
- ]
- if slow_suites:
- recommendations.append(f"Slow test suites: {', '.join(slow_suites)}. Consider performance optimization.")
-
- if not recommendations:
- recommendations.append("All tests passed successfully! Consider adding more edge case tests.")
-
- return recommendations
-
- def generate_html_report(self, report: Dict[str, Any]):
- """Generate HTML test report"""
- html_template = """
-
-
-
- Training Service Test Report
-
-
-
-
-
-
-
-
{total_suites}
-
Total Suites
-
-
-
{passed_suites}
-
Passed
-
-
-
{failed_suites}
-
Failed
-
-
-
{timeout_suites}
-
Timeout
-
-
-
{success_rate:.1f}%
-
Success Rate
-
-
-
{duration:.1f}s
-
Total Duration
-
-
-
-
-
Recommendations
-
- {recommendations_html}
-
-
-
- Suite Results
- {suite_results_html}
-
-
-
- """
-
- # Format recommendations
- recommendations_html = '\n'.join(
- f"{rec}" for rec in report["recommendations"]
- )
-
- # Format suite results
- suite_results_html = ""
- for suite_name, result in report["suite_results"].items():
- status_class = result["status"]
- suite_results_html += f"""
-
-
{suite_name.title()} Tests ({result["status"].upper()})
-
Duration: {result["duration"]:.2f}s
-
Return Code: {result["return_code"]}
-
- {f'
Output:
{result["stdout"][:1000]}{"..." if len(result["stdout"]) > 1000 else ""}' if result["stdout"] else ""}
- {f'
Errors:
{result["stderr"][:1000]}{"..." if len(result["stderr"]) > 1000 else ""}' if result["stderr"] else ""}
-
- """
-
- # Fill template
- html_content = html_template.format(
- timestamp=report["test_run_summary"]["timestamp"],
- total_suites=report["test_run_summary"]["total_suites"],
- passed_suites=report["test_run_summary"]["passed_suites"],
- failed_suites=report["test_run_summary"]["failed_suites"],
- timeout_suites=report["test_run_summary"]["timeout_suites"],
- success_rate=report["test_run_summary"]["success_rate"],
- duration=report["test_run_summary"]["total_duration_seconds"],
- recommendations_html=recommendations_html,
- suite_results_html=suite_results_html
- )
-
- # Save HTML report
- html_file = self.results_dir / "test_report.html"
- with open(html_file, 'w') as f:
- f.write(html_content)
-
- logger.info(f"HTML report saved to: {html_file}")
-
- def print_test_summary(self, report: Dict[str, Any]):
- """Print test summary to console"""
- summary = report["test_run_summary"]
-
- print("\n" + "=" * 80)
- print("TRAINING SERVICE TEST RESULTS SUMMARY")
- print("=" * 80)
- print(f"Timestamp: {summary['timestamp']}")
- print(f"Total Suites: {summary['total_suites']}")
- print(f"Passed: {summary['passed_suites']}")
- print(f"Failed: {summary['failed_suites']}")
- print(f"Errors: {summary['error_suites']}")
- print(f"Timeouts: {summary['timeout_suites']}")
- print(f"Success Rate: {summary['success_rate']:.1f}%")
- print(f"Total Duration: {summary['total_duration_seconds']:.2f}s")
-
- print("\nSUITE DETAILS:")
- print("-" * 50)
- for suite_name, result in report["suite_results"].items():
- status_icon = "✅" if result["status"] == "passed" else "❌"
- print(f"{status_icon} {suite_name.ljust(15)}: {result['status'].upper().ljust(10)} ({result['duration']:.2f}s)")
-
- print("\nRECOMMENDATIONS:")
- print("-" * 50)
- for i, rec in enumerate(report["recommendations"], 1):
- print(f"{i}. {rec}")
-
- print("\nFILES GENERATED:")
- print("-" * 50)
- print(f"📄 JSON Report: {self.results_dir}/test_report.json")
- print(f"🌐 HTML Report: {self.results_dir}/test_report.html")
- print(f"📊 Coverage Reports: {self.results_dir}/coverage_*_html/")
- print(f"📋 JUnit XML: {self.results_dir}/junit_*.xml")
- print("=" * 80)
-
- async def run_all_tests(self):
- """Run all test suites"""
- logger.info("Starting comprehensive test run...")
-
- # Setup environment
- await self.setup_test_environment()
-
- # Run each test suite
- for suite_name in self.test_suites.keys():
- logger.info(f"Starting {suite_name} test suite...")
- result = self.run_test_suite(suite_name)
- self.test_results[suite_name] = result
-
- if result["status"] == "passed":
- logger.info(f"✅ {suite_name} tests PASSED ({result['duration']:.2f}s)")
- elif result["status"] == "failed":
- logger.error(f"❌ {suite_name} tests FAILED ({result['duration']:.2f}s)")
- elif result["status"] == "timeout":
- logger.error(f"⏰ {suite_name} tests TIMED OUT ({result['duration']:.2f}s)")
- else:
- logger.error(f"💥 {suite_name} tests ERROR ({result['duration']:.2f}s)")
-
- # Generate final report
- report = self.generate_test_report()
-
- return report
-
- def run_specific_suite(self, suite_name: str):
- """Run a specific test suite"""
- if suite_name not in self.test_suites:
- logger.error(f"Unknown test suite: {suite_name}")
- logger.info(f"Available suites: {', '.join(self.test_suites.keys())}")
- return None
-
- logger.info(f"Running {suite_name} test suite only...")
- result = self.run_test_suite(suite_name)
- self.test_results[suite_name] = result
-
- # Generate report for single suite
- report = self.generate_test_report()
- return report
-
-
-# ================================================================
-# MAIN EXECUTION
-# ================================================================
-
-async def main():
- """Main execution function"""
- import argparse
-
- parser = argparse.ArgumentParser(description="Training Service Test Runner")
- parser.add_argument(
- "--suite",
- choices=list(TrainingTestRunner().test_suites.keys()) + ["all"],
- default="all",
- help="Test suite to run (default: all)"
- )
- parser.add_argument(
- "--verbose", "-v",
- action="store_true",
- help="Verbose output"
- )
- parser.add_argument(
- "--quick",
- action="store_true",
- help="Run quick tests only (skip performance tests)"
- )
-
- args = parser.parse_args()
-
- # Setup logging level
- if args.verbose:
- logging.getLogger().setLevel(logging.DEBUG)
-
- # Create test runner
- runner = TrainingTestRunner()
-
- # Modify test suites for quick run
- if args.quick:
- # Skip performance tests in quick mode
- if "performance" in runner.test_suites:
- del runner.test_suites["performance"]
- logger.info("Quick mode: Skipping performance tests")
-
- try:
- if args.suite == "all":
- report = await runner.run_all_tests()
- else:
- report = runner.run_specific_suite(args.suite)
-
- # Exit with appropriate code
- if report and report["test_run_summary"]["failed_suites"] == 0 and report["test_run_summary"]["error_suites"] == 0:
- logger.info("All tests completed successfully!")
- sys.exit(0)
- else:
- logger.error("Some tests failed!")
- sys.exit(1)
-
- except KeyboardInterrupt:
- logger.info("Test run interrupted by user")
- sys.exit(130)
- except Exception as e:
- logger.error(f"Test run failed with error: {e}")
- sys.exit(1)
-
-
-if __name__ == "__main__":
- # Handle both direct execution and pytest discovery
- if len(sys.argv) > 1 and sys.argv[1] in ["--suite", "-h", "--help"]:
- # Running as main script with arguments
- asyncio.run(main())
- else:
- # Running as pytest discovery or direct execution without args
- print("Training Service Test Runner")
- print("=" * 50)
- print("Usage:")
- print(" python run_tests.py --suite all # Run all test suites")
- print(" python run_tests.py --suite unit # Run unit tests only")
- print(" python run_tests.py --suite integration # Run integration tests only")
- print(" python run_tests.py --suite performance # Run performance tests only")
- print(" python run_tests.py --quick # Run quick tests (skip performance)")
- print(" python run_tests.py -v # Verbose output")
- print()
- print("Available test suites:")
- runner = TrainingTestRunner()
- for suite_name, config in runner.test_suites.items():
- print(f" {suite_name.ljust(15)}: {config['description']}")
- print()
-
- # If no arguments provided, run all tests
- if len(sys.argv) == 1:
- print("No arguments provided. Running all tests...")
- asyncio.run(TrainingTestRunner().run_all_tests())
\ No newline at end of file
diff --git a/services/training/tests/test_api.py b/services/training/tests/test_api.py
deleted file mode 100644
index f9b696b6..00000000
--- a/services/training/tests/test_api.py
+++ /dev/null
@@ -1,687 +0,0 @@
-# services/training/tests/test_api.py
-"""
-Tests for training service API endpoints
-"""
-
-import pytest
-from unittest.mock import AsyncMock, patch
-from fastapi import status
-from httpx import AsyncClient
-
-from app.schemas.training import TrainingJobRequest
-
-
-class TestTrainingAPI:
- """Test training API endpoints"""
-
- @pytest.mark.asyncio
- async def test_health_check(self, test_client: AsyncClient):
- """Test health check endpoint"""
- response = await test_client.get("/health")
-
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
- assert data["service"] == "training-service"
- assert data["version"] == "1.0.0"
- assert "status" in data
-
- @pytest.mark.asyncio
- async def test_readiness_check_ready(self, test_client: AsyncClient):
- """Test readiness check when service is ready"""
- # Mock app state as ready
- from app.main import app # Add import at top
- with patch.object(app.state, 'ready', True, create=True):
- response = await test_client.get("/health/ready")
-
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
- assert data["status"] == "ready"
-
- @pytest.mark.asyncio
- async def test_readiness_check_not_ready(self, test_client: AsyncClient):
- """Test readiness check when service is not ready"""
- with patch('app.main.app.state.ready', False):
- response = await test_client.get("/health/ready")
-
- assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
- data = response.json()
- assert data["status"] == "not_ready"
-
- @pytest.mark.asyncio
- async def test_liveness_check_healthy(self, test_client: AsyncClient):
- """Test liveness check when service is healthy"""
- with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=True)):
- response = await test_client.get("/health/live")
-
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
- assert data["status"] == "alive"
-
- @pytest.mark.asyncio
- async def test_liveness_check_unhealthy(self, test_client: AsyncClient):
- """Test liveness check when database is unhealthy"""
- with patch('app.core.database.get_db_health', return_value=AsyncMock(return_value=False)):
- response = await test_client.get("/health/live")
-
- assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE
- data = response.json()
- assert data["status"] == "unhealthy"
- assert data["reason"] == "database_unavailable"
-
- @pytest.mark.asyncio
- async def test_metrics_endpoint(self, test_client: AsyncClient):
- """Test metrics endpoint"""
- response = await test_client.get("/metrics")
-
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
-
- expected_metrics = [
- "training_jobs_active",
- "training_jobs_completed",
- "training_jobs_failed",
- "models_trained_total",
- "uptime_seconds"
- ]
-
- for metric in expected_metrics:
- assert metric in data
-
- @pytest.mark.asyncio
- async def test_root_endpoint(self, test_client: AsyncClient):
- """Test root endpoint"""
- response = await test_client.get("/")
-
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
- assert data["service"] == "training-service"
- assert data["version"] == "1.0.0"
- assert "description" in data
-
-
-class TestTrainingJobsAPI:
- """Test training jobs API endpoints"""
-
- @pytest.mark.asyncio
- async def test_start_training_job_success(
- self,
- test_client: AsyncClient,
- mock_messaging,
- mock_ml_trainer,
- mock_data_service
- ):
- """Test starting a training job successfully"""
- request_data = {
- "include_weather": True,
- "include_traffic": True,
- "min_data_points": 30,
- "seasonality_mode": "additive"
- }
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs", json=request_data)
-
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
-
- assert "job_id" in data
- assert data["status"] == "started"
- assert data["tenant_id"] == "test-tenant"
- assert "estimated_duration_minutes" in data
-
- @pytest.mark.asyncio
- async def test_start_training_job_validation_error(self, test_client: AsyncClient):
- """Test starting training job with validation error"""
- request_data = {
- "seasonality_mode": "invalid_mode", # Invalid value
- "min_data_points": 5 # Too low
- }
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs", json=request_data)
-
- assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
-
- @pytest.mark.asyncio
- async def test_get_training_status_existing_job(
- self,
- test_client: AsyncClient,
- training_job_in_db
- ):
- """Test getting status of existing training job"""
- job_id = training_job_in_db.job_id
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.get(f"/training/jobs/{job_id}/status")
-
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
-
- assert data["job_id"] == job_id
- assert data["status"] == "pending"
- assert "progress" in data
- assert "started_at" in data
-
- @pytest.mark.asyncio
- async def test_get_training_status_nonexistent_job(self, test_client: AsyncClient):
- """Test getting status of non-existent training job"""
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.get("/training/jobs/nonexistent-job/status")
-
- assert response.status_code == status.HTTP_404_NOT_FOUND
-
- @pytest.mark.asyncio
- async def test_list_training_jobs(
- self,
- test_client: AsyncClient,
- training_job_in_db
- ):
- """Test listing training jobs"""
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.get("/training/jobs")
-
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
-
- assert isinstance(data, list)
- assert len(data) >= 1
-
- # Check first job structure
- job = data[0]
- assert "job_id" in job
- assert "status" in job
- assert "started_at" in job
-
- @pytest.mark.asyncio
- async def test_list_training_jobs_with_status_filter(
- self,
- test_client: AsyncClient,
- training_job_in_db
- ):
- """Test listing training jobs with status filter"""
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.get("/training/jobs?status=pending")
-
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
-
- assert isinstance(data, list)
- # All jobs should have status "pending"
- for job in data:
- assert job["status"] == "pending"
-
- @pytest.mark.asyncio
- async def test_cancel_training_job_success(
- self,
- test_client: AsyncClient,
- training_job_in_db,
- mock_messaging
- ):
- """Test cancelling a training job successfully"""
- job_id = training_job_in_db.job_id
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post(f"/training/jobs/{job_id}/cancel")
-
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
- assert "message" in data
- assert "cancelled" in data["message"].lower()
-
- @pytest.mark.asyncio
- async def test_cancel_nonexistent_job(self, test_client: AsyncClient):
- """Test cancelling a non-existent training job"""
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs/nonexistent-job/cancel")
-
- assert response.status_code == status.HTTP_404_NOT_FOUND
-
- @pytest.mark.asyncio
- async def test_get_training_logs(
- self,
- test_client: AsyncClient,
- training_job_in_db
- ):
- """Test getting training logs"""
- job_id = training_job_in_db.job_id
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.get(f"/training/jobs/{job_id}/logs")
-
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
-
- assert "job_id" in data
- assert "logs" in data
- assert isinstance(data["logs"], list)
-
- @pytest.mark.asyncio
- async def test_validate_training_data_valid(
- self,
- test_client: AsyncClient,
- mock_data_service
- ):
- """Test validating valid training data"""
- request_data = {
- "include_weather": True,
- "include_traffic": True,
- "min_data_points": 30
- }
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/validate", json=request_data)
-
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
-
- assert "is_valid" in data
- assert "issues" in data
- assert "recommendations" in data
- assert "estimated_training_time" in data
-
-
-class TestSingleProductTrainingAPI:
- """Test single product training API endpoints"""
-
- @pytest.mark.asyncio
- async def test_train_single_product_success(
- self,
- test_client: AsyncClient,
- mock_messaging,
- mock_ml_trainer,
- mock_data_service
- ):
- """Test training a single product successfully"""
- product_name = "Pan Integral"
- request_data = {
- "include_weather": True,
- "include_traffic": True,
- "seasonality_mode": "additive"
- }
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post(
- f"/training/products/{product_name}",
- json=request_data
- )
-
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
-
- assert "job_id" in data
- assert data["status"] == "started"
- assert data["tenant_id"] == "test-tenant"
- assert f"training started for {product_name}" in data["message"].lower()
-
- @pytest.mark.asyncio
- async def test_train_single_product_validation_error(self, test_client: AsyncClient):
- """Test single product training with validation error"""
- product_name = "Pan Integral"
- request_data = {
- "seasonality_mode": "invalid_mode" # Invalid value
- }
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post(
- f"/training/products/{product_name}",
- json=request_data
- )
-
- assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
-
- @pytest.mark.asyncio
- async def test_train_single_product_special_characters(
- self,
- test_client: AsyncClient,
- mock_messaging,
- mock_ml_trainer,
- mock_data_service
- ):
- """Test training product with special characters in name"""
- product_name = "Pan Francés" # With accent
- request_data = {
- "include_weather": True,
- "seasonality_mode": "additive"
- }
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post(
- f"/training/products/{product_name}",
- json=request_data
- )
-
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
- assert "job_id" in data
-
-
-class TestModelsAPI:
- """Test models API endpoints"""
-
- @pytest.mark.asyncio
- async def test_list_models(
- self,
- test_client: AsyncClient,
- trained_model_in_db
- ):
- """Test listing trained models"""
- with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"):
- response = await test_client.get("/models")
-
- # This endpoint might not exist yet, so we expect either 200 or 404
- assert response.status_code in [status.HTTP_200_OK, status.HTTP_404_NOT_FOUND]
-
- if response.status_code == status.HTTP_200_OK:
- data = response.json()
- assert isinstance(data, list)
-
- @pytest.mark.asyncio
- async def test_get_model_details(
- self,
- test_client: AsyncClient,
- trained_model_in_db
- ):
- """Test getting model details"""
- model_id = trained_model_in_db.model_id
-
- with patch('app.api.models.get_current_tenant_id', return_value="test-tenant"):
- response = await test_client.get(f"/models/{model_id}")
-
- # This endpoint might not exist yet
- assert response.status_code in [
- status.HTTP_200_OK,
- status.HTTP_404_NOT_FOUND,
- status.HTTP_501_NOT_IMPLEMENTED
- ]
-
-
-class TestErrorHandling:
- """Test error handling in API endpoints"""
-
- @pytest.mark.asyncio
- async def test_database_error_handling(self, test_client: AsyncClient):
- """Test handling of database errors"""
- with patch('app.services.training_service.TrainingService.create_training_job') as mock_create:
- mock_create.side_effect = Exception("Database connection failed")
-
- request_data = {
- "include_weather": True,
- "include_traffic": True,
- "min_data_points": 30
- }
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs", json=request_data)
-
- assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
-
- @pytest.mark.asyncio
- async def test_missing_tenant_id(self, test_client: AsyncClient):
- """Test handling when tenant ID is missing"""
- request_data = {
- "include_weather": True,
- "include_traffic": True,
- "min_data_points": 30
- }
-
- # Don't mock get_current_tenant_id to simulate missing auth
- response = await test_client.post("/training/jobs", json=request_data)
-
- # Should fail due to missing authentication
- assert response.status_code in [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN]
-
- @pytest.mark.asyncio
- async def test_invalid_job_id_format(self, test_client: AsyncClient):
- """Test handling of invalid job ID format"""
- invalid_job_id = "invalid-job-id-with-special-chars@#$"
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.get(f"/training/jobs/{invalid_job_id}/status")
-
- # Should handle gracefully
- assert response.status_code in [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST]
-
- @pytest.mark.asyncio
- async def test_messaging_failure_handling(
- self,
- test_client: AsyncClient,
- mock_data_service
- ):
- """Test handling when messaging fails"""
- request_data = {
- "include_weather": True,
- "include_traffic": True,
- "min_data_points": 30
- }
-
- with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")), \
- patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
-
- response = await test_client.post("/training/jobs", json=request_data)
-
- # Should still succeed even if messaging fails
- assert response.status_code == status.HTTP_200_OK
- data = response.json()
- assert "job_id" in data
-
- @pytest.mark.asyncio
- async def test_invalid_json_payload(self, test_client: AsyncClient):
- """Test handling of invalid JSON payload"""
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post(
- "/training/jobs",
- content="invalid json {{{",
- headers={"Content-Type": "application/json"}
- )
-
- assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
-
- @pytest.mark.asyncio
- async def test_unsupported_content_type(self, test_client: AsyncClient):
- """Test handling of unsupported content type"""
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post(
- "/training/jobs",
- content="some text data",
- headers={"Content-Type": "text/plain"}
- )
-
- assert response.status_code in [
- status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
- status.HTTP_422_UNPROCESSABLE_ENTITY
- ]
-
-
-class TestAuthenticationIntegration:
- """Test authentication integration"""
-
- @pytest.mark.asyncio
- async def test_endpoints_require_auth(self, test_client: AsyncClient):
- """Test that endpoints require authentication in production"""
- # This test would be more meaningful in a production environment
- # where authentication is actually enforced
-
- endpoints_to_test = [
- ("POST", "/training/jobs"),
- ("GET", "/training/jobs"),
- ("POST", "/training/products/Pan Integral"),
- ("POST", "/training/validate")
- ]
-
- for method, endpoint in endpoints_to_test:
- if method == "POST":
- response = await test_client.post(endpoint, json={})
- else:
- response = await test_client.get(endpoint)
-
- # In test environment with mocked auth, should work
- # In production, would require valid authentication
- assert response.status_code != status.HTTP_500_INTERNAL_SERVER_ERROR
-
- @pytest.mark.asyncio
- async def test_tenant_isolation_in_api(
- self,
- test_client: AsyncClient,
- training_job_in_db
- ):
- """Test tenant isolation at API level"""
- job_id = training_job_in_db.job_id
-
- # Try to access job with different tenant
- with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"):
- response = await test_client.get(f"/training/jobs/{job_id}/status")
-
- # Should not find job for different tenant
- assert response.status_code == status.HTTP_404_NOT_FOUND
-
-
-class TestAPIValidation:
- """Test API validation and input handling"""
-
- @pytest.mark.asyncio
- async def test_training_request_validation(self, test_client: AsyncClient):
- """Test comprehensive training request validation"""
-
- # Test valid request
- valid_request = {
- "include_weather": True,
- "include_traffic": False,
- "min_data_points": 30,
- "seasonality_mode": "additive",
- "daily_seasonality": True,
- "weekly_seasonality": True,
- "yearly_seasonality": True
- }
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs", json=valid_request)
-
- assert response.status_code == status.HTTP_200_OK
-
- # Test invalid seasonality mode
- invalid_request = valid_request.copy()
- invalid_request["seasonality_mode"] = "invalid_mode"
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs", json=invalid_request)
-
- assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
-
- # Test invalid min_data_points
- invalid_request = valid_request.copy()
- invalid_request["min_data_points"] = 5 # Too low
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs", json=invalid_request)
-
- assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
-
- @pytest.mark.asyncio
- async def test_single_product_request_validation(self, test_client: AsyncClient):
- """Test single product training request validation"""
-
- product_name = "Pan Integral"
-
- # Test valid request
- valid_request = {
- "include_weather": True,
- "include_traffic": True,
- "seasonality_mode": "multiplicative"
- }
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post(
- f"/training/products/{product_name}",
- json=valid_request
- )
-
- assert response.status_code == status.HTTP_200_OK
-
- # Test empty product name
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post(
- "/training/products/",
- json=valid_request
- )
-
- assert response.status_code == status.HTTP_404_NOT_FOUND
-
- @pytest.mark.asyncio
- async def test_query_parameter_validation(self, test_client: AsyncClient):
- """Test query parameter validation"""
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- # Test valid limit parameter
- response = await test_client.get("/training/jobs?limit=5")
- assert response.status_code == status.HTTP_200_OK
-
- # Test invalid limit parameter
- response = await test_client.get("/training/jobs?limit=invalid")
- assert response.status_code in [
- status.HTTP_422_UNPROCESSABLE_ENTITY,
- status.HTTP_400_BAD_REQUEST
- ]
-
- # Test negative limit
- response = await test_client.get("/training/jobs?limit=-1")
- assert response.status_code in [
- status.HTTP_422_UNPROCESSABLE_ENTITY,
- status.HTTP_400_BAD_REQUEST
- ]
-
-
-class TestAPIPerformance:
- """Test API performance characteristics"""
-
- @pytest.mark.asyncio
- async def test_concurrent_requests(self, test_client: AsyncClient):
- """Test handling of concurrent requests"""
- import asyncio
-
- # Create multiple concurrent requests
- tasks = []
- for i in range(10):
- with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"):
- task = test_client.get("/health")
- tasks.append(task)
-
- responses = await asyncio.gather(*tasks)
-
- # All requests should succeed
- for response in responses:
- assert response.status_code == status.HTTP_200_OK
-
- @pytest.mark.asyncio
- async def test_large_payload_handling(self, test_client: AsyncClient):
- """Test handling of large request payloads"""
-
- # Create large request payload
- large_request = {
- "include_weather": True,
- "include_traffic": True,
- "min_data_points": 30,
- "large_config": {f"key_{i}": f"value_{i}" for i in range(1000)}
- }
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs", json=large_request)
-
- # Should handle large payload gracefully
- assert response.status_code in [
- status.HTTP_200_OK,
- status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
- ]
-
- @pytest.mark.asyncio
- async def test_rapid_successive_requests(self, test_client: AsyncClient):
- """Test rapid successive requests to same endpoint"""
-
- # Make rapid requests
- responses = []
- for _ in range(20):
- response = await test_client.get("/health")
- responses.append(response)
-
- # All should succeed
- for response in responses:
- assert response.status_code == status.HTTP_200_OK
\ No newline at end of file
diff --git a/services/training/tests/test_end_to_end.py b/services/training/tests/test_end_to_end.py
deleted file mode 100644
index 65c75712..00000000
--- a/services/training/tests/test_end_to_end.py
+++ /dev/null
@@ -1,311 +0,0 @@
-# ================================================================
-# services/training/tests/test_end_to_end.py
-# ================================================================
-"""
-End-to-End Testing for Training Service
-Tests complete workflows from API to ML pipeline to results
-"""
-
-import pytest
-import asyncio
-import httpx
-import pandas as pd
-import json
-import tempfile
-import time
-from datetime import datetime, timedelta
-from typing import Dict, List, Any
-from unittest.mock import patch, AsyncMock
-import uuid
-
-from app.main import app
-from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
-
-
-class TestTrainingServiceEndToEnd:
- """End-to-end tests for complete training workflows"""
-
- @pytest.fixture
- async def test_client(self):
- """Create test client for the training service"""
- from httpx import AsyncClient
- async with AsyncClient(app=app, base_url="http://test") as client:
- yield client
-
- @pytest.fixture
- def real_bakery_data(self):
- """Use the actual bakery sales data from the uploaded CSV"""
- # This fixture would load the real bakery_sales_2023_2024.csv data
- # For testing, we'll simulate the structure based on the document description
-
- # Generate realistic data matching the CSV structure
- start_date = datetime(2023, 1, 1)
- dates = [start_date + timedelta(days=i) for i in range(365)]
-
- products = [
- "Pan Integral", "Pan Blanco", "Croissant", "Magdalenas",
- "Empanadas", "Tarta Chocolate", "Roscon Reyes", "Palmeras"
- ]
-
- data = []
- for date in dates:
- for product in products:
- # Realistic sales patterns for Madrid bakery
- base_quantity = {
- "Pan Integral": 80, "Pan Blanco": 120, "Croissant": 45,
- "Magdalenas": 30, "Empanadas": 25, "Tarta Chocolate": 15,
- "Roscon Reyes": 8, "Palmeras": 12
- }.get(product, 20)
-
- # Seasonal variations
- if date.month == 12 and product == "Roscon Reyes":
- base_quantity *= 5 # Christmas specialty
- elif date.month in [6, 7, 8]: # Summer
- base_quantity *= 0.85
- elif date.month in [11, 12, 1]: # Winter
- base_quantity *= 1.15
-
- # Weekly patterns
- if date.weekday() >= 5: # Weekends
- base_quantity *= 1.3
- elif date.weekday() == 0: # Monday slower
- base_quantity *= 0.8
-
- # Weather influence
- temp = 15 + 12 * np.sin((date.timetuple().tm_yday / 365) * 2 * np.pi)
- if temp > 30: # Very hot days
- if product in ["Pan Integral", "Pan Blanco"]:
- base_quantity *= 0.7
- elif temp < 5: # Cold days
- base_quantity *= 1.1
-
- # Add realistic noise
- import numpy as np
- quantity = max(1, int(base_quantity + np.random.normal(0, base_quantity * 0.15)))
-
- # Calculate revenue (realistic Spanish bakery prices)
- price_per_unit = {
- "Pan Integral": 2.80, "Pan Blanco": 2.50, "Croissant": 1.50,
- "Magdalenas": 1.20, "Empanadas": 3.50, "Tarta Chocolate": 18.00,
- "Roscon Reyes": 25.00, "Palmeras": 1.80
- }.get(product, 2.00)
-
- revenue = round(quantity * price_per_unit, 2)
-
- data.append({
- "date": date.strftime("%Y-%m-%d"),
- "product": product,
- "quantity": quantity,
- "revenue": revenue,
- "temperature": round(temp + np.random.normal(0, 3), 1),
- "precipitation": max(0, np.random.exponential(0.8)),
- "is_weekend": date.weekday() >= 5,
- "is_holiday": self._is_spanish_holiday(date)
- })
-
- return pd.DataFrame(data)
-
- def _is_spanish_holiday(self, date: datetime) -> bool:
- """Check if date is a Spanish holiday"""
- spanish_holidays = [
- (1, 1), # Año Nuevo
- (1, 6), # Reyes Magos
- (5, 1), # Día del Trabajo
- (8, 15), # Asunción de la Virgen
- (10, 12), # Fiesta Nacional de España
- (11, 1), # Todos los Santos
- (12, 6), # Día de la Constitución
- (12, 8), # Inmaculada Concepción
- (12, 25), # Navidad
- ]
- return (date.month, date.day) in spanish_holidays
-
- @pytest.fixture
- async def mock_external_apis(self):
- """Mock external APIs (AEMET and Madrid OpenData)"""
- with patch('app.external.aemet.AEMETClient') as mock_aemet, \
- patch('app.external.madrid_opendata.MadridOpenDataClient') as mock_madrid:
-
- # Mock AEMET weather data
- mock_aemet_instance = AsyncMock()
- mock_aemet.return_value = mock_aemet_instance
-
- # Generate realistic Madrid weather data
- weather_data = []
- for i in range(365):
- date = datetime(2023, 1, 1) + timedelta(days=i)
- day_of_year = date.timetuple().tm_yday
- # Madrid climate: hot summers, mild winters
- base_temp = 14 + 12 * np.sin((day_of_year / 365) * 2 * np.pi)
-
- weather_data.append({
- "date": date,
- "temperature": round(base_temp + np.random.normal(0, 4), 1),
- "precipitation": max(0, np.random.exponential(1.2)),
- "humidity": np.random.uniform(25, 75),
- "wind_speed": np.random.uniform(3, 20),
- "pressure": np.random.uniform(995, 1025),
- "description": np.random.choice([
- "Soleado", "Parcialmente nublado", "Nublado",
- "Lluvia ligera", "Despejado"
- ]),
- "source": "aemet"
- })
-
- mock_aemet_instance.get_historical_weather.return_value = weather_data
- mock_aemet_instance.get_current_weather.return_value = weather_data[-1]
-
- # Mock Madrid traffic data
- mock_madrid_instance = AsyncMock()
- mock_madrid.return_value = mock_madrid_instance
-
- traffic_data = []
- for i in range(365):
- date = datetime(2023, 1, 1) + timedelta(days=i)
-
- # Multiple measurements per day
- for hour in range(6, 22, 2): # Every 2 hours from 6 AM to 10 PM
- measurement_time = date.replace(hour=hour)
-
- # Realistic Madrid traffic patterns
- if hour in [7, 8, 9, 18, 19, 20]: # Rush hours
- volume = np.random.randint(1200, 2000)
- congestion = "high"
- speed = np.random.randint(10, 25)
- elif hour in [12, 13, 14]: # Lunch time
- volume = np.random.randint(800, 1200)
- congestion = "medium"
- speed = np.random.randint(20, 35)
- else: # Off-peak
- volume = np.random.randint(300, 800)
- congestion = "low"
- speed = np.random.randint(30, 50)
-
- traffic_data.append({
- "date": measurement_time,
- "traffic_volume": volume,
- "occupation_percentage": np.random.randint(15, 85),
- "load_percentage": np.random.randint(25, 90),
- "average_speed": speed,
- "congestion_level": congestion,
- "pedestrian_count": np.random.randint(100, 800),
- "measurement_point_id": "MADRID_CENTER_001",
- "measurement_point_name": "Puerta del Sol",
- "road_type": "URB",
- "source": "madrid_opendata"
- })
-
- mock_madrid_instance.get_historical_traffic.return_value = traffic_data
- mock_madrid_instance.get_current_traffic.return_value = traffic_data[-1]
-
- yield {
- 'aemet': mock_aemet_instance,
- 'madrid': mock_madrid_instance
- }
-
- @pytest.mark.asyncio
- async def test_complete_training_workflow_api(
- self,
- test_client,
- real_bakery_data,
- mock_external_apis
- ):
- """Test complete training workflow through API endpoints"""
-
- # Step 1: Check service health
- health_response = await test_client.get("/health")
- assert health_response.status_code == 200
- health_data = health_response.json()
- assert health_data["status"] == "healthy"
-
- # Step 2: Validate training data quality
- with patch('app.services.training_service.TrainingService._fetch_sales_data',
- return_value=real_bakery_data):
-
- validation_response = await test_client.post(
- "/training/validate",
- json={
- "tenant_id": "test_bakery_001",
- "include_weather": True,
- "include_traffic": True
- }
- )
-
- assert validation_response.status_code == 200
- validation_data = validation_response.json()
- assert validation_data["is_valid"] is True
- assert validation_data["data_points"] > 1000 # Sufficient data
- assert validation_data["missing_percentage"] < 10
-
- # Step 3: Start training job for multiple products
- training_request = {
- "products": ["Pan Integral", "Croissant", "Magdalenas"],
- "include_weather": True,
- "include_traffic": True,
- "config": {
- "seasonality_mode": "additive",
- "changepoint_prior_scale": 0.05,
- "seasonality_prior_scale": 10.0,
- "validation_enabled": True
- }
- }
-
- with patch('app.services.training_service.TrainingService._fetch_sales_data',
- return_value=real_bakery_data):
-
- start_response = await test_client.post(
- "/training/jobs",
- json=training_request,
- headers={"X-Tenant-ID": "test_bakery_001"}
- )
-
- assert start_response.status_code == 201
- job_data = start_response.json()
- job_id = job_data["job_id"]
- assert job_data["status"] == "pending"
-
- # Step 4: Monitor job progress
- max_wait_time = 300 # 5 minutes
- start_time = time.time()
-
- while time.time() - start_time < max_wait_time:
- status_response = await test_client.get(f"/training/jobs/{job_id}/status")
- assert status_response.status_code == 200
-
- status_data = status_response.json()
-
- if status_data["status"] == "completed":
- # Training completed successfully
- assert "models_trained" in status_data
- assert len(status_data["models_trained"]) == 3 # Three products
-
- # Check model quality
- for model_info in status_data["models_trained"]:
- assert "product_name" in model_info
- assert "model_id" in model_info
- assert "metrics" in model_info
-
- metrics = model_info["metrics"]
- assert "mape" in metrics
- assert "rmse" in metrics
- assert "mae" in metrics
-
- # Quality thresholds for bakery data
- assert metrics["mape"] < 50, f"MAPE too high for {model_info['product_name']}: {metrics['mape']}"
- assert metrics["rmse"] > 0
-
- break
- elif status_data["status"] == "failed":
- pytest.fail(f"Training job failed: {status_data.get('error_message', 'Unknown error')}")
-
- # Wait before checking again
- await asyncio.sleep(10)
- else:
- pytest.fail(f"Training job did not complete within {max_wait_time} seconds")
-
- # Step 5: Get detailed job logs
- logs_response = await test_client.get(f"/training/jobs/{job_id}/logs")
- assert logs_response.status_code == 200
- logs_data = logs_response.json()
- assert "logs" in logs_data
- assert len(logs_data["logs"]) > 0
\ No newline at end of file
diff --git a/services/training/tests/test_integration.py b/services/training/tests/test_integration.py
deleted file mode 100644
index 40ea6d98..00000000
--- a/services/training/tests/test_integration.py
+++ /dev/null
@@ -1,848 +0,0 @@
-# services/training/tests/test_integration.py
-"""
-Integration tests for training service
-Tests complete workflows and service interactions
-"""
-
-import pytest
-import asyncio
-from unittest.mock import AsyncMock, Mock, patch
-from httpx import AsyncClient
-from datetime import datetime, timedelta
-
-from app.main import app
-from app.schemas.training import TrainingJobRequest
-
-
-class TestTrainingWorkflowIntegration:
- """Test complete training workflows end-to-end"""
-
- @pytest.mark.asyncio
- async def test_complete_training_workflow(
- self,
- test_client: AsyncClient,
- test_db_session,
- mock_messaging,
- mock_data_service,
- mock_ml_trainer
- ):
- """Test complete training workflow from API to completion"""
-
- # Step 1: Start training job
- request_data = {
- "include_weather": True,
- "include_traffic": True,
- "min_data_points": 30,
- "seasonality_mode": "additive"
- }
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs", json=request_data)
-
- assert response.status_code == 200
- job_data = response.json()
- job_id = job_data["job_id"]
-
- # Step 2: Check initial status
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.get(f"/training/jobs/{job_id}/status")
-
- assert response.status_code == 200
- status_data = response.json()
- assert status_data["status"] in ["pending", "started"]
-
- # Step 3: Simulate background task completion
- # In real scenario, this would be handled by background tasks
- await asyncio.sleep(0.1) # Allow background task to start
-
- # Step 4: Check completion status
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.get(f"/training/jobs/{job_id}/status")
-
- # The job should exist in database even if not completed yet
- assert response.status_code == 200
-
- @pytest.mark.asyncio
- async def test_single_product_training_workflow(
- self,
- test_client: AsyncClient,
- mock_messaging,
- mock_data_service,
- mock_ml_trainer
- ):
- """Test single product training complete workflow"""
-
- product_name = "Pan Integral"
- request_data = {
- "include_weather": True,
- "include_traffic": False,
- "seasonality_mode": "additive"
- }
-
- # Start single product training
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post(
- f"/training/products/{product_name}",
- json=request_data
- )
-
- assert response.status_code == 200
- job_data = response.json()
- job_id = job_data["job_id"]
- assert f"training started for {product_name}" in job_data["message"].lower()
-
- # Check job status
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.get(f"/training/jobs/{job_id}/status")
-
- assert response.status_code == 200
- status_data = response.json()
- assert status_data["job_id"] == job_id
-
- @pytest.mark.asyncio
- async def test_training_validation_workflow(
- self,
- test_client: AsyncClient,
- mock_data_service
- ):
- """Test training data validation workflow"""
-
- request_data = {
- "include_weather": True,
- "include_traffic": True,
- "min_data_points": 30
- }
-
- # Validate training data
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/validate", json=request_data)
-
- assert response.status_code == 200
- validation_data = response.json()
-
- assert "is_valid" in validation_data
- assert "issues" in validation_data
- assert "recommendations" in validation_data
- assert "estimated_training_time" in validation_data
-
- # If validation passes, start actual training
- if validation_data["is_valid"]:
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs", json=request_data)
-
- assert response.status_code == 200
-
- @pytest.mark.asyncio
- async def test_job_cancellation_workflow(
- self,
- test_client: AsyncClient,
- training_job_in_db,
- mock_messaging
- ):
- """Test job cancellation workflow"""
-
- job_id = training_job_in_db.job_id
-
- # Check initial status
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.get(f"/training/jobs/{job_id}/status")
-
- assert response.status_code == 200
- initial_status = response.json()
- assert initial_status["status"] == "pending"
-
- # Cancel the job
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post(f"/training/jobs/{job_id}/cancel")
-
- assert response.status_code == 200
- cancel_response = response.json()
- assert "cancelled" in cancel_response["message"].lower()
-
- # Verify cancellation
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.get(f"/training/jobs/{job_id}/status")
-
- assert response.status_code == 200
- final_status = response.json()
- assert final_status["status"] == "cancelled"
-
-
-class TestServiceInteractionIntegration:
- """Test interactions between training service and external services"""
-
- @pytest.mark.asyncio
- async def test_data_service_integration(self, training_service, mock_data_service):
- """Test integration with data service"""
- from app.schemas.training import TrainingJobRequest
-
- request = TrainingJobRequest(
- include_weather=True,
- include_traffic=True,
- min_data_points=30
- )
-
- # Test sales data fetching
- sales_data = await training_service._fetch_sales_data("test-tenant", request)
- assert isinstance(sales_data, list)
-
- # Test weather data fetching
- weather_data = await training_service._fetch_weather_data("test-tenant", request)
- assert isinstance(weather_data, list)
-
- # Test traffic data fetching
- traffic_data = await training_service._fetch_traffic_data("test-tenant", request)
- assert isinstance(traffic_data, list)
-
- @pytest.mark.asyncio
- async def test_messaging_integration(self, mock_messaging):
- """Test integration with messaging system"""
- from app.services.messaging import (
- publish_job_started,
- publish_job_completed,
- publish_model_trained
- )
-
- # Test various message types
- result1 = await publish_job_started("job-123", "tenant-123", {})
- result2 = await publish_job_completed("job-123", "tenant-123", {"status": "success"})
- result3 = await publish_model_trained("model-123", "tenant-123", "Pan Integral", {"mae": 5.0})
-
- assert result1 is True
- assert result2 is True
- assert result3 is True
-
- @pytest.mark.asyncio
- async def test_database_integration(self, test_db_session, training_service):
- """Test database operations integration"""
-
- # Create a training job
- job = await training_service.create_training_job(
- db=test_db_session,
- tenant_id="test-tenant",
- job_id="integration-test-job",
- config={"test": True}
- )
-
- assert job.job_id == "integration-test-job"
-
- # Update job status
- await training_service._update_job_status(
- db=test_db_session,
- job_id=job.job_id,
- status="running",
- progress=50,
- current_step="Processing data"
- )
-
- # Retrieve updated job
- updated_job = await training_service.get_job_status(
- db=test_db_session,
- job_id=job.job_id,
- tenant_id="test-tenant"
- )
-
- assert updated_job.status == "running"
- assert updated_job.progress == 50
-
-
-class TestErrorHandlingIntegration:
- """Test error handling across service boundaries"""
-
- @pytest.mark.asyncio
- async def test_data_service_failure_handling(
- self,
- test_client: AsyncClient,
- mock_messaging
- ):
- """Test handling when data service is unavailable"""
-
- request_data = {
- "include_weather": True,
- "include_traffic": True,
- "min_data_points": 30
- }
-
- # Mock data service failure
- with patch('httpx.AsyncClient') as mock_client:
- mock_client.return_value.__aenter__.return_value.get.side_effect = Exception("Service unavailable")
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs", json=request_data)
-
- # Should still create job but might fail during execution
- assert response.status_code == 200
-
- @pytest.mark.asyncio
- async def test_messaging_failure_handling(
- self,
- test_client: AsyncClient,
- mock_data_service
- ):
- """Test handling when messaging fails"""
-
- request_data = {
- "include_weather": True,
- "include_traffic": True,
- "min_data_points": 30
- }
-
- # Mock messaging failure
- with patch('app.services.messaging.publish_job_started', side_effect=Exception("Messaging failed")):
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs", json=request_data)
-
- # Should still succeed even if messaging fails
- assert response.status_code == 200
-
- @pytest.mark.asyncio
- async def test_ml_training_failure_handling(
- self,
- test_client: AsyncClient,
- mock_messaging,
- mock_data_service
- ):
- """Test handling when ML training fails"""
-
- request_data = {
- "include_weather": True,
- "include_traffic": True,
- "min_data_points": 30
- }
-
- # Mock ML training failure
- with patch('app.ml.trainer.BakeryMLTrainer.train_tenant_models', side_effect=Exception("ML training failed")):
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs", json=request_data)
-
- # Job should be created successfully
- assert response.status_code == 200
-
- # Background task would handle the failure
-
-
-class TestPerformanceIntegration:
- """Test performance characteristics of integrated workflows"""
-
- @pytest.mark.asyncio
- async def test_concurrent_training_jobs(
- self,
- test_client: AsyncClient,
- mock_messaging,
- mock_data_service,
- mock_ml_trainer
- ):
- """Test handling multiple concurrent training jobs"""
-
- request_data = {
- "include_weather": True,
- "include_traffic": True,
- "min_data_points": 30
- }
-
- # Start multiple jobs concurrently
- tasks = []
- for i in range(5):
- with patch('app.api.training.get_current_tenant_id', return_value=f"tenant-{i}"):
- task = test_client.post("/training/jobs", json=request_data)
- tasks.append(task)
-
- responses = await asyncio.gather(*tasks)
-
- # All jobs should be created successfully
- for response in responses:
- assert response.status_code == 200
- data = response.json()
- assert "job_id" in data
-
- @pytest.mark.asyncio
- async def test_large_dataset_handling(
- self,
- training_service,
- test_db_session
- ):
- """Test handling of large datasets"""
-
- # Simulate large dataset
- large_config = {
- "include_weather": True,
- "include_traffic": True,
- "min_data_points": 1000, # Large minimum
- "products": [f"Product-{i}" for i in range(100)] # Many products
- }
-
- job = await training_service.create_training_job(
- db=test_db_session,
- tenant_id="test-tenant",
- job_id="large-dataset-job",
- config=large_config
- )
-
- assert job.config == large_config
- assert job.job_id == "large-dataset-job"
-
- @pytest.mark.asyncio
- async def test_rapid_status_checks(
- self,
- test_client: AsyncClient,
- training_job_in_db
- ):
- """Test rapid successive status checks"""
-
- job_id = training_job_in_db.job_id
-
- # Make many rapid status requests
- tasks = []
- for _ in range(20):
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- task = test_client.get(f"/training/jobs/{job_id}/status")
- tasks.append(task)
-
- responses = await asyncio.gather(*tasks)
-
- # All requests should succeed
- for response in responses:
- assert response.status_code == 200
-
-
-class TestSecurityIntegration:
- """Test security aspects of service integration"""
-
- @pytest.mark.asyncio
- async def test_tenant_isolation(
- self,
- test_client: AsyncClient,
- training_job_in_db,
- mock_messaging
- ):
- """Test that tenants cannot access each other's jobs"""
-
- job_id = training_job_in_db.job_id
-
- # Try to access job with different tenant ID
- with patch('app.api.training.get_current_tenant_id', return_value="different-tenant"):
- response = await test_client.get(f"/training/jobs/{job_id}/status")
-
- # Should not find the job (belongs to different tenant)
- assert response.status_code == 404
-
- @pytest.mark.asyncio
- async def test_input_validation_integration(
- self,
- test_client: AsyncClient
- ):
- """Test input validation across API boundaries"""
-
- # Test invalid seasonality mode
- invalid_request = {
- "seasonality_mode": "invalid_mode",
- "min_data_points": -5 # Invalid negative value
- }
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs", json=invalid_request)
-
- assert response.status_code == 422 # Validation error
-
- @pytest.mark.asyncio
- async def test_sql_injection_protection(
- self,
- test_client: AsyncClient
- ):
- """Test protection against SQL injection attempts"""
-
- # Try SQL injection in job ID
- malicious_job_id = "job'; DROP TABLE model_training_logs; --"
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.get(f"/training/jobs/{malicious_job_id}/status")
-
- # Should return 404, not cause database error
- assert response.status_code == 404
-
-
-class TestRecoveryIntegration:
- """Test recovery and resilience scenarios"""
-
- @pytest.mark.asyncio
- async def test_service_restart_recovery(
- self,
- test_db_session,
- training_service,
- training_job_in_db
- ):
- """Test service recovery after restart"""
-
- # Simulate service restart by creating new service instance
- new_training_service = training_service.__class__()
-
- # Should be able to access existing jobs
- existing_job = await new_training_service.get_job_status(
- db=test_db_session,
- job_id=training_job_in_db.job_id,
- tenant_id=training_job_in_db.tenant_id
- )
-
- assert existing_job is not None
- assert existing_job.job_id == training_job_in_db.job_id
-
- @pytest.mark.asyncio
- async def test_partial_failure_recovery(
- self,
- training_service,
- test_db_session
- ):
- """Test recovery from partial failures"""
-
- # Create job that might fail partway through
- job = await training_service.create_training_job(
- db=test_db_session,
- tenant_id="test-tenant",
- job_id="partial-failure-job",
- config={"simulate_failure": True}
- )
-
- # Simulate partial progress
- await training_service._update_job_status(
- db=test_db_session,
- job_id=job.job_id,
- status="running",
- progress=50,
- current_step="Halfway through training"
- )
-
- # Simulate failure
- await training_service._update_job_status(
- db=test_db_session,
- job_id=job.job_id,
- status="failed",
- progress=50,
- current_step="Training failed",
- error_message="Simulated failure"
- )
-
- # Verify failure was recorded
- failed_job = await training_service.get_job_status(
- db=test_db_session,
- job_id=job.job_id,
- tenant_id="test-tenant"
- )
-
- assert failed_job.status == "failed"
- assert failed_job.error_message == "Simulated failure"
- assert failed_job.progress == 50
-
-
-class TestComplianceIntegration:
- """Test compliance and audit requirements"""
-
- @pytest.mark.asyncio
- async def test_audit_trail_creation(
- self,
- training_service,
- test_db_session
- ):
- """Test that audit trail is properly created"""
-
- # Create and update job
- job = await training_service.create_training_job(
- db=test_db_session,
- tenant_id="test-tenant",
- job_id="audit-test-job",
- config={"audit_test": True}
- )
-
- # Multiple status updates
- await training_service._update_job_status(
- db=test_db_session,
- job_id=job.job_id,
- status="running",
- progress=25,
- current_step="Started processing"
- )
-
- await training_service._update_job_status(
- db=test_db_session,
- job_id=job.job_id,
- status="running",
- progress=75,
- current_step="Almost complete"
- )
-
- await training_service._update_job_status(
- db=test_db_session,
- job_id=job.job_id,
- status="completed",
- progress=100,
- current_step="Completed successfully"
- )
-
- # Verify audit trail
- logs = await training_service.get_training_logs(
- db=test_db_session,
- job_id=job.job_id,
- tenant_id="test-tenant"
- )
-
- assert logs is not None
- assert len(logs) > 0
-
- # Check final status
- final_job = await training_service.get_job_status(
- db=test_db_session,
- job_id=job.job_id,
- tenant_id="test-tenant"
- )
-
- assert final_job.status == "completed"
- assert final_job.progress == 100
-
- @pytest.mark.asyncio
- async def test_data_retention_compliance(
- self,
- training_service,
- test_db_session
- ):
- """Test data retention and cleanup compliance"""
-
- from datetime import datetime, timedelta
-
- # Create old job (simulate old data)
- old_job = await training_service.create_training_job(
- db=test_db_session,
- tenant_id="test-tenant",
- job_id="old-job",
- config={"created_long_ago": True}
- )
-
- # Manually set old timestamp
- from sqlalchemy import update
- from app.models.training import ModelTrainingLog
-
- old_timestamp = datetime.now() - timedelta(days=400)
- await test_db_session.execute(
- update(ModelTrainingLog)
- .where(ModelTrainingLog.job_id == old_job.job_id)
- .values(start_time=old_timestamp, created_at=old_timestamp)
- )
- await test_db_session.commit()
-
- # Verify old job exists
- retrieved_job = await training_service.get_job_status(
- db=test_db_session,
- job_id=old_job.job_id,
- tenant_id="test-tenant"
- )
-
- assert retrieved_job is not None
- # In a real implementation, there would be cleanup procedures
-
- @pytest.mark.asyncio
- async def test_gdpr_compliance_features(
- self,
- training_service,
- test_db_session
- ):
- """Test GDPR compliance features"""
-
- # Create job with tenant data
- job = await training_service.create_training_job(
- db=test_db_session,
- tenant_id="gdpr-test-tenant",
- job_id="gdpr-test-job",
- config={"gdpr_test": True}
- )
-
- # Verify job is associated with tenant
- assert job.tenant_id == "gdpr-test-tenant"
-
- # Test data access (right to access)
- tenant_jobs = await training_service.list_training_jobs(
- db=test_db_session,
- tenant_id="gdpr-test-tenant"
- )
-
- assert len(tenant_jobs) >= 1
- assert any(job.job_id == "gdpr-test-job" for job in tenant_jobs)
-
-
-@pytest.mark.slow
-class TestLongRunningIntegration:
- """Test long-running integration scenarios (marked as slow)"""
-
- @pytest.mark.asyncio
- async def test_extended_training_simulation(
- self,
- training_service,
- test_db_session,
- mock_messaging
- ):
- """Test extended training process simulation"""
-
- job = await training_service.create_training_job(
- db=test_db_session,
- tenant_id="test-tenant",
- job_id="long-running-job",
- config={"extended_test": True}
- )
-
- # Simulate progress over time
- progress_steps = [
- (10, "Initializing"),
- (25, "Loading data"),
- (50, "Training models"),
- (75, "Validating results"),
- (90, "Storing models"),
- (100, "Completed")
- ]
-
- for progress, step in progress_steps:
- await training_service._update_job_status(
- db=test_db_session,
- job_id=job.job_id,
- status="running" if progress < 100 else "completed",
- progress=progress,
- current_step=step
- )
-
- # Small delay to simulate real progression
- await asyncio.sleep(0.01)
-
- # Verify final state
- final_job = await training_service.get_job_status(
- db=test_db_session,
- job_id=job.job_id,
- tenant_id="test-tenant"
- )
-
- assert final_job.status == "completed"
- assert final_job.progress == 100
- assert final_job.current_step == "Completed"
-
- @pytest.mark.asyncio
- async def test_memory_usage_stability(
- self,
- training_service,
- test_db_session
- ):
- """Test memory usage stability over many operations"""
-
- # Create many jobs to test memory stability
- for i in range(50):
- job = await training_service.create_training_job(
- db=test_db_session,
- tenant_id=f"tenant-{i % 5}", # 5 different tenants
- job_id=f"memory-test-job-{i}",
- config={"iteration": i}
- )
-
- # Update status
- await training_service._update_job_status(
- db=test_db_session,
- job_id=job.job_id,
- status="completed",
- progress=100,
- current_step="Completed"
- )
-
- # List jobs for each tenant
- for tenant_i in range(5):
- tenant_id = f"tenant-{tenant_i}"
- jobs = await training_service.list_training_jobs(
- db=test_db_session,
- tenant_id=tenant_id,
- limit=20
- )
-
- # Should have 10 jobs per tenant (50 total / 5 tenants)
- assert len(jobs) == 10
-
-
-class TestBackwardCompatibility:
- """Test backward compatibility with existing systems"""
-
- @pytest.mark.asyncio
- async def test_legacy_config_handling(
- self,
- training_service,
- test_db_session
- ):
- """Test handling of legacy configuration formats"""
-
- # Test with old-style configuration
- legacy_config = {
- "weather_enabled": True, # Old key
- "traffic_enabled": True, # Old key
- "minimum_samples": 30, # Old key
- "prophet_config": { # Old nested structure
- "seasonality": "additive"
- }
- }
-
- job = await training_service.create_training_job(
- db=test_db_session,
- tenant_id="test-tenant",
- job_id="legacy-config-job",
- config=legacy_config
- )
-
- assert job.config == legacy_config
- assert job.job_id == "legacy-config-job"
-
- @pytest.mark.asyncio
- async def test_api_version_compatibility(
- self,
- test_client: AsyncClient
- ):
- """Test API version compatibility"""
-
- # Test with minimal request (old API style)
- minimal_request = {
- "include_weather": True
- }
-
- with patch('shared.auth.decorators.get_current_tenant_id_dep', return_value="test-tenant"):
- response = await test_client.post("/training/jobs", json=minimal_request)
-
- # Should work with defaults for missing fields
- assert response.status_code == 200
- data = response.json()
- assert "job_id" in data
-
-
-# Utility functions for integration tests
-async def wait_for_condition(condition_func, timeout=5.0, interval=0.1):
- """Wait for a condition to become true"""
- import time
- start_time = time.time()
-
- while time.time() - start_time < timeout:
- if await condition_func():
- return True
- await asyncio.sleep(interval)
-
- return False
-
-
-def assert_job_progression(job_updates):
- """Assert that job updates show proper progression"""
- assert len(job_updates) > 0
-
- # Check progress is non-decreasing
- for i in range(1, len(job_updates)):
- assert job_updates[i]["progress"] >= job_updates[i-1]["progress"]
-
- # Check final status
- final_update = job_updates[-1]
- assert final_update["status"] in ["completed", "failed", "cancelled"]
-
-
-def assert_valid_job_structure(job_data):
- """Assert job data has valid structure"""
- required_fields = ["job_id", "status", "tenant_id"]
- for field in required_fields:
- assert field in job_data
-
- assert isinstance(job_data["progress"], int)
- assert 0 <= job_data["progress"] <= 100
- assert job_data["status"] in ["pending", "running", "completed", "failed", "cancelled"]
\ No newline at end of file
diff --git a/services/training/tests/test_messaging.py b/services/training/tests/test_messaging.py
deleted file mode 100644
index 09031a12..00000000
--- a/services/training/tests/test_messaging.py
+++ /dev/null
@@ -1,467 +0,0 @@
-# services/training/tests/test_messaging.py
-"""
-Tests for training service messaging functionality
-"""
-
-import pytest
-from unittest.mock import AsyncMock, Mock, patch
-import json
-
-from app.services import messaging
-
-
-class TestTrainingMessaging:
- """Test training service messaging functions"""
-
- @pytest.fixture
- def mock_publisher(self):
- """Mock the RabbitMQ publisher"""
- with patch('app.services.messaging.training_publisher') as mock_pub:
- mock_pub.publish_event = AsyncMock(return_value=True)
- mock_pub.connect = AsyncMock(return_value=True)
- mock_pub.disconnect = AsyncMock(return_value=None)
- yield mock_pub
-
- @pytest.mark.asyncio
- async def test_setup_messaging_success(self, mock_publisher):
- """Test successful messaging setup"""
- await messaging.setup_messaging()
-
- mock_publisher.connect.assert_called_once()
-
- @pytest.mark.asyncio
- async def test_setup_messaging_failure(self, mock_publisher):
- """Test messaging setup failure"""
- mock_publisher.connect.return_value = False
-
- await messaging.setup_messaging()
-
- mock_publisher.connect.assert_called_once()
-
- @pytest.mark.asyncio
- async def test_cleanup_messaging(self, mock_publisher):
- """Test messaging cleanup"""
- await messaging.cleanup_messaging()
-
- mock_publisher.disconnect.assert_called_once()
-
- @pytest.mark.asyncio
- async def test_publish_job_started(self, mock_publisher):
- """Test publishing job started event"""
- job_id = "test-job-123"
- tenant_id = "test-tenant"
- config = {"include_weather": True}
-
- result = await messaging.publish_job_started(job_id, tenant_id, config)
-
- assert result is True
- mock_publisher.publish_event.assert_called_once()
-
- # Check call arguments
- call_args = mock_publisher.publish_event.call_args
- assert call_args[1]["exchange_name"] == "training.events"
- assert call_args[1]["routing_key"] == "training.started"
-
- event_data = call_args[1]["event_data"]
- assert event_data["service_name"] == "training-service"
- assert event_data["data"]["job_id"] == job_id
- assert event_data["data"]["tenant_id"] == tenant_id
- assert event_data["data"]["config"] == config
-
- @pytest.mark.asyncio
- async def test_publish_job_progress(self, mock_publisher):
- """Test publishing job progress event"""
- job_id = "test-job-123"
- tenant_id = "test-tenant"
- progress = 50
- step = "Training models"
-
- result = await messaging.publish_job_progress(job_id, tenant_id, progress, step)
-
- assert result is True
- mock_publisher.publish_event.assert_called_once()
-
- call_args = mock_publisher.publish_event.call_args
- assert call_args[1]["routing_key"] == "training.progress"
-
- event_data = call_args[1]["event_data"]
- assert event_data["data"]["progress"] == progress
- assert event_data["data"]["current_step"] == step
-
- @pytest.mark.asyncio
- async def test_publish_job_completed(self, mock_publisher):
- """Test publishing job completed event"""
- job_id = "test-job-123"
- tenant_id = "test-tenant"
- results = {
- "products_trained": 3,
- "summary": {"success_rate": 100.0}
- }
-
- result = await messaging.publish_job_completed(job_id, tenant_id, results)
-
- assert result is True
- mock_publisher.publish_event.assert_called_once()
-
- call_args = mock_publisher.publish_event.call_args
- assert call_args[1]["routing_key"] == "training.completed"
-
- event_data = call_args[1]["event_data"]
- assert event_data["data"]["results"] == results
- assert event_data["data"]["models_trained"] == 3
- assert event_data["data"]["success_rate"] == 100.0
-
- @pytest.mark.asyncio
- async def test_publish_job_failed(self, mock_publisher):
- """Test publishing job failed event"""
- job_id = "test-job-123"
- tenant_id = "test-tenant"
- error = "Data service unavailable"
-
- result = await messaging.publish_job_failed(job_id, tenant_id, error)
-
- assert result is True
- mock_publisher.publish_event.assert_called_once()
-
- call_args = mock_publisher.publish_event.call_args
- assert call_args[1]["routing_key"] == "training.failed"
-
- event_data = call_args[1]["event_data"]
- assert event_data["data"]["error"] == error
-
- @pytest.mark.asyncio
- async def test_publish_job_cancelled(self, mock_publisher):
- """Test publishing job cancelled event"""
- job_id = "test-job-123"
- tenant_id = "test-tenant"
-
- result = await messaging.publish_job_cancelled(job_id, tenant_id)
-
- assert result is True
- mock_publisher.publish_event.assert_called_once()
-
- call_args = mock_publisher.publish_event.call_args
- assert call_args[1]["routing_key"] == "training.cancelled"
-
- @pytest.mark.asyncio
- async def test_publish_product_training_started(self, mock_publisher):
- """Test publishing product training started event"""
- job_id = "test-product-job-123"
- tenant_id = "test-tenant"
- product_name = "Pan Integral"
-
- result = await messaging.publish_product_training_started(job_id, tenant_id, product_name)
-
- assert result is True
- mock_publisher.publish_event.assert_called_once()
-
- call_args = mock_publisher.publish_event.call_args
- assert call_args[1]["routing_key"] == "training.product.started"
-
- event_data = call_args[1]["event_data"]
- assert event_data["data"]["product_name"] == product_name
-
- @pytest.mark.asyncio
- async def test_publish_product_training_completed(self, mock_publisher):
- """Test publishing product training completed event"""
- job_id = "test-product-job-123"
- tenant_id = "test-tenant"
- product_name = "Pan Integral"
- model_id = "test-model-123"
-
- result = await messaging.publish_product_training_completed(
- job_id, tenant_id, product_name, model_id
- )
-
- assert result is True
- mock_publisher.publish_event.assert_called_once()
-
- call_args = mock_publisher.publish_event.call_args
- assert call_args[1]["routing_key"] == "training.product.completed"
-
- event_data = call_args[1]["event_data"]
- assert event_data["data"]["model_id"] == model_id
- assert event_data["data"]["product_name"] == product_name
-
- @pytest.mark.asyncio
- async def test_publish_model_trained(self, mock_publisher):
- """Test publishing model trained event"""
- model_id = "test-model-123"
- tenant_id = "test-tenant"
- product_name = "Pan Integral"
- metrics = {"mae": 5.2, "rmse": 7.8, "mape": 12.5}
-
- result = await messaging.publish_model_trained(model_id, tenant_id, product_name, metrics)
-
- assert result is True
- mock_publisher.publish_event.assert_called_once()
-
- call_args = mock_publisher.publish_event.call_args
- assert call_args[1]["routing_key"] == "training.model.trained"
-
- event_data = call_args[1]["event_data"]
- assert event_data["data"]["training_metrics"] == metrics
-
- @pytest.mark.asyncio
- async def test_publish_model_updated(self, mock_publisher):
- """Test publishing model updated event"""
- model_id = "test-model-123"
- tenant_id = "test-tenant"
- product_name = "Pan Integral"
- version = 2
-
- result = await messaging.publish_model_updated(model_id, tenant_id, product_name, version)
-
- assert result is True
- mock_publisher.publish_event.assert_called_once()
-
- call_args = mock_publisher.publish_event.call_args
- assert call_args[1]["routing_key"] == "training.model.updated"
-
- event_data = call_args[1]["event_data"]
- assert event_data["data"]["version"] == version
-
- @pytest.mark.asyncio
- async def test_publish_model_validated(self, mock_publisher):
- """Test publishing model validated event"""
- model_id = "test-model-123"
- tenant_id = "test-tenant"
- product_name = "Pan Integral"
- validation_results = {"is_valid": True, "accuracy": 0.95}
-
- result = await messaging.publish_model_validated(
- model_id, tenant_id, product_name, validation_results
- )
-
- assert result is True
- mock_publisher.publish_event.assert_called_once()
-
- call_args = mock_publisher.publish_event.call_args
- assert call_args[1]["routing_key"] == "training.model.validated"
-
- event_data = call_args[1]["event_data"]
- assert event_data["data"]["validation_results"] == validation_results
-
- @pytest.mark.asyncio
- async def test_publish_model_saved(self, mock_publisher):
- """Test publishing model saved event"""
- model_id = "test-model-123"
- tenant_id = "test-tenant"
- product_name = "Pan Integral"
- model_path = "/models/test-model-123.pkl"
-
- result = await messaging.publish_model_saved(model_id, tenant_id, product_name, model_path)
-
- assert result is True
- mock_publisher.publish_event.assert_called_once()
-
- call_args = mock_publisher.publish_event.call_args
- assert call_args[1]["routing_key"] == "training.model.saved"
-
- event_data = call_args[1]["event_data"]
- assert event_data["data"]["model_path"] == model_path
-
-
-class TestMessagingErrorHandling:
- """Test error handling in messaging"""
-
- @pytest.fixture
- def failing_publisher(self):
- """Mock publisher that fails"""
- with patch('app.services.messaging.training_publisher') as mock_pub:
- mock_pub.publish_event = AsyncMock(return_value=False)
- mock_pub.connect = AsyncMock(return_value=False)
- yield mock_pub
-
- @pytest.mark.asyncio
- async def test_publish_event_failure(self, failing_publisher):
- """Test handling of publish event failure"""
- result = await messaging.publish_job_started("job-123", "tenant-123", {})
-
- assert result is False
- failing_publisher.publish_event.assert_called_once()
-
- @pytest.mark.asyncio
- async def test_setup_messaging_connection_failure(self, failing_publisher):
- """Test setup with connection failure"""
- await messaging.setup_messaging()
-
- failing_publisher.connect.assert_called_once()
-
- @pytest.mark.asyncio
- async def test_publish_with_exception(self):
- """Test publishing with exception"""
- with patch('app.services.messaging.training_publisher') as mock_pub:
- mock_pub.publish_event.side_effect = Exception("Connection lost")
-
- result = await messaging.publish_job_started("job-123", "tenant-123", {})
-
- assert result is False
-
-
-class TestMessagingIntegration:
- """Test messaging integration with shared components"""
-
- @pytest.mark.asyncio
- async def test_event_structure_consistency(self):
- """Test that events follow consistent structure"""
- with patch('app.services.messaging.training_publisher') as mock_pub:
- mock_pub.publish_event = AsyncMock(return_value=True)
-
- # Test different event types
- await messaging.publish_job_started("job-123", "tenant-123", {})
- await messaging.publish_job_completed("job-123", "tenant-123", {})
- await messaging.publish_model_trained("model-123", "tenant-123", "Pan", {})
-
- # Verify all calls have consistent structure
- assert mock_pub.publish_event.call_count == 3
-
- for call in mock_pub.publish_event.call_args_list:
- event_data = call[1]["event_data"]
-
- # All events should have these fields
- assert "service_name" in event_data
- assert "event_type" in event_data
- assert "data" in event_data
- assert event_data["service_name"] == "training-service"
-
- @pytest.mark.asyncio
- async def test_shared_event_classes_usage(self):
- """Test that shared event classes are used properly"""
- with patch('shared.messaging.events.TrainingStartedEvent') as mock_event_class:
- mock_event = Mock()
- mock_event.to_dict.return_value = {
- "service_name": "training-service",
- "event_type": "training.started",
- "data": {"job_id": "test-job"}
- }
- mock_event_class.return_value = mock_event
-
- with patch('app.services.messaging.training_publisher') as mock_pub:
- mock_pub.publish_event = AsyncMock(return_value=True)
-
- await messaging.publish_job_started("test-job", "test-tenant", {})
-
- # Verify shared event class was used
- mock_event_class.assert_called_once()
- mock_event.to_dict.assert_called_once()
-
- @pytest.mark.asyncio
- async def test_routing_key_consistency(self):
- """Test that routing keys follow consistent patterns"""
- with patch('app.services.messaging.training_publisher') as mock_pub:
- mock_pub.publish_event = AsyncMock(return_value=True)
-
- # Test various event types
- events_and_keys = [
- (messaging.publish_job_started, "training.started"),
- (messaging.publish_job_progress, "training.progress"),
- (messaging.publish_job_completed, "training.completed"),
- (messaging.publish_job_failed, "training.failed"),
- (messaging.publish_job_cancelled, "training.cancelled"),
- (messaging.publish_product_training_started, "training.product.started"),
- (messaging.publish_product_training_completed, "training.product.completed"),
- (messaging.publish_model_trained, "training.model.trained"),
- (messaging.publish_model_updated, "training.model.updated"),
- (messaging.publish_model_validated, "training.model.validated"),
- (messaging.publish_model_saved, "training.model.saved")
- ]
-
- for event_func, expected_key in events_and_keys:
- mock_pub.reset_mock()
-
- # Call event function with appropriate parameters
- if "progress" in expected_key:
- await event_func("job-123", "tenant-123", 50, "step")
- elif "model" in expected_key and "trained" in expected_key:
- await event_func("model-123", "tenant-123", "product", {})
- elif "model" in expected_key and "updated" in expected_key:
- await event_func("model-123", "tenant-123", "product", 1)
- elif "model" in expected_key and "validated" in expected_key:
- await event_func("model-123", "tenant-123", "product", {})
- elif "model" in expected_key and "saved" in expected_key:
- await event_func("model-123", "tenant-123", "product", "/path")
- elif "product" in expected_key and "completed" in expected_key:
- await event_func("job-123", "tenant-123", "product", "model-123")
- elif "product" in expected_key:
- await event_func("job-123", "tenant-123", "product")
- elif "failed" in expected_key:
- await event_func("job-123", "tenant-123", "error")
- elif "cancelled" in expected_key:
- await event_func("job-123", "tenant-123")
- else:
- await event_func("job-123", "tenant-123", {})
-
- # Verify routing key
- call_args = mock_pub.publish_event.call_args
- assert call_args[1]["routing_key"] == expected_key
-
- @pytest.mark.asyncio
- async def test_exchange_consistency(self):
- """Test that all events use the same exchange"""
- with patch('app.services.messaging.training_publisher') as mock_pub:
- mock_pub.publish_event = AsyncMock(return_value=True)
-
- # Test multiple events
- await messaging.publish_job_started("job-123", "tenant-123", {})
- await messaging.publish_model_trained("model-123", "tenant-123", "product", {})
- await messaging.publish_product_training_started("job-123", "tenant-123", "product")
-
- # Verify all use same exchange
- for call in mock_pub.publish_event.call_args_list:
- assert call[1]["exchange_name"] == "training.events"
-
-
-class TestMessagingPerformance:
- """Test messaging performance and reliability"""
-
- @pytest.mark.asyncio
- async def test_concurrent_publishing(self):
- """Test concurrent event publishing"""
- import asyncio
-
- with patch('app.services.messaging.training_publisher') as mock_pub:
- mock_pub.publish_event = AsyncMock(return_value=True)
-
- # Create multiple concurrent publishing tasks
- tasks = []
- for i in range(10):
- task = messaging.publish_job_progress(f"job-{i}", "tenant-123", i * 10, f"step-{i}")
- tasks.append(task)
-
- # Execute all tasks concurrently
- results = await asyncio.gather(*tasks)
-
- # Verify all succeeded
- assert all(results)
- assert mock_pub.publish_event.call_count == 10
-
- @pytest.mark.asyncio
- async def test_large_event_data(self):
- """Test publishing events with large data payloads"""
- with patch('app.services.messaging.training_publisher') as mock_pub:
- mock_pub.publish_event = AsyncMock(return_value=True)
-
- # Create large config data
- large_config = {
- "products": [f"Product-{i}" for i in range(1000)],
- "features": [f"feature-{i}" for i in range(100)],
- "hyperparameters": {f"param-{i}": i for i in range(50)}
- }
-
- result = await messaging.publish_job_started("job-123", "tenant-123", large_config)
-
- assert result is True
- mock_pub.publish_event.assert_called_once()
-
- @pytest.mark.asyncio
- async def test_rapid_sequential_publishing(self):
- """Test rapid sequential event publishing"""
- with patch('app.services.messaging.training_publisher') as mock_pub:
- mock_pub.publish_event = AsyncMock(return_value=True)
-
- # Publish many events in sequence
- for i in range(100):
- await messaging.publish_job_progress("job-123", "tenant-123", i, f"step-{i}")
-
- assert mock_pub.publish_event.call_count == 100
\ No newline at end of file
diff --git a/services/training/tests/test_ml_pipeline_integration.py b/services/training/tests/test_ml_pipeline_integration.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/services/training/tests/test_performance.py b/services/training/tests/test_performance.py
deleted file mode 100644
index a7ba60bf..00000000
--- a/services/training/tests/test_performance.py
+++ /dev/null
@@ -1,630 +0,0 @@
-# ================================================================
-# services/training/tests/test_performance.py
-# ================================================================
-"""
-Performance and Load Testing for Training Service
-Tests training performance with real-world data volumes
-"""
-
-import pytest
-import asyncio
-import pandas as pd
-import numpy as np
-import time
-from datetime import datetime, timedelta
-from concurrent.futures import ThreadPoolExecutor
-import psutil
-import gc
-from typing import List, Dict, Any
-import logging
-
-from app.ml.trainer import BakeryMLTrainer
-from app.ml.data_processor import BakeryDataProcessor
-from app.services.training_service import TrainingService
-
-
-class TestTrainingPerformance:
- """Performance tests for training service components"""
-
- @pytest.fixture
- def large_sales_dataset(self):
- """Generate large dataset for performance testing (2 years of data)"""
- start_date = datetime(2022, 1, 1)
- end_date = datetime(2024, 1, 1)
-
- date_range = pd.date_range(start=start_date, end=end_date, freq='D')
- products = [
- "Pan Integral", "Pan Blanco", "Croissant", "Magdalenas",
- "Empanadas", "Tarta Chocolate", "Roscon Reyes", "Palmeras",
- "Donuts", "Berlinas", "Napolitanas", "Ensaimadas"
- ]
-
- data = []
- for date in date_range:
- for product in products:
- # Realistic sales simulation
- base_quantity = np.random.randint(5, 150)
-
- # Seasonal patterns
- if date.month in [12, 1]: # Winter/Holiday season
- base_quantity *= 1.4
- elif date.month in [6, 7, 8]: # Summer
- base_quantity *= 0.8
-
- # Weekly patterns
- if date.weekday() >= 5: # Weekends
- base_quantity *= 1.2
- elif date.weekday() == 0: # Monday
- base_quantity *= 0.7
-
- # Add noise
- quantity = max(1, int(base_quantity + np.random.normal(0, base_quantity * 0.1)))
-
- data.append({
- "date": date.strftime("%Y-%m-%d"),
- "product": product,
- "quantity": quantity,
- "revenue": round(quantity * np.random.uniform(1.5, 8.0), 2),
- "temperature": round(15 + 12 * np.sin((date.timetuple().tm_yday / 365) * 2 * np.pi) + np.random.normal(0, 3), 1),
- "precipitation": max(0, np.random.exponential(0.8)),
- "is_weekend": date.weekday() >= 5,
- "is_holiday": self._is_spanish_holiday(date)
- })
-
- return pd.DataFrame(data)
-
- def _is_spanish_holiday(self, date: datetime) -> bool:
- """Check if date is a Spanish holiday"""
- holidays = [
- (1, 1), # New Year
- (1, 6), # Epiphany
- (5, 1), # Labor Day
- (8, 15), # Assumption
- (10, 12), # National Day
- (11, 1), # All Saints
- (12, 6), # Constitution Day
- (12, 8), # Immaculate Conception
- (12, 25), # Christmas
- ]
- return (date.month, date.day) in holidays
-
- @pytest.mark.asyncio
- async def test_single_product_training_performance(self, large_sales_dataset):
- """Test performance of single product training with large dataset"""
-
- trainer = BakeryMLTrainer()
- product_data = large_sales_dataset[large_sales_dataset['product'] == 'Pan Integral'].copy()
-
- # Measure memory before training
- process = psutil.Process()
- memory_before = process.memory_info().rss / 1024 / 1024 # MB
-
- start_time = time.time()
-
- result = await trainer.train_single_product(
- tenant_id="perf_test_tenant",
- product_name="Pan Integral",
- sales_data=product_data,
- config={
- "include_weather": True,
- "include_traffic": False, # Skip traffic for performance
- "seasonality_mode": "additive"
- }
- )
-
- end_time = time.time()
- training_duration = end_time - start_time
-
- # Measure memory after training
- memory_after = process.memory_info().rss / 1024 / 1024 # MB
- memory_used = memory_after - memory_before
-
- # Performance assertions
- assert training_duration < 120, f"Training took too long: {training_duration:.2f}s"
- assert memory_used < 500, f"Memory usage too high: {memory_used:.2f}MB"
- assert result['status'] == 'completed'
-
- # Quality assertions
- metrics = result['metrics']
- assert metrics['mape'] < 50, f"MAPE too high: {metrics['mape']:.2f}%"
-
- print(f"Performance Results:")
- print(f" Training Duration: {training_duration:.2f}s")
- print(f" Memory Used: {memory_used:.2f}MB")
- print(f" Data Points: {len(product_data)}")
- print(f" MAPE: {metrics['mape']:.2f}%")
- print(f" RMSE: {metrics['rmse']:.2f}")
-
- @pytest.mark.asyncio
- async def test_concurrent_training_performance(self, large_sales_dataset):
- """Test performance of concurrent training jobs"""
-
- trainer = BakeryMLTrainer()
- products = ["Pan Integral", "Croissant", "Magdalenas"]
-
- async def train_product(product_name: str):
- """Train a single product"""
- product_data = large_sales_dataset[large_sales_dataset['product'] == product_name].copy()
-
- start_time = time.time()
- result = await trainer.train_single_product(
- tenant_id=f"concurrent_test_{product_name.replace(' ', '_').lower()}",
- product_name=product_name,
- sales_data=product_data,
- config={"include_weather": True, "include_traffic": False}
- )
- end_time = time.time()
-
- return {
- 'product': product_name,
- 'duration': end_time - start_time,
- 'status': result['status'],
- 'metrics': result.get('metrics', {})
- }
-
- # Run concurrent training
- start_time = time.time()
- tasks = [train_product(product) for product in products]
- results = await asyncio.gather(*tasks)
- total_time = time.time() - start_time
-
- # Verify all trainings completed
- for result in results:
- assert result['status'] == 'completed'
- assert result['duration'] < 120 # Individual training time
-
- # Concurrent execution should be faster than sequential
- sequential_time_estimate = sum(r['duration'] for r in results)
- efficiency = sequential_time_estimate / total_time
-
- assert efficiency > 1.5, f"Concurrency efficiency too low: {efficiency:.2f}x"
-
- print(f"Concurrent Training Results:")
- print(f" Total Time: {total_time:.2f}s")
- print(f" Sequential Estimate: {sequential_time_estimate:.2f}s")
- print(f" Efficiency: {efficiency:.2f}x")
-
- for result in results:
- print(f" {result['product']}: {result['duration']:.2f}s, MAPE: {result['metrics'].get('mape', 'N/A'):.2f}%")
-
- @pytest.mark.asyncio
- async def test_data_processing_scalability(self, large_sales_dataset):
- """Test data processing performance with increasing data sizes"""
-
- data_processor = BakeryDataProcessor()
-
- # Test with different data sizes
- data_sizes = [1000, 5000, 10000, 20000, len(large_sales_dataset)]
- performance_results = []
-
- for size in data_sizes:
- # Take a sample of the specified size
- sample_data = large_sales_dataset.head(size).copy()
-
- start_time = time.time()
-
- # Process the data
- processed_data = await data_processor.prepare_training_data(
- sales_data=sample_data,
- include_weather=True,
- include_traffic=True,
- tenant_id="scalability_test",
- product_name="Pan Integral"
- )
-
- processing_time = time.time() - start_time
-
- performance_results.append({
- 'data_size': size,
- 'processing_time': processing_time,
- 'processed_rows': len(processed_data),
- 'throughput': size / processing_time if processing_time > 0 else 0
- })
-
- # Verify linear or sub-linear scaling
- for i in range(1, len(performance_results)):
- prev_result = performance_results[i-1]
- curr_result = performance_results[i]
-
- size_ratio = curr_result['data_size'] / prev_result['data_size']
- time_ratio = curr_result['processing_time'] / prev_result['processing_time']
-
- # Processing time should scale better than linearly
- assert time_ratio < size_ratio * 1.5, f"Poor scaling at size {curr_result['data_size']}"
-
- print("Data Processing Scalability Results:")
- for result in performance_results:
- print(f" Size: {result['data_size']:,} rows, Time: {result['processing_time']:.2f}s, "
- f"Throughput: {result['throughput']:.0f} rows/s")
-
- @pytest.mark.asyncio
- async def test_memory_usage_optimization(self, large_sales_dataset):
- """Test memory usage optimization during training"""
-
- trainer = BakeryMLTrainer()
- process = psutil.Process()
-
- # Baseline memory
- gc.collect() # Force garbage collection
- baseline_memory = process.memory_info().rss / 1024 / 1024 # MB
-
- memory_snapshots = [{'stage': 'baseline', 'memory_mb': baseline_memory}]
-
- # Load data
- product_data = large_sales_dataset[large_sales_dataset['product'] == 'Pan Integral'].copy()
- current_memory = process.memory_info().rss / 1024 / 1024
- memory_snapshots.append({'stage': 'data_loaded', 'memory_mb': current_memory})
-
- # Train model
- result = await trainer.train_single_product(
- tenant_id="memory_test_tenant",
- product_name="Pan Integral",
- sales_data=product_data,
- config={"include_weather": True, "include_traffic": True}
- )
-
- current_memory = process.memory_info().rss / 1024 / 1024
- memory_snapshots.append({'stage': 'model_trained', 'memory_mb': current_memory})
-
- # Cleanup
- del product_data
- del result
- gc.collect()
-
- final_memory = process.memory_info().rss / 1024 / 1024
- memory_snapshots.append({'stage': 'cleanup', 'memory_mb': final_memory})
-
- # Memory assertions
- peak_memory = max(snapshot['memory_mb'] for snapshot in memory_snapshots)
- memory_increase = peak_memory - baseline_memory
- memory_after_cleanup = final_memory - baseline_memory
-
- assert memory_increase < 800, f"Peak memory increase too high: {memory_increase:.2f}MB"
- assert memory_after_cleanup < 100, f"Memory not properly cleaned up: {memory_after_cleanup:.2f}MB"
-
- print("Memory Usage Analysis:")
- for snapshot in memory_snapshots:
- print(f" {snapshot['stage']}: {snapshot['memory_mb']:.2f}MB")
- print(f" Peak increase: {memory_increase:.2f}MB")
- print(f" After cleanup: {memory_after_cleanup:.2f}MB")
-
- @pytest.mark.asyncio
- async def test_training_service_throughput(self, large_sales_dataset):
- """Test training service throughput with multiple requests"""
-
- training_service = TrainingService()
-
- # Simulate multiple training requests
- num_requests = 5
- products = ["Pan Integral", "Croissant", "Magdalenas", "Empanadas", "Tarta Chocolate"]
-
- async def execute_training_request(request_id: int, product: str):
- """Execute a single training request"""
- product_data = large_sales_dataset[large_sales_dataset['product'] == product].copy()
-
- with patch.object(training_service, '_fetch_sales_data', return_value=product_data):
- start_time = time.time()
-
- result = await training_service.execute_training_job(
- db=None, # Mock DB session
- tenant_id=f"throughput_test_tenant_{request_id}",
- job_id=f"job_{request_id}_{product.replace(' ', '_').lower()}",
- request={
- 'products': [product],
- 'include_weather': True,
- 'include_traffic': False,
- 'config': {'seasonality_mode': 'additive'}
- }
- )
-
- duration = time.time() - start_time
- return {
- 'request_id': request_id,
- 'product': product,
- 'duration': duration,
- 'status': result.get('status', 'unknown'),
- 'models_trained': len(result.get('models_trained', []))
- }
-
- # Execute requests concurrently
- start_time = time.time()
- tasks = [
- execute_training_request(i, products[i % len(products)])
- for i in range(num_requests)
- ]
- results = await asyncio.gather(*tasks)
- total_time = time.time() - start_time
-
- # Calculate throughput metrics
- successful_requests = sum(1 for r in results if r['status'] == 'completed')
- throughput = successful_requests / total_time # requests per second
-
- # Performance assertions
- assert successful_requests >= num_requests * 0.8, "Too many failed requests"
- assert throughput >= 0.1, f"Throughput too low: {throughput:.3f} req/s"
- assert total_time < 300, f"Total time too long: {total_time:.2f}s"
-
- print(f"Training Service Throughput Results:")
- print(f" Total Requests: {num_requests}")
- print(f" Successful: {successful_requests}")
- print(f" Total Time: {total_time:.2f}s")
- print(f" Throughput: {throughput:.3f} req/s")
- print(f" Average Request Time: {total_time/num_requests:.2f}s")
-
- @pytest.mark.asyncio
- async def test_large_dataset_edge_cases(self, large_sales_dataset):
- """Test handling of edge cases with large datasets"""
-
- data_processor = BakeryDataProcessor()
-
- # Test 1: Dataset with many missing values
- corrupted_data = large_sales_dataset.copy()
- # Introduce 30% missing values randomly
- mask = np.random.random(len(corrupted_data)) < 0.3
- corrupted_data.loc[mask, 'quantity'] = np.nan
-
- start_time = time.time()
- result = await data_processor.validate_data_quality(corrupted_data)
- validation_time = time.time() - start_time
-
- assert validation_time < 10, f"Validation too slow: {validation_time:.2f}s"
- assert result['is_valid'] is False
- assert 'high_missing_data' in result['issues']
-
- # Test 2: Dataset with extreme outliers
- outlier_data = large_sales_dataset.copy()
- # Add extreme outliers (100x normal values)
- outlier_indices = np.random.choice(len(outlier_data), size=int(len(outlier_data) * 0.01), replace=False)
- outlier_data.loc[outlier_indices, 'quantity'] *= 100
-
- start_time = time.time()
- cleaned_data = await data_processor.clean_outliers(outlier_data)
- cleaning_time = time.time() - start_time
-
- assert cleaning_time < 15, f"Outlier cleaning too slow: {cleaning_time:.2f}s"
- assert len(cleaned_data) > len(outlier_data) * 0.95 # Should retain most data
-
- # Test 3: Very sparse data (many products with few sales)
- sparse_data = large_sales_dataset.copy()
- # Keep only 10% of data for each product randomly
- sparse_data = sparse_data.groupby('product').apply(
- lambda x: x.sample(n=max(1, int(len(x) * 0.1)))
- ).reset_index(drop=True)
-
- start_time = time.time()
- validation_result = await data_processor.validate_data_quality(sparse_data)
- sparse_validation_time = time.time() - start_time
-
- assert sparse_validation_time < 5, f"Sparse data validation too slow: {sparse_validation_time:.2f}s"
-
- print("Edge Case Performance Results:")
- print(f" Corrupted data validation: {validation_time:.2f}s")
- print(f" Outlier cleaning: {cleaning_time:.2f}s")
- print(f" Sparse data validation: {sparse_validation_time:.2f}s")
-
-
-class TestTrainingServiceLoad:
- """Load testing for training service under stress"""
-
- @pytest.mark.asyncio
- async def test_sustained_load_training(self, large_sales_dataset):
- """Test training service under sustained load"""
-
- trainer = BakeryMLTrainer()
-
- # Define load test parameters
- duration_minutes = 2 # Run for 2 minutes
- requests_per_minute = 3
-
- products = ["Pan Integral", "Croissant", "Magdalenas"]
-
- async def sustained_training_worker(worker_id: int, duration: float):
- """Worker that continuously submits training requests"""
- start_time = time.time()
- completed_requests = 0
- failed_requests = 0
-
- while time.time() - start_time < duration:
- try:
- product = products[completed_requests % len(products)]
- product_data = large_sales_dataset[
- large_sales_dataset['product'] == product
- ].copy()
-
- result = await trainer.train_single_product(
- tenant_id=f"load_test_worker_{worker_id}",
- product_name=product,
- sales_data=product_data,
- config={"include_weather": False, "include_traffic": False} # Minimal config for speed
- )
-
- if result['status'] == 'completed':
- completed_requests += 1
- else:
- failed_requests += 1
-
- except Exception as e:
- failed_requests += 1
- logging.error(f"Training request failed: {e}")
-
- # Wait before next request
- await asyncio.sleep(60 / requests_per_minute)
-
- return {
- 'worker_id': worker_id,
- 'completed': completed_requests,
- 'failed': failed_requests,
- 'duration': time.time() - start_time
- }
-
- # Start multiple workers
- num_workers = 2
- duration_seconds = duration_minutes * 60
-
- start_time = time.time()
- tasks = [
- sustained_training_worker(i, duration_seconds)
- for i in range(num_workers)
- ]
- results = await asyncio.gather(*tasks)
- total_time = time.time() - start_time
-
- # Analyze results
- total_completed = sum(r['completed'] for r in results)
- total_failed = sum(r['failed'] for r in results)
- success_rate = total_completed / (total_completed + total_failed) if (total_completed + total_failed) > 0 else 0
-
- # Performance assertions
- assert success_rate >= 0.8, f"Success rate too low: {success_rate:.2%}"
- assert total_completed >= duration_minutes * requests_per_minute * num_workers * 0.7, "Throughput too low"
-
- print(f"Sustained Load Test Results:")
- print(f" Duration: {total_time:.2f}s")
- print(f" Workers: {num_workers}")
- print(f" Completed Requests: {total_completed}")
- print(f" Failed Requests: {total_failed}")
- print(f" Success Rate: {success_rate:.2%}")
- print(f" Average Throughput: {total_completed/total_time:.2f} req/s")
-
- @pytest.mark.asyncio
- async def test_resource_exhaustion_recovery(self, large_sales_dataset):
- """Test service recovery from resource exhaustion"""
-
- trainer = BakeryMLTrainer()
-
- # Simulate resource exhaustion by running many concurrent requests
- num_concurrent = 10 # High concurrency to stress the system
-
- async def resource_intensive_task(task_id: int):
- """Task designed to consume resources"""
- try:
- # Use all products to increase memory usage
- all_products_data = large_sales_dataset.copy()
-
- result = await trainer.train_tenant_models(
- tenant_id=f"resource_test_{task_id}",
- sales_data=all_products_data,
- config={
- "train_all_products": True,
- "include_weather": True,
- "include_traffic": True
- }
- )
-
- return {'task_id': task_id, 'status': 'completed', 'error': None}
-
- except Exception as e:
- return {'task_id': task_id, 'status': 'failed', 'error': str(e)}
-
- # Launch all tasks simultaneously
- start_time = time.time()
- tasks = [resource_intensive_task(i) for i in range(num_concurrent)]
- results = await asyncio.gather(*tasks, return_exceptions=True)
- duration = time.time() - start_time
-
- # Analyze results
- completed = sum(1 for r in results if isinstance(r, dict) and r['status'] == 'completed')
- failed = sum(1 for r in results if isinstance(r, dict) and r['status'] == 'failed')
- exceptions = sum(1 for r in results if isinstance(r, Exception))
-
- # The system should handle some failures gracefully
- # but should complete at least some requests
- total_processed = completed + failed + exceptions
- processing_rate = total_processed / num_concurrent
-
- assert processing_rate >= 0.5, f"Too many requests not processed: {processing_rate:.2%}"
- assert duration < 600, f"Recovery took too long: {duration:.2f}s" # 10 minutes max
-
- print(f"Resource Exhaustion Test Results:")
- print(f" Concurrent Requests: {num_concurrent}")
- print(f" Completed: {completed}")
- print(f" Failed: {failed}")
- print(f" Exceptions: {exceptions}")
- print(f" Duration: {duration:.2f}s")
- print(f" Processing Rate: {processing_rate:.2%}")
-
-
-# ================================================================
-# BENCHMARK UTILITIES
-# ================================================================
-
-class PerformanceBenchmark:
- """Utility class for performance benchmarking"""
-
- @staticmethod
- def measure_execution_time(func):
- """Decorator to measure execution time"""
- async def wrapper(*args, **kwargs):
- start_time = time.time()
- result = await func(*args, **kwargs)
- execution_time = time.time() - start_time
-
- if hasattr(result, 'update') and isinstance(result, dict):
- result['execution_time'] = execution_time
-
- return result
- return wrapper
-
- @staticmethod
- def memory_profiler(func):
- """Decorator to profile memory usage"""
- async def wrapper(*args, **kwargs):
- process = psutil.Process()
-
- # Memory before
- gc.collect()
- memory_before = process.memory_info().rss / 1024 / 1024
-
- result = await func(*args, **kwargs)
-
- # Memory after
- memory_after = process.memory_info().rss / 1024 / 1024
- memory_used = memory_after - memory_before
-
- if hasattr(result, 'update') and isinstance(result, dict):
- result['memory_used_mb'] = memory_used
-
- return result
- return wrapper
-
-
-# ================================================================
-# STANDALONE EXECUTION
-# ================================================================
-
-if __name__ == "__main__":
- """
- Run performance tests as standalone script
- Usage: python test_performance.py
- """
- import sys
- import os
- from unittest.mock import patch
-
- # Add the training service root to Python path
- training_service_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
- sys.path.insert(0, training_service_root)
-
- print("=" * 60)
- print("TRAINING SERVICE PERFORMANCE TEST SUITE")
- print("=" * 60)
-
- # Setup logging
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
- )
-
- # Run performance tests
- pytest.main([
- __file__,
- "-v",
- "--tb=short",
- "-s", # Don't capture output
- "--durations=10", # Show 10 slowest tests
- "-m", "not slow", # Skip slow tests unless specifically requested
- ])
-
- print("\n" + "=" * 60)
- print("PERFORMANCE TESTING COMPLETE")
- print("=" * 60)
\ No newline at end of file
diff --git a/services/training/tests/test_service.py b/services/training/tests/test_service.py
deleted file mode 100644
index a1c03248..00000000
--- a/services/training/tests/test_service.py
+++ /dev/null
@@ -1,688 +0,0 @@
-# services/training/tests/test_service.py
-"""
-Tests for training service business logic layer
-"""
-
-import pytest
-from unittest.mock import AsyncMock, Mock, patch
-from datetime import datetime, timedelta
-import httpx
-
-from app.services.training_service import TrainingService
-from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
-from app.models.training import ModelTrainingLog, TrainedModel
-
-
-class TestTrainingService:
- """Test the training service business logic"""
-
- @pytest.fixture
- def training_service(self, mock_ml_trainer):
- return TrainingService()
-
- @pytest.mark.asyncio
- async def test_create_training_job_success(
- self,
- training_service,
- test_db_session
- ):
- """Test successful training job creation"""
- job_id = "test-job-123"
- tenant_id = "test-tenant"
- config = {"include_weather": True, "include_traffic": True}
-
- result = await training_service.create_training_job(
- db=test_db_session,
- tenant_id=tenant_id,
- job_id=job_id,
- config=config
- )
-
- assert isinstance(result, ModelTrainingLog)
- assert result.job_id == job_id
- assert result.tenant_id == tenant_id
- assert result.status == "pending"
- assert result.progress == 0
- assert result.config == config
-
- @pytest.mark.asyncio
- async def test_create_single_product_job_success(
- self,
- training_service,
- test_db_session
- ):
- """Test successful single product job creation"""
- job_id = "test-product-job-123"
- tenant_id = "test-tenant"
- product_name = "Pan Integral"
- config = {"include_weather": True}
-
- result = await training_service.create_single_product_job(
- db=test_db_session,
- tenant_id=tenant_id,
- product_name=product_name,
- job_id=job_id,
- config=config
- )
-
- assert isinstance(result, ModelTrainingLog)
- assert result.job_id == job_id
- assert result.tenant_id == tenant_id
- assert result.config["single_product"] == product_name
- assert f"Initializing training for {product_name}" in result.current_step
-
- @pytest.mark.asyncio
- async def test_get_job_status_existing(
- self,
- training_service,
- test_db_session,
- training_job_in_db
- ):
- """Test getting status of existing job"""
- result = await training_service.get_job_status(
- db=test_db_session,
- job_id=training_job_in_db.job_id,
- tenant_id=training_job_in_db.tenant_id
- )
-
- assert result is not None
- assert result.job_id == training_job_in_db.job_id
- assert result.status == training_job_in_db.status
-
- @pytest.mark.asyncio
- async def test_get_job_status_nonexistent(
- self,
- training_service,
- test_db_session
- ):
- """Test getting status of non-existent job"""
- result = await training_service.get_job_status(
- db=test_db_session,
- job_id="nonexistent-job",
- tenant_id="test-tenant"
- )
-
- assert result is None
-
- @pytest.mark.asyncio
- async def test_list_training_jobs(
- self,
- training_service,
- test_db_session,
- training_job_in_db
- ):
- """Test listing training jobs"""
- result = await training_service.list_training_jobs(
- db=test_db_session,
- tenant_id=training_job_in_db.tenant_id,
- limit=10
- )
-
- assert isinstance(result, list)
- assert len(result) >= 1
- assert result[0].job_id == training_job_in_db.job_id
-
- @pytest.mark.asyncio
- async def test_list_training_jobs_with_filter(
- self,
- training_service,
- test_db_session,
- training_job_in_db
- ):
- """Test listing training jobs with status filter"""
- result = await training_service.list_training_jobs(
- db=test_db_session,
- tenant_id=training_job_in_db.tenant_id,
- limit=10,
- status_filter="pending"
- )
-
- assert isinstance(result, list)
- for job in result:
- assert job.status == "pending"
-
- @pytest.mark.asyncio
- async def test_cancel_training_job_success(
- self,
- training_service,
- test_db_session,
- training_job_in_db
- ):
- """Test successful job cancellation"""
- result = await training_service.cancel_training_job(
- db=test_db_session,
- job_id=training_job_in_db.job_id,
- tenant_id=training_job_in_db.tenant_id
- )
-
- assert result is True
-
- # Verify status was updated
- updated_job = await training_service.get_job_status(
- db=test_db_session,
- job_id=training_job_in_db.job_id,
- tenant_id=training_job_in_db.tenant_id
- )
- assert updated_job.status == "cancelled"
-
- @pytest.mark.asyncio
- async def test_cancel_nonexistent_job(
- self,
- training_service,
- test_db_session
- ):
- """Test cancelling non-existent job"""
- result = await training_service.cancel_training_job(
- db=test_db_session,
- job_id="nonexistent-job",
- tenant_id="test-tenant"
- )
-
- assert result is False
-
- @pytest.mark.asyncio
- async def test_validate_training_data_valid(
- self,
- training_service,
- test_db_session,
- mock_data_service
- ):
- """Test validation with valid data"""
- config = {"min_data_points": 30}
-
- result = await training_service.validate_training_data(
- db=test_db_session,
- tenant_id="test-tenant",
- config=config
- )
-
- assert isinstance(result, dict)
- assert "is_valid" in result
- assert "issues" in result
- assert "recommendations" in result
- assert "estimated_time_minutes" in result
-
- @pytest.mark.asyncio
- async def test_validate_training_data_no_data(
- self,
- training_service,
- test_db_session
- ):
- """Test validation with no data"""
- config = {"min_data_points": 30}
-
- with patch('app.services.training_service.TrainingService._fetch_sales_data', return_value=AsyncMock(return_value=[])):
- result = await training_service.validate_training_data(
- db=test_db_session,
- tenant_id="test-tenant",
- config=config
- )
-
- assert result["is_valid"] is False
- assert "No sales data found" in result["issues"][0]
-
- @pytest.mark.asyncio
- async def test_update_job_status(
- self,
- training_service,
- test_db_session,
- training_job_in_db
- ):
- """Test updating job status"""
- await training_service._update_job_status(
- db=test_db_session,
- job_id=training_job_in_db.job_id,
- status="running",
- progress=50,
- current_step="Training models"
- )
-
- # Verify update
- updated_job = await training_service.get_job_status(
- db=test_db_session,
- job_id=training_job_in_db.job_id,
- tenant_id=training_job_in_db.tenant_id
- )
-
- assert updated_job.status == "running"
- assert updated_job.progress == 50
- assert updated_job.current_step == "Training models"
-
- @pytest.mark.asyncio
- async def test_store_trained_models(
- self,
- training_service,
- test_db_session
- ):
- """Test storing trained models"""
- tenant_id = "test-tenant"
- training_results = {
- "training_results": {
- "Pan Integral": {
- "status": "success",
- "model_info": {
- "model_id": "test-model-123",
- "model_path": "/test/models/test-model-123.pkl",
- "type": "prophet",
- "training_samples": 100,
- "features": ["temperature", "humidity"],
- "hyperparameters": {"seasonality_mode": "additive"},
- "training_metrics": {"mae": 5.2, "rmse": 7.8},
- "data_period": {
- "start_date": "2024-01-01T00:00:00",
- "end_date": "2024-01-31T00:00:00"
- }
- }
- }
- }
- }
-
- await training_service._store_trained_models(
- db=test_db_session,
- tenant_id=tenant_id,
- training_results=training_results
- )
-
- # Verify model was stored
- from sqlalchemy import select
- result = await test_db_session.execute(
- select(TrainedModel).where(
- TrainedModel.tenant_id == tenant_id,
- TrainedModel.product_name == "Pan Integral"
- )
- )
-
- stored_model = result.scalar_one_or_none()
- assert stored_model is not None
- assert stored_model.model_id == "test-model-123"
- assert stored_model.is_active is True
-
- @pytest.mark.asyncio
- async def test_get_training_logs(
- self,
- training_service,
- test_db_session,
- training_job_in_db
- ):
- """Test getting training logs"""
- result = await training_service.get_training_logs(
- db=test_db_session,
- job_id=training_job_in_db.job_id,
- tenant_id=training_job_in_db.tenant_id
- )
-
- assert isinstance(result, list)
- assert len(result) > 0
-
- # Check log content
- log_text = " ".join(result)
- assert training_job_in_db.job_id in log_text or "Job started" in log_text
-
-
-class TestTrainingServiceDataFetching:
- """Test external data fetching functionality"""
-
- @pytest.fixture
- def training_service(self):
- return TrainingService()
-
- @pytest.mark.asyncio
- async def test_fetch_sales_data_success(self, training_service):
- """Test successful sales data fetching"""
- mock_request = Mock()
- mock_request.start_date = None
- mock_request.end_date = None
-
- mock_response_data = {
- "sales": [
- {"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45}
- ]
- }
-
- with patch('httpx.AsyncClient') as mock_client:
- mock_response = Mock()
- mock_response.status_code = 200
- mock_response.json.return_value = mock_response_data
-
- mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
-
- result = await training_service._fetch_sales_data(
- tenant_id="test-tenant",
- request=mock_request
- )
-
- assert result == mock_response_data["sales"]
-
- @pytest.mark.asyncio
- async def test_fetch_sales_data_error(self, training_service):
- """Test sales data fetching with API error"""
- mock_request = Mock()
- mock_request.start_date = None
- mock_request.end_date = None
-
- with patch('httpx.AsyncClient') as mock_client:
- mock_response = Mock()
- mock_response.status_code = 500
-
- mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
-
- result = await training_service._fetch_sales_data(
- tenant_id="test-tenant",
- request=mock_request
- )
-
- assert result == []
-
- @pytest.mark.asyncio
- async def test_fetch_weather_data_success(self, training_service):
- """Test successful weather data fetching"""
- mock_request = Mock()
- mock_request.start_date = None
- mock_request.end_date = None
-
- mock_response_data = {
- "weather": [
- {"date": "2024-01-01", "temperature": 15.2, "precipitation": 0.0}
- ]
- }
-
- with patch('httpx.AsyncClient') as mock_client:
- mock_response = Mock()
- mock_response.status_code = 200
- mock_response.json.return_value = mock_response_data
-
- mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
-
- result = await training_service._fetch_weather_data(
- tenant_id="test-tenant",
- request=mock_request
- )
-
- assert result == mock_response_data["weather"]
-
- @pytest.mark.asyncio
- async def test_fetch_traffic_data_success(self, training_service):
- """Test successful traffic data fetching"""
- mock_request = Mock()
- mock_request.start_date = None
- mock_request.end_date = None
-
- mock_response_data = {
- "traffic": [
- {"date": "2024-01-01", "traffic_volume": 120}
- ]
- }
-
- with patch('httpx.AsyncClient') as mock_client:
- mock_response = Mock()
- mock_response.status_code = 200
- mock_response.json.return_value = mock_response_data
-
- mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
-
- result = await training_service._fetch_traffic_data(
- tenant_id="test-tenant",
- request=mock_request
- )
-
- assert result == mock_response_data["traffic"]
-
- @pytest.mark.asyncio
- async def test_fetch_data_with_date_filters(self, training_service):
- """Test data fetching with date filters"""
- from datetime import datetime
-
- mock_request = Mock()
- mock_request.start_date = datetime(2024, 1, 1)
- mock_request.end_date = datetime(2024, 1, 31)
-
- with patch('httpx.AsyncClient') as mock_client:
- mock_response = Mock()
- mock_response.status_code = 200
- mock_response.json.return_value = {"sales": []}
-
- mock_get = mock_client.return_value.__aenter__.return_value.get
- mock_get.return_value = mock_response
-
- await training_service._fetch_sales_data(
- tenant_id="test-tenant",
- request=mock_request
- )
-
- # Verify dates were passed in params
- call_args = mock_get.call_args
- params = call_args[1]["params"]
- assert "start_date" in params
- assert "end_date" in params
- assert params["start_date"] == "2024-01-01T00:00:00"
- assert params["end_date"] == "2024-01-31T00:00:00"
-
-
-class TestTrainingServiceExecution:
- """Test training execution workflow"""
-
- @pytest.fixture
- def training_service(self, mock_ml_trainer):
- return TrainingService()
-
- @pytest.mark.asyncio
- async def test_execute_training_job_success(
- self,
- training_service,
- test_db_session,
- mock_messaging,
- mock_data_service
- ):
- """Test successful training job execution"""
- # Create job first
- job_id = "test-execution-job"
- training_log = await training_service.create_training_job(
- db=test_db_session,
- tenant_id="test-tenant",
- job_id=job_id,
- config={"include_weather": True}
- )
-
- request = TrainingJobRequest(
- include_weather=True,
- include_traffic=True,
- min_data_points=30
- )
-
- with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch_sales, \
- patch('app.services.training_service.TrainingService._fetch_weather_data') as mock_fetch_weather, \
- patch('app.services.training_service.TrainingService._fetch_traffic_data') as mock_fetch_traffic, \
- patch('app.services.training_service.TrainingService._store_trained_models') as mock_store:
-
- mock_fetch_sales.return_value = [{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45}]
- mock_fetch_weather.return_value = []
- mock_fetch_traffic.return_value = []
- mock_store.return_value = None
-
- await training_service.execute_training_job(
- db=test_db_session,
- job_id=job_id,
- tenant_id="test-tenant",
- request=request
- )
-
- # Verify job was completed
- updated_job = await training_service.get_job_status(
- db=test_db_session,
- job_id=job_id,
- tenant_id="test-tenant"
- )
-
- assert updated_job.status == "completed"
- assert updated_job.progress == 100
-
- @pytest.mark.asyncio
- async def test_execute_training_job_failure(
- self,
- training_service,
- test_db_session,
- mock_messaging
- ):
- """Test training job execution with failure"""
- # Create job first
- job_id = "test-failure-job"
- await training_service.create_training_job(
- db=test_db_session,
- tenant_id="test-tenant",
- job_id=job_id,
- config={}
- )
-
- request = TrainingJobRequest(min_data_points=30)
-
- with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch:
- mock_fetch.side_effect = Exception("Data service unavailable")
-
- with pytest.raises(Exception):
- await training_service.execute_training_job(
- db=test_db_session,
- job_id=job_id,
- tenant_id="test-tenant",
- request=request
- )
-
- # Verify job was marked as failed
- updated_job = await training_service.get_job_status(
- db=test_db_session,
- job_id=job_id,
- tenant_id="test-tenant"
- )
-
- assert updated_job.status == "failed"
- assert "Data service unavailable" in updated_job.error_message
-
- @pytest.mark.asyncio
- async def test_execute_single_product_training_success(
- self,
- training_service,
- test_db_session,
- mock_messaging,
- mock_data_service
- ):
- """Test successful single product training execution"""
- job_id = "test-single-product-job"
- product_name = "Pan Integral"
-
- await training_service.create_single_product_job(
- db=test_db_session,
- tenant_id="test-tenant",
- product_name=product_name,
- job_id=job_id,
- config={}
- )
-
- request = SingleProductTrainingRequest(
- include_weather=True,
- include_traffic=False
- )
-
- with patch('app.services.training_service.TrainingService._fetch_product_sales_data') as mock_fetch_sales, \
- patch('app.services.training_service.TrainingService._fetch_weather_data') as mock_fetch_weather, \
- patch('app.services.training_service.TrainingService._store_single_trained_model') as mock_store:
-
- mock_fetch_sales.return_value = [{"date": "2024-01-01", "product_name": product_name, "quantity": 45}]
- mock_fetch_weather.return_value = []
- mock_store.return_value = None
-
- await training_service.execute_single_product_training(
- db=test_db_session,
- job_id=job_id,
- tenant_id="test-tenant",
- product_name=product_name,
- request=request
- )
-
- # Verify job was completed
- updated_job = await training_service.get_job_status(
- db=test_db_session,
- job_id=job_id,
- tenant_id="test-tenant"
- )
-
- assert updated_job.status == "completed"
- assert updated_job.progress == 100
-
-
-class TestTrainingServiceEdgeCases:
- """Test edge cases and error conditions"""
-
- @pytest.fixture
- def training_service(self):
- return TrainingService()
-
- @pytest.mark.asyncio
- async def test_database_connection_failure(self, training_service):
- """Test handling of database connection failures"""
- with patch('sqlalchemy.ext.asyncio.AsyncSession') as mock_session:
- mock_session.side_effect = Exception("Database connection failed")
-
- with pytest.raises(Exception):
- await training_service.create_training_job(
- db=mock_session,
- tenant_id="test-tenant",
- job_id="test-job",
- config={}
- )
-
- @pytest.mark.asyncio
- async def test_external_service_timeout(self, training_service):
- """Test handling of external service timeouts"""
- mock_request = Mock()
- mock_request.start_date = None
- mock_request.end_date = None
-
- with patch('httpx.AsyncClient') as mock_client:
- mock_client.return_value.__aenter__.return_value.get.side_effect = httpx.TimeoutException("Request timeout")
-
- result = await training_service._fetch_sales_data(
- tenant_id="test-tenant",
- request=mock_request
- )
-
- # Should return empty list on timeout
- assert result == []
-
- @pytest.mark.asyncio
- async def test_concurrent_job_creation(self, training_service, test_db_session):
- """Test handling of concurrent job creation"""
- # This test would need more sophisticated setup for true concurrency testing
- # For now, just test that multiple jobs can be created
-
- job_ids = ["concurrent-job-1", "concurrent-job-2", "concurrent-job-3"]
-
- jobs = []
- for job_id in job_ids:
- job = await training_service.create_training_job(
- db=test_db_session,
- tenant_id="test-tenant",
- job_id=job_id,
- config={}
- )
- jobs.append(job)
-
- assert len(jobs) == 3
- for i, job in enumerate(jobs):
- assert job.job_id == job_ids[i]
-
- @pytest.mark.asyncio
- async def test_malformed_config_handling(self, training_service, test_db_session):
- """Test handling of malformed configuration"""
- malformed_config = {
- "invalid_key": "invalid_value",
- "nested": {"data": None}
- }
-
- # Should not raise exception, just store the config as-is
- job = await training_service.create_training_job(
- db=test_db_session,
- tenant_id="test-tenant",
- job_id="malformed-config-job",
- config=malformed_config
- )
-
- assert job.config == malformed_config
\ No newline at end of file