311 lines
13 KiB
Python
311 lines
13 KiB
Python
# ================================================================
|
|
# services/training/tests/test_end_to_end.py
|
|
# ================================================================
|
|
"""
|
|
End-to-End Testing for Training Service
|
|
Tests complete workflows from API to ML pipeline to results
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
import httpx
|
|
import pandas as pd
|
|
import json
|
|
import tempfile
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Any
|
|
from unittest.mock import patch, AsyncMock
|
|
import uuid
|
|
|
|
from app.main import app
|
|
from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
|
|
|
|
|
|
class TestTrainingServiceEndToEnd:
|
|
"""End-to-end tests for complete training workflows"""
|
|
|
|
@pytest.fixture
|
|
async def test_client(self):
|
|
"""Create test client for the training service"""
|
|
from httpx import AsyncClient
|
|
async with AsyncClient(app=app, base_url="http://test") as client:
|
|
yield client
|
|
|
|
@pytest.fixture
|
|
def real_bakery_data(self):
|
|
"""Use the actual bakery sales data from the uploaded CSV"""
|
|
# This fixture would load the real bakery_sales_2023_2024.csv data
|
|
# For testing, we'll simulate the structure based on the document description
|
|
|
|
# Generate realistic data matching the CSV structure
|
|
start_date = datetime(2023, 1, 1)
|
|
dates = [start_date + timedelta(days=i) for i in range(365)]
|
|
|
|
products = [
|
|
"Pan Integral", "Pan Blanco", "Croissant", "Magdalenas",
|
|
"Empanadas", "Tarta Chocolate", "Roscon Reyes", "Palmeras"
|
|
]
|
|
|
|
data = []
|
|
for date in dates:
|
|
for product in products:
|
|
# Realistic sales patterns for Madrid bakery
|
|
base_quantity = {
|
|
"Pan Integral": 80, "Pan Blanco": 120, "Croissant": 45,
|
|
"Magdalenas": 30, "Empanadas": 25, "Tarta Chocolate": 15,
|
|
"Roscon Reyes": 8, "Palmeras": 12
|
|
}.get(product, 20)
|
|
|
|
# Seasonal variations
|
|
if date.month == 12 and product == "Roscon Reyes":
|
|
base_quantity *= 5 # Christmas specialty
|
|
elif date.month in [6, 7, 8]: # Summer
|
|
base_quantity *= 0.85
|
|
elif date.month in [11, 12, 1]: # Winter
|
|
base_quantity *= 1.15
|
|
|
|
# Weekly patterns
|
|
if date.weekday() >= 5: # Weekends
|
|
base_quantity *= 1.3
|
|
elif date.weekday() == 0: # Monday slower
|
|
base_quantity *= 0.8
|
|
|
|
# Weather influence
|
|
temp = 15 + 12 * np.sin((date.timetuple().tm_yday / 365) * 2 * np.pi)
|
|
if temp > 30: # Very hot days
|
|
if product in ["Pan Integral", "Pan Blanco"]:
|
|
base_quantity *= 0.7
|
|
elif temp < 5: # Cold days
|
|
base_quantity *= 1.1
|
|
|
|
# Add realistic noise
|
|
import numpy as np
|
|
quantity = max(1, int(base_quantity + np.random.normal(0, base_quantity * 0.15)))
|
|
|
|
# Calculate revenue (realistic Spanish bakery prices)
|
|
price_per_unit = {
|
|
"Pan Integral": 2.80, "Pan Blanco": 2.50, "Croissant": 1.50,
|
|
"Magdalenas": 1.20, "Empanadas": 3.50, "Tarta Chocolate": 18.00,
|
|
"Roscon Reyes": 25.00, "Palmeras": 1.80
|
|
}.get(product, 2.00)
|
|
|
|
revenue = round(quantity * price_per_unit, 2)
|
|
|
|
data.append({
|
|
"date": date.strftime("%Y-%m-%d"),
|
|
"product": product,
|
|
"quantity": quantity,
|
|
"revenue": revenue,
|
|
"temperature": round(temp + np.random.normal(0, 3), 1),
|
|
"precipitation": max(0, np.random.exponential(0.8)),
|
|
"is_weekend": date.weekday() >= 5,
|
|
"is_holiday": self._is_spanish_holiday(date)
|
|
})
|
|
|
|
return pd.DataFrame(data)
|
|
|
|
def _is_spanish_holiday(self, date: datetime) -> bool:
|
|
"""Check if date is a Spanish holiday"""
|
|
spanish_holidays = [
|
|
(1, 1), # Año Nuevo
|
|
(1, 6), # Reyes Magos
|
|
(5, 1), # Día del Trabajo
|
|
(8, 15), # Asunción de la Virgen
|
|
(10, 12), # Fiesta Nacional de España
|
|
(11, 1), # Todos los Santos
|
|
(12, 6), # Día de la Constitución
|
|
(12, 8), # Inmaculada Concepción
|
|
(12, 25), # Navidad
|
|
]
|
|
return (date.month, date.day) in spanish_holidays
|
|
|
|
@pytest.fixture
|
|
async def mock_external_apis(self):
|
|
"""Mock external APIs (AEMET and Madrid OpenData)"""
|
|
with patch('app.external.aemet.AEMETClient') as mock_aemet, \
|
|
patch('app.external.madrid_opendata.MadridOpenDataClient') as mock_madrid:
|
|
|
|
# Mock AEMET weather data
|
|
mock_aemet_instance = AsyncMock()
|
|
mock_aemet.return_value = mock_aemet_instance
|
|
|
|
# Generate realistic Madrid weather data
|
|
weather_data = []
|
|
for i in range(365):
|
|
date = datetime(2023, 1, 1) + timedelta(days=i)
|
|
day_of_year = date.timetuple().tm_yday
|
|
# Madrid climate: hot summers, mild winters
|
|
base_temp = 14 + 12 * np.sin((day_of_year / 365) * 2 * np.pi)
|
|
|
|
weather_data.append({
|
|
"date": date,
|
|
"temperature": round(base_temp + np.random.normal(0, 4), 1),
|
|
"precipitation": max(0, np.random.exponential(1.2)),
|
|
"humidity": np.random.uniform(25, 75),
|
|
"wind_speed": np.random.uniform(3, 20),
|
|
"pressure": np.random.uniform(995, 1025),
|
|
"description": np.random.choice([
|
|
"Soleado", "Parcialmente nublado", "Nublado",
|
|
"Lluvia ligera", "Despejado"
|
|
]),
|
|
"source": "aemet"
|
|
})
|
|
|
|
mock_aemet_instance.get_historical_weather.return_value = weather_data
|
|
mock_aemet_instance.get_current_weather.return_value = weather_data[-1]
|
|
|
|
# Mock Madrid traffic data
|
|
mock_madrid_instance = AsyncMock()
|
|
mock_madrid.return_value = mock_madrid_instance
|
|
|
|
traffic_data = []
|
|
for i in range(365):
|
|
date = datetime(2023, 1, 1) + timedelta(days=i)
|
|
|
|
# Multiple measurements per day
|
|
for hour in range(6, 22, 2): # Every 2 hours from 6 AM to 10 PM
|
|
measurement_time = date.replace(hour=hour)
|
|
|
|
# Realistic Madrid traffic patterns
|
|
if hour in [7, 8, 9, 18, 19, 20]: # Rush hours
|
|
volume = np.random.randint(1200, 2000)
|
|
congestion = "high"
|
|
speed = np.random.randint(10, 25)
|
|
elif hour in [12, 13, 14]: # Lunch time
|
|
volume = np.random.randint(800, 1200)
|
|
congestion = "medium"
|
|
speed = np.random.randint(20, 35)
|
|
else: # Off-peak
|
|
volume = np.random.randint(300, 800)
|
|
congestion = "low"
|
|
speed = np.random.randint(30, 50)
|
|
|
|
traffic_data.append({
|
|
"date": measurement_time,
|
|
"traffic_volume": volume,
|
|
"occupation_percentage": np.random.randint(15, 85),
|
|
"load_percentage": np.random.randint(25, 90),
|
|
"average_speed": speed,
|
|
"congestion_level": congestion,
|
|
"pedestrian_count": np.random.randint(100, 800),
|
|
"measurement_point_id": "MADRID_CENTER_001",
|
|
"measurement_point_name": "Puerta del Sol",
|
|
"road_type": "URB",
|
|
"source": "madrid_opendata"
|
|
})
|
|
|
|
mock_madrid_instance.get_historical_traffic.return_value = traffic_data
|
|
mock_madrid_instance.get_current_traffic.return_value = traffic_data[-1]
|
|
|
|
yield {
|
|
'aemet': mock_aemet_instance,
|
|
'madrid': mock_madrid_instance
|
|
}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_training_workflow_api(
|
|
self,
|
|
test_client,
|
|
real_bakery_data,
|
|
mock_external_apis
|
|
):
|
|
"""Test complete training workflow through API endpoints"""
|
|
|
|
# Step 1: Check service health
|
|
health_response = await test_client.get("/health")
|
|
assert health_response.status_code == 200
|
|
health_data = health_response.json()
|
|
assert health_data["status"] == "healthy"
|
|
|
|
# Step 2: Validate training data quality
|
|
with patch('app.services.training_service.TrainingService._fetch_sales_data',
|
|
return_value=real_bakery_data):
|
|
|
|
validation_response = await test_client.post(
|
|
"/training/validate",
|
|
json={
|
|
"tenant_id": "test_bakery_001",
|
|
"include_weather": True,
|
|
"include_traffic": True
|
|
}
|
|
)
|
|
|
|
assert validation_response.status_code == 200
|
|
validation_data = validation_response.json()
|
|
assert validation_data["is_valid"] is True
|
|
assert validation_data["data_points"] > 1000 # Sufficient data
|
|
assert validation_data["missing_percentage"] < 10
|
|
|
|
# Step 3: Start training job for multiple products
|
|
training_request = {
|
|
"products": ["Pan Integral", "Croissant", "Magdalenas"],
|
|
"include_weather": True,
|
|
"include_traffic": True,
|
|
"config": {
|
|
"seasonality_mode": "additive",
|
|
"changepoint_prior_scale": 0.05,
|
|
"seasonality_prior_scale": 10.0,
|
|
"validation_enabled": True
|
|
}
|
|
}
|
|
|
|
with patch('app.services.training_service.TrainingService._fetch_sales_data',
|
|
return_value=real_bakery_data):
|
|
|
|
start_response = await test_client.post(
|
|
"/training/jobs",
|
|
json=training_request,
|
|
headers={"X-Tenant-ID": "test_bakery_001"}
|
|
)
|
|
|
|
assert start_response.status_code == 201
|
|
job_data = start_response.json()
|
|
job_id = job_data["job_id"]
|
|
assert job_data["status"] == "pending"
|
|
|
|
# Step 4: Monitor job progress
|
|
max_wait_time = 300 # 5 minutes
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < max_wait_time:
|
|
status_response = await test_client.get(f"/training/jobs/{job_id}/status")
|
|
assert status_response.status_code == 200
|
|
|
|
status_data = status_response.json()
|
|
|
|
if status_data["status"] == "completed":
|
|
# Training completed successfully
|
|
assert "models_trained" in status_data
|
|
assert len(status_data["models_trained"]) == 3 # Three products
|
|
|
|
# Check model quality
|
|
for model_info in status_data["models_trained"]:
|
|
assert "product_name" in model_info
|
|
assert "model_id" in model_info
|
|
assert "metrics" in model_info
|
|
|
|
metrics = model_info["metrics"]
|
|
assert "mape" in metrics
|
|
assert "rmse" in metrics
|
|
assert "mae" in metrics
|
|
|
|
# Quality thresholds for bakery data
|
|
assert metrics["mape"] < 50, f"MAPE too high for {model_info['product_name']}: {metrics['mape']}"
|
|
assert metrics["rmse"] > 0
|
|
|
|
break
|
|
elif status_data["status"] == "failed":
|
|
pytest.fail(f"Training job failed: {status_data.get('error_message', 'Unknown error')}")
|
|
|
|
# Wait before checking again
|
|
await asyncio.sleep(10)
|
|
else:
|
|
pytest.fail(f"Training job did not complete within {max_wait_time} seconds")
|
|
|
|
# Step 5: Get detailed job logs
|
|
logs_response = await test_client.get(f"/training/jobs/{job_id}/logs")
|
|
assert logs_response.status_code == 200
|
|
logs_data = logs_response.json()
|
|
assert "logs" in logs_data
|
|
assert len(logs_data["logs"]) > 0 |