Files
bakery-ia/services/training/app/services/training_orchestrator.py
2025-07-30 21:21:02 +02:00

749 lines
33 KiB
Python

# services/training/app/services/training_orchestrator.py
"""
Training Data Orchestrator - Enhanced Integration Layer
Orchestrates data collection, date alignment, and preparation for ML training
"""
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import timezone
import pandas as pd
from app.services.data_client import DataClient
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed,
publish_job_failed
)
logger = logging.getLogger(__name__)
@dataclass
class TrainingDataSet:
"""Container for all training data with metadata"""
sales_data: List[Dict[str, Any]]
weather_data: List[Dict[str, Any]]
traffic_data: List[Dict[str, Any]]
date_range: AlignedDateRange
metadata: Dict[str, Any]
class TrainingDataOrchestrator:
"""
Enhanced orchestrator for data collection from multiple sources.
Ensures date alignment, handles data source constraints, and prepares data for ML training.
"""
def __init__(self,
madrid_client=None,
weather_client=None,
date_alignment_service: DateAlignmentService = None):
self.data_client = DataClient()
self.date_alignment_service = date_alignment_service or DateAlignmentService()
self.max_concurrent_requests = 3
async def prepare_training_data(
self,
tenant_id: str,
bakery_location: Tuple[float, float], # (lat, lon)
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None,
job_id: Optional[str] = None
) -> TrainingDataSet:
"""
Main method to prepare all training data with comprehensive date alignment.
Args:
tenant_id: Tenant identifier
sales_data: User-provided sales data
bakery_location: Bakery coordinates (lat, lon)
requested_start: Optional explicit start date
requested_end: Optional explicit end date
job_id: Training job identifier for logging
Returns:
TrainingDataSet with all aligned and validated data
"""
logger.info(f"Starting comprehensive training data preparation for tenant {tenant_id}, job {job_id}")
try:
await publish_job_progress(job_id, tenant_id, 5, "Extrayendo datos de ventas",
step_details="Conectando con servicio de datos")
sales_data = await self.data_client.fetch_sales_data(tenant_id)
# Step 1: Extract and validate sales data date range
await publish_job_progress(job_id, tenant_id, 10, "Validando fechas de datos de venta",
step_details="Aplicando restricciones de fuentes de datos")
sales_date_range = self._extract_sales_date_range(sales_data)
logger.info(f"Sales data range detected: {sales_date_range.start} to {sales_date_range.end}")
# Step 2: Apply date alignment across all data sources
await publish_job_progress(job_id, tenant_id, 15, "Alinear el rango de fechas",
step_details="Aplicar la alineación de fechas en todas las fuentes de datos")
aligned_range = self.date_alignment_service.validate_and_align_dates(
user_sales_range=sales_date_range,
requested_start=requested_start,
requested_end=requested_end
)
logger.info(f"Date alignment completed: {aligned_range.start} to {aligned_range.end}")
if aligned_range.constraints:
logger.info(f"Applied constraints: {aligned_range.constraints}")
# Step 3: Filter sales data to aligned date range
await publish_job_progress(job_id, tenant_id, 20, "Alinear el rango de las ventas",
step_details="Aplicar la alineación de fechas de las ventas")
filtered_sales = self._filter_sales_data(sales_data, aligned_range)
# Step 4: Collect external data sources concurrently
logger.info("Collecting external data sources...")
await publish_job_progress(job_id, tenant_id, 25, "Recopilación de fuentes de datos externas",
step_details="Recopilación de fuentes de datos externas")
weather_data, traffic_data = await self._collect_external_data(
aligned_range, bakery_location, tenant_id
)
# Step 5: Validate data quality
await publish_job_progress(job_id, tenant_id, 30, "Validando la calidad de los datos",
step_details="Validando la calidad de los datos")
data_quality_results = self._validate_data_sources(
filtered_sales, weather_data, traffic_data, aligned_range
)
# Step 6: Create comprehensive training dataset
training_dataset = TrainingDataSet(
sales_data=filtered_sales,
weather_data=weather_data,
traffic_data=traffic_data,
date_range=aligned_range,
metadata={
"tenant_id": tenant_id,
"job_id": job_id,
"bakery_location": bakery_location,
"data_sources_used": aligned_range.available_sources,
"constraints_applied": aligned_range.constraints,
"data_quality": data_quality_results,
"preparation_timestamp": datetime.now().isoformat(),
"original_sales_range": {
"start": sales_date_range.start.isoformat(),
"end": sales_date_range.end.isoformat()
}
}
)
# Step 7: Final validation
await publish_job_progress(job_id, tenant_id, 35, "Validancion final de los datos",
step_details="Validancion final de los datos")
final_validation = self.validate_training_data_quality(training_dataset)
training_dataset.metadata["final_validation"] = final_validation
logger.info(f"Training data preparation completed successfully:")
logger.info(f" - Sales records: {len(filtered_sales)}")
logger.info(f" - Weather records: {len(weather_data)}")
logger.info(f" - Traffic records: {len(traffic_data)}")
logger.info(f" - Data quality score: {final_validation.get('data_quality_score', 'N/A')}")
return training_dataset
except Exception as e:
publish_job_failed(job_id, tenant_id, str(e))
logger.error(f"Training data preparation failed: {str(e)}")
raise ValueError(f"Failed to prepare training data: {str(e)}")
def _extract_sales_date_range(self, sales_data: List[Dict[str, Any]]) -> DateRange:
"""Extract date range from sales data with timezone handling"""
if not sales_data:
raise ValueError("No sales data provided")
dates = []
for record in sales_data:
date_value = record.get('date')
if date_value:
# ✅ FIX: Ensure timezone-aware datetime
if isinstance(date_value, str):
dt = pd.to_datetime(date_value)
if dt.tz is None:
dt = dt.replace(tzinfo=timezone.utc)
dates.append(dt.to_pydatetime())
elif isinstance(date_value, datetime):
if date_value.tzinfo is None:
date_value = date_value.replace(tzinfo=timezone.utc)
dates.append(date_value)
if not dates:
raise ValueError("No valid dates found in sales data")
start_date = min(dates)
end_date = max(dates)
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
def _filter_sales_data(
self,
sales_data: List[Dict[str, Any]],
aligned_range: AlignedDateRange
) -> List[Dict[str, Any]]:
"""Filter sales data to the aligned date range with enhanced validation"""
filtered_data = []
filtered_count = 0
for record in sales_data:
try:
if 'date' in record:
record_date = record['date']
# ✅ FIX: Proper timezone handling for date parsing
if isinstance(record_date, str):
if 'T' in record_date:
record_date = record_date.replace('Z', '+00:00')
# Parse with timezone info intact
parsed_date = datetime.fromisoformat(record_date.split('T')[0])
# Ensure timezone-aware
if parsed_date.tzinfo is None:
parsed_date = parsed_date.replace(tzinfo=timezone.utc)
record_date = parsed_date
elif isinstance(record_date, datetime):
# Ensure timezone-aware
if record_date.tzinfo is None:
record_date = record_date.replace(tzinfo=timezone.utc)
# Normalize to start of day
record_date = record_date.replace(hour=0, minute=0, second=0, microsecond=0)
# ✅ FIX: Ensure aligned_range dates are also timezone-aware for comparison
aligned_start = aligned_range.start
aligned_end = aligned_range.end
if aligned_start.tzinfo is None:
aligned_start = aligned_start.replace(tzinfo=timezone.utc)
if aligned_end.tzinfo is None:
aligned_end = aligned_end.replace(tzinfo=timezone.utc)
# Check if date falls within aligned range (now both are timezone-aware)
if aligned_start <= record_date <= aligned_end:
# Validate that record has required fields
if self._validate_sales_record(record):
filtered_data.append(record)
else:
filtered_count += 1
else:
# Record outside date range
filtered_count += 1
except Exception as e:
logger.warning(f"Error processing sales record: {str(e)}")
filtered_count += 1
continue
logger.info(f"Filtered sales data: {len(filtered_data)} records in aligned range")
if filtered_count > 0:
logger.warning(f"Filtered out {filtered_count} invalid records")
return filtered_data
def _validate_sales_record(self, record: Dict[str, Any]) -> bool:
"""Validate individual sales record"""
required_fields = ['date', 'product_name']
quantity_fields = ['quantity', 'quantity_sold', 'sales', 'units_sold']
# Check required fields
for field in required_fields:
if field not in record or record[field] is None:
return False
# Check at least one quantity field exists
has_quantity = any(field in record and record[field] is not None for field in quantity_fields)
if not has_quantity:
return False
# Validate quantity is numeric and non-negative
for field in quantity_fields:
if field in record and record[field] is not None:
try:
quantity = float(record[field])
if quantity < 0:
return False
except (ValueError, TypeError):
return False
break
return True
async def _collect_external_data(
self,
aligned_range: AlignedDateRange,
bakery_location: Tuple[float, float],
tenant_id: str
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Collect weather and traffic data concurrently with enhanced error handling"""
lat, lon = bakery_location
# Create collection tasks with timeout
tasks = []
# Weather data collection
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
weather_task = asyncio.create_task(
self._collect_weather_data_with_timeout(lat, lon, aligned_range, tenant_id)
)
tasks.append(("weather", weather_task))
# Traffic data collection
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
traffic_task = asyncio.create_task(
self._collect_traffic_data_with_timeout(lat, lon, aligned_range, tenant_id)
)
tasks.append(("traffic", traffic_task))
# Execute tasks concurrently with proper error handling
results = {}
if tasks:
try:
completed_tasks = await asyncio.gather(
*[task for _, task in tasks],
return_exceptions=True
)
for i, (task_name, _) in enumerate(tasks):
result = completed_tasks[i]
if isinstance(result, Exception):
logger.warning(f"{task_name} data collection failed: {result}")
results[task_name] = []
else:
results[task_name] = result
logger.info(f"{task_name} data collection completed: {len(result)} records")
except Exception as e:
logger.error(f"Error in concurrent data collection: {str(e)}")
results = {"weather": [], "traffic": []}
weather_data = results.get("weather", [])
traffic_data = results.get("traffic", [])
return weather_data, traffic_data
async def _collect_weather_data_with_timeout(
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Collect weather data with timeout and fallback"""
try:
start_date_str = aligned_range.start.isoformat()
end_date_str = aligned_range.end.isoformat()
weather_data = await self.data_client.fetch_weather_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=lat,
longitude=lon)
# Validate weather data
if self._validate_weather_data(weather_data):
logger.info(f"Collected {len(weather_data)} valid weather records")
return weather_data
else:
logger.warning("Invalid weather data received, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
except asyncio.TimeoutError:
logger.warning(f"Weather data collection timed out, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
except Exception as e:
logger.warning(f"Weather data collection failed: {e}, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
async def _collect_traffic_data_with_timeout(
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Collect traffic data with timeout and Madrid constraint validation"""
try:
# Double-check Madrid constraint before making request
if self.date_alignment_service.check_madrid_current_month_constraint(aligned_range.end):
logger.warning("Madrid current month constraint violation, no traffic data available")
return []
start_date_str = aligned_range.start.isoformat()
end_date_str = aligned_range.end.isoformat()
traffic_data = await self.data_client.fetch_traffic_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=lat,
longitude=lon)
# Validate traffic data
if self._validate_traffic_data(traffic_data):
logger.info(f"Collected {len(traffic_data)} valid traffic records")
return traffic_data
else:
logger.warning("Invalid traffic data received")
return []
except asyncio.TimeoutError:
logger.warning(f"Traffic data collection timed out")
return []
except Exception as e:
logger.warning(f"Traffic data collection failed: {e}")
return []
def _validate_weather_data(self, weather_data: List[Dict[str, Any]]) -> bool:
"""Validate weather data quality"""
if not weather_data:
return False
required_fields = ['date']
weather_fields = ['temperature', 'temp', 'temperatura', 'precipitation', 'rain', 'lluvia']
valid_records = 0
for record in weather_data:
# Check required fields
if not all(field in record for field in required_fields):
continue
# Check at least one weather field exists
if any(field in record and record[field] is not None for field in weather_fields):
valid_records += 1
# Consider valid if at least 50% of records are valid
validity_threshold = 0.5
is_valid = (valid_records / len(weather_data)) >= validity_threshold
if not is_valid:
logger.warning(f"Weather data validation failed: {valid_records}/{len(weather_data)} valid records")
return is_valid
def _validate_traffic_data(self, traffic_data: List[Dict[str, Any]]) -> bool:
"""Validate traffic data quality"""
if not traffic_data:
return False
required_fields = ['date']
traffic_fields = ['traffic_volume', 'traffic_intensity', 'intensidad', 'trafico']
valid_records = 0
for record in traffic_data:
# Check required fields
if not all(field in record for field in required_fields):
continue
# Check at least one traffic field exists
if any(field in record and record[field] is not None for field in traffic_fields):
valid_records += 1
# Consider valid if at least 30% of records are valid (traffic data is often sparse)
validity_threshold = 0.3
is_valid = (valid_records / len(traffic_data)) >= validity_threshold
if not is_valid:
logger.warning(f"Traffic data validation failed: {valid_records}/{len(traffic_data)} valid records")
return is_valid
def _validate_data_sources(
self,
sales_data: List[Dict[str, Any]],
weather_data: List[Dict[str, Any]],
traffic_data: List[Dict[str, Any]],
aligned_range: AlignedDateRange
) -> Dict[str, Any]:
"""Validate all data sources and provide quality metrics"""
validation_results = {
"sales_data": {
"record_count": len(sales_data),
"is_valid": len(sales_data) > 0,
"coverage_days": (aligned_range.end - aligned_range.start).days,
"quality_score": 0.0
},
"weather_data": {
"record_count": len(weather_data),
"is_valid": self._validate_weather_data(weather_data) if weather_data else False,
"quality_score": 0.0
},
"traffic_data": {
"record_count": len(traffic_data),
"is_valid": self._validate_traffic_data(traffic_data) if traffic_data else False,
"quality_score": 0.0
},
"overall_quality_score": 0.0
}
# Calculate quality scores
# Sales data quality (most important)
if validation_results["sales_data"]["record_count"] > 0:
coverage_ratio = min(1.0, validation_results["sales_data"]["record_count"] / validation_results["sales_data"]["coverage_days"])
validation_results["sales_data"]["quality_score"] = coverage_ratio * 100
# Weather data quality
if validation_results["weather_data"]["record_count"] > 0:
expected_weather_records = (aligned_range.end - aligned_range.start).days
coverage_ratio = min(1.0, validation_results["weather_data"]["record_count"] / expected_weather_records)
validation_results["weather_data"]["quality_score"] = coverage_ratio * 100
# Traffic data quality
if validation_results["traffic_data"]["record_count"] > 0:
expected_traffic_records = (aligned_range.end - aligned_range.start).days
coverage_ratio = min(1.0, validation_results["traffic_data"]["record_count"] / expected_traffic_records)
validation_results["traffic_data"]["quality_score"] = coverage_ratio * 100
# Overall quality score (weighted by importance)
weights = {"sales_data": 0.7, "weather_data": 0.2, "traffic_data": 0.1}
overall_score = sum(
validation_results[source]["quality_score"] * weight
for source, weight in weights.items()
)
validation_results["overall_quality_score"] = round(overall_score, 2)
return validation_results
def _generate_synthetic_weather_data(
self,
aligned_range: AlignedDateRange
) -> List[Dict[str, Any]]:
"""Generate realistic synthetic weather data for Madrid"""
synthetic_data = []
current_date = aligned_range.start
# Madrid seasonal temperature patterns
seasonal_temps = {
1: 9, 2: 11, 3: 15, 4: 17, 5: 21, 6: 26,
7: 29, 8: 28, 9: 24, 10: 18, 11: 12, 12: 9
}
while current_date <= aligned_range.end:
month = current_date.month
base_temp = seasonal_temps.get(month, 15)
# Add some realistic variation
import random
temp_variation = random.gauss(0, 3) # ±3°C variation
temperature = max(0, base_temp + temp_variation)
# Precipitation patterns (Madrid is relatively dry)
precipitation = 0.0
if random.random() < 0.15: # 15% chance of rain
precipitation = random.uniform(0.1, 15.0)
synthetic_data.append({
"date": current_date,
"temperature": round(temperature, 1),
"precipitation": round(precipitation, 1),
"humidity": round(random.uniform(40, 80), 1),
"wind_speed": round(random.uniform(2, 15), 1),
"pressure": round(random.uniform(1005, 1025), 1),
"source": "synthetic_madrid_pattern"
})
current_date = current_date + timedelta(days=1)
logger.info(f"Generated {len(synthetic_data)} synthetic weather records with Madrid patterns")
return synthetic_data
def validate_training_data_quality(self, dataset: TrainingDataSet) -> Dict[str, Any]:
"""Enhanced validation of training data quality"""
validation_results = {
"is_valid": True,
"warnings": [],
"errors": [],
"data_quality_score": 100.0,
"recommendations": []
}
# Check sales data completeness
sales_count = len(dataset.sales_data)
if sales_count < 30:
validation_results["warnings"].append(
f"Limited sales data: {sales_count} records (recommended: 30+)"
)
validation_results["data_quality_score"] -= 20
validation_results["recommendations"].append("Consider collecting more historical sales data")
elif sales_count < 90:
validation_results["warnings"].append(
f"Moderate sales data: {sales_count} records (optimal: 90+)"
)
validation_results["data_quality_score"] -= 10
# Check date coverage
date_coverage = (dataset.date_range.end - dataset.date_range.start).days
if date_coverage < 90:
validation_results["warnings"].append(
f"Limited date coverage: {date_coverage} days (recommended: 90+)"
)
validation_results["data_quality_score"] -= 15
validation_results["recommendations"].append("Extend date range for better seasonality detection")
# Check external data availability
if not dataset.weather_data:
validation_results["warnings"].append("No weather data available")
validation_results["data_quality_score"] -= 10
validation_results["recommendations"].append("Weather data improves forecast accuracy")
elif len(dataset.weather_data) < date_coverage * 0.5:
validation_results["warnings"].append("Sparse weather data coverage")
validation_results["data_quality_score"] -= 5
if not dataset.traffic_data:
validation_results["warnings"].append("No traffic data available")
validation_results["data_quality_score"] -= 5
validation_results["recommendations"].append("Traffic data can help with location-based patterns")
# Check data consistency
unique_products = set()
for record in dataset.sales_data:
if 'product_name' in record:
unique_products.add(record['product_name'])
if len(unique_products) == 0:
validation_results["errors"].append("No product names found in sales data")
validation_results["is_valid"] = False
elif len(unique_products) > 50:
validation_results["warnings"].append(
f"Many products detected ({len(unique_products)}). Consider training models in batches."
)
validation_results["recommendations"].append("Group similar products for better training efficiency")
# Check for data source constraints
if dataset.date_range.constraints:
constraint_info = []
for constraint_type, message in dataset.date_range.constraints.items():
constraint_info.append(f"{constraint_type}: {message}")
validation_results["warnings"].append(
f"Data source constraints applied: {'; '.join(constraint_info)}"
)
# Final validation
if validation_results["errors"]:
validation_results["is_valid"] = False
validation_results["data_quality_score"] = 0.0
# Ensure score doesn't go below 0
validation_results["data_quality_score"] = max(0.0, validation_results["data_quality_score"])
# Add quality assessment
score = validation_results["data_quality_score"]
if score >= 80:
validation_results["quality_assessment"] = "Excellent"
elif score >= 60:
validation_results["quality_assessment"] = "Good"
elif score >= 40:
validation_results["quality_assessment"] = "Fair"
else:
validation_results["quality_assessment"] = "Poor"
return validation_results
def get_data_collection_plan(self, aligned_range: AlignedDateRange) -> Dict[str, Dict]:
"""
Generate an enhanced data collection plan based on the aligned date range.
"""
plan = {
"collection_summary": {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"duration_days": (aligned_range.end - aligned_range.start).days,
"available_sources": [source.value for source in aligned_range.available_sources],
"constraints": aligned_range.constraints
},
"data_sources": {}
}
# Bakery Sales Data
if DataSourceType.BAKERY_SALES in aligned_range.available_sources:
plan["data_sources"]["sales_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "user_upload",
"required": True,
"priority": "high",
"expected_records": "variable",
"data_points": ["date", "product_name", "quantity"],
"validation": "required_fields_check"
}
# Madrid Traffic Data
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
plan["data_sources"]["traffic_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "madrid_opendata",
"required": False,
"priority": "medium",
"expected_records": (aligned_range.end - aligned_range.start).days,
"constraint": "Cannot request current month data",
"data_points": ["date", "traffic_volume", "congestion_level"],
"validation": "date_constraint_check"
}
# Weather Data
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
plan["data_sources"]["weather_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "aemet_api",
"required": False,
"priority": "high",
"expected_records": (aligned_range.end - aligned_range.start).days,
"constraint": "Available from yesterday backward",
"data_points": ["date", "temperature", "precipitation", "humidity"],
"validation": "temporal_constraint_check",
"fallback": "synthetic_madrid_weather"
}
return plan
def get_orchestration_summary(self, dataset: TrainingDataSet) -> Dict[str, Any]:
"""
Generate a comprehensive summary of the orchestration process.
"""
return {
"tenant_id": dataset.metadata.get("tenant_id"),
"job_id": dataset.metadata.get("job_id"),
"orchestration_completed_at": dataset.metadata.get("preparation_timestamp"),
"data_alignment": {
"original_range": dataset.metadata.get("original_sales_range"),
"aligned_range": {
"start": dataset.date_range.start.isoformat(),
"end": dataset.date_range.end.isoformat(),
"duration_days": (dataset.date_range.end - dataset.date_range.start).days
},
"constraints_applied": dataset.date_range.constraints,
"available_sources": [source.value for source in dataset.date_range.available_sources]
},
"data_collection_results": {
"sales_records": len(dataset.sales_data),
"weather_records": len(dataset.weather_data),
"traffic_records": len(dataset.traffic_data),
"total_records": len(dataset.sales_data) + len(dataset.weather_data) + len(dataset.traffic_data)
},
"data_quality": dataset.metadata.get("data_quality", {}),
"validation_results": dataset.metadata.get("final_validation", {}),
"processing_metadata": {
"bakery_location": dataset.metadata.get("bakery_location"),
"data_sources_requested": len(dataset.date_range.available_sources),
"data_sources_successful": sum([
1 if len(dataset.sales_data) > 0 else 0,
1 if len(dataset.weather_data) > 0 else 0,
1 if len(dataset.traffic_data) > 0 else 0
])
}
}