Files
bakery-ia/services/training/app/services/training_orchestrator.py
2025-08-15 17:53:59 +02:00

896 lines
40 KiB
Python

# services/training/app/services/training_orchestrator.py
"""
Training Data Orchestrator - Enhanced Integration Layer
Orchestrates data collection, date alignment, and preparation for ML training
"""
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
import asyncio
import structlog
from concurrent.futures import ThreadPoolExecutor
from datetime import timezone
import pandas as pd
from app.services.data_client import DataClient
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
publish_data_validation_completed,
publish_job_step_completed,
publish_job_failed
)
logger = structlog.get_logger()
@dataclass
class TrainingDataSet:
"""Container for all training data with metadata"""
sales_data: List[Dict[str, Any]]
weather_data: List[Dict[str, Any]]
traffic_data: List[Dict[str, Any]]
date_range: AlignedDateRange
metadata: Dict[str, Any]
class TrainingDataOrchestrator:
"""
Enhanced orchestrator for data collection from multiple sources.
Ensures date alignment, handles data source constraints, and prepares data for ML training.
Uses the new abstracted traffic service layer for multi-city support.
"""
def __init__(self,
date_alignment_service: DateAlignmentService = None):
self.data_client = DataClient()
self.date_alignment_service = date_alignment_service or DateAlignmentService()
self.max_concurrent_requests = 5 # Increased for better performance
async def prepare_training_data(
self,
tenant_id: str,
bakery_location: Tuple[float, float], # (lat, lon)
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None,
job_id: Optional[str] = None
) -> TrainingDataSet:
"""
Main method to prepare all training data with comprehensive date alignment.
Args:
tenant_id: Tenant identifier
sales_data: User-provided sales data
bakery_location: Bakery coordinates (lat, lon)
requested_start: Optional explicit start date
requested_end: Optional explicit end date
job_id: Training job identifier for logging
Returns:
TrainingDataSet with all aligned and validated data
"""
logger.info(f"Starting comprehensive training data preparation for tenant {tenant_id}, job {job_id}")
try:
sales_data = await self.data_client.fetch_sales_data(tenant_id)
# Step 1: Extract and validate sales data date range
sales_date_range = self._extract_sales_date_range(sales_data)
logger.info(f"Sales data range detected: {sales_date_range.start} to {sales_date_range.end}")
# Step 2: Apply date alignment across all data sources
aligned_range = self.date_alignment_service.validate_and_align_dates(
user_sales_range=sales_date_range,
requested_start=requested_start,
requested_end=requested_end
)
logger.info(f"Date alignment completed: {aligned_range.start} to {aligned_range.end}")
if aligned_range.constraints:
logger.info(f"Applied constraints: {aligned_range.constraints}")
# Step 3: Filter sales data to aligned date range
filtered_sales = self._filter_sales_data(sales_data, aligned_range)
# Step 4: Collect external data sources concurrently
logger.info("Collecting external data sources...")
weather_data, traffic_data = await self._collect_external_data(
aligned_range, bakery_location, tenant_id
)
# Step 5: Validate data quality
data_quality_results = self._validate_data_sources(
filtered_sales, weather_data, traffic_data, aligned_range
)
# Step 6: Create comprehensive training dataset
training_dataset = TrainingDataSet(
sales_data=filtered_sales,
weather_data=weather_data,
traffic_data=traffic_data,
date_range=aligned_range,
metadata={
"tenant_id": tenant_id,
"job_id": job_id,
"bakery_location": bakery_location,
"data_sources_used": aligned_range.available_sources,
"constraints_applied": aligned_range.constraints,
"data_quality": data_quality_results,
"preparation_timestamp": datetime.now().isoformat(),
"original_sales_range": {
"start": sales_date_range.start.isoformat(),
"end": sales_date_range.end.isoformat()
}
}
)
# Step 7: Final validation
final_validation = self.validate_training_data_quality(training_dataset)
training_dataset.metadata["final_validation"] = final_validation
logger.info(f"Training data preparation completed successfully:")
logger.info(f" - Sales records: {len(filtered_sales)}")
logger.info(f" - Weather records: {len(weather_data)}")
logger.info(f" - Traffic records: {len(traffic_data)}")
logger.info(f" - Data quality score: {final_validation.get('data_quality_score', 'N/A')}")
return training_dataset
except Exception as e:
publish_job_failed(job_id, tenant_id, str(e))
logger.error(f"Training data preparation failed: {str(e)}")
raise ValueError(f"Failed to prepare training data: {str(e)}")
def _extract_sales_date_range(self, sales_data: List[Dict[str, Any]]) -> DateRange:
"""Extract date range from sales data with timezone handling"""
if not sales_data:
raise ValueError("No sales data provided")
dates = []
for record in sales_data:
date_value = record.get('date')
if date_value:
# ✅ FIX: Ensure timezone-aware datetime
if isinstance(date_value, str):
dt = pd.to_datetime(date_value)
if dt.tz is None:
dt = dt.replace(tzinfo=timezone.utc)
dates.append(dt.to_pydatetime())
elif isinstance(date_value, datetime):
if date_value.tzinfo is None:
date_value = date_value.replace(tzinfo=timezone.utc)
dates.append(date_value)
if not dates:
raise ValueError("No valid dates found in sales data")
start_date = min(dates)
end_date = max(dates)
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
def _filter_sales_data(
self,
sales_data: List[Dict[str, Any]],
aligned_range: AlignedDateRange
) -> List[Dict[str, Any]]:
"""Filter sales data to the aligned date range with enhanced validation"""
filtered_data = []
filtered_count = 0
for record in sales_data:
try:
if 'date' in record:
record_date = record['date']
# ✅ FIX: Proper timezone handling for date parsing
if isinstance(record_date, str):
if 'T' in record_date:
record_date = record_date.replace('Z', '+00:00')
# Parse with timezone info intact
parsed_date = datetime.fromisoformat(record_date.split('T')[0])
# Ensure timezone-aware
if parsed_date.tzinfo is None:
parsed_date = parsed_date.replace(tzinfo=timezone.utc)
record_date = parsed_date
elif isinstance(record_date, datetime):
# Ensure timezone-aware
if record_date.tzinfo is None:
record_date = record_date.replace(tzinfo=timezone.utc)
# Normalize to start of day
record_date = record_date.replace(hour=0, minute=0, second=0, microsecond=0)
# ✅ FIX: Ensure aligned_range dates are also timezone-aware for comparison
aligned_start = aligned_range.start
aligned_end = aligned_range.end
if aligned_start.tzinfo is None:
aligned_start = aligned_start.replace(tzinfo=timezone.utc)
if aligned_end.tzinfo is None:
aligned_end = aligned_end.replace(tzinfo=timezone.utc)
# Check if date falls within aligned range (now both are timezone-aware)
if aligned_start <= record_date <= aligned_end:
# Validate that record has required fields
if self._validate_sales_record(record):
filtered_data.append(record)
else:
filtered_count += 1
else:
# Record outside date range
filtered_count += 1
except Exception as e:
logger.warning(f"Error processing sales record: {str(e)}")
filtered_count += 1
continue
logger.info(f"Filtered sales data: {len(filtered_data)} records in aligned range")
if filtered_count > 0:
logger.warning(f"Filtered out {filtered_count} invalid records")
return filtered_data
def _validate_sales_record(self, record: Dict[str, Any]) -> bool:
"""Validate individual sales record"""
required_fields = ['date', 'inventory_product_id']
quantity_fields = ['quantity', 'quantity_sold', 'sales', 'units_sold']
# Check required fields
for field in required_fields:
if field not in record or record[field] is None:
return False
# Check at least one quantity field exists
has_quantity = any(field in record and record[field] is not None for field in quantity_fields)
if not has_quantity:
return False
# Validate quantity is numeric and non-negative
for field in quantity_fields:
if field in record and record[field] is not None:
try:
quantity = float(record[field])
if quantity < 0:
return False
except (ValueError, TypeError):
return False
break
return True
async def _collect_external_data(
self,
aligned_range: AlignedDateRange,
bakery_location: Tuple[float, float],
tenant_id: str
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Collect weather and traffic data concurrently with enhanced error handling"""
lat, lon = bakery_location
# Create collection tasks with timeout
tasks = []
# Weather data collection
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
weather_task = asyncio.create_task(
self._collect_weather_data_with_timeout(lat, lon, aligned_range, tenant_id)
)
tasks.append(("weather", weather_task))
# Enhanced Traffic data collection (supports multiple cities)
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
logger.info(f"🚛 Traffic data source available for multiple cities, creating collection task for date range: {aligned_range.start} to {aligned_range.end}")
traffic_task = asyncio.create_task(
self._collect_traffic_data_with_timeout_enhanced(lat, lon, aligned_range, tenant_id)
)
tasks.append(("traffic", traffic_task))
else:
logger.warning(f"🚫 Traffic data source NOT available in sources: {[s.value for s in aligned_range.available_sources]}")
# Execute tasks concurrently with proper error handling
results = {}
if tasks:
try:
completed_tasks = await asyncio.gather(
*[task for _, task in tasks],
return_exceptions=True
)
for i, (task_name, _) in enumerate(tasks):
result = completed_tasks[i]
if isinstance(result, Exception):
logger.warning(f"{task_name} data collection failed: {result}")
results[task_name] = []
else:
results[task_name] = result
logger.info(f"{task_name} data collection completed: {len(result)} records")
except Exception as e:
logger.error(f"Error in concurrent data collection: {str(e)}")
results = {"weather": [], "traffic": []}
weather_data = results.get("weather", [])
traffic_data = results.get("traffic", [])
return weather_data, traffic_data
async def _collect_weather_data_with_timeout(
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Collect weather data with timeout and fallback"""
try:
start_date_str = aligned_range.start.isoformat()
end_date_str = aligned_range.end.isoformat()
weather_data = await self.data_client.fetch_weather_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=lat,
longitude=lon)
# Validate weather data
if self._validate_weather_data(weather_data):
logger.info(f"Collected {len(weather_data)} valid weather records")
return weather_data
else:
logger.warning("Invalid weather data received, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
except asyncio.TimeoutError:
logger.warning(f"Weather data collection timed out, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
except Exception as e:
logger.warning(f"Weather data collection failed: {e}, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
async def _collect_traffic_data_with_timeout_enhanced(
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""
Enhanced traffic data collection with multi-city support and improved storage
Uses the new abstracted traffic service layer
"""
try:
# Double-check constraints before making request
constraint_violated = self.date_alignment_service.check_madrid_current_month_constraint(aligned_range.end)
if constraint_violated:
logger.warning(f"🚫 Current month constraint violation: end_date={aligned_range.end}, no traffic data available")
return []
else:
logger.info(f"✅ Date constraints passed: end_date={aligned_range.end}, proceeding with traffic data request")
start_date_str = aligned_range.start.isoformat()
end_date_str = aligned_range.end.isoformat()
# Enhanced: Fetch traffic data using new abstracted service
# This automatically detects the appropriate city and uses the right client
traffic_data = await self.data_client.fetch_traffic_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=lat,
longitude=lon)
# Enhanced validation including pedestrian inference data
if self._validate_traffic_data_enhanced(traffic_data):
logger.info(f"Collected and stored {len(traffic_data)} valid enhanced traffic records for re-training")
# Log storage success with enhanced metadata
self._log_enhanced_traffic_data_storage(lat, lon, aligned_range, len(traffic_data), traffic_data)
return traffic_data
else:
logger.warning("Invalid enhanced traffic data received")
return []
except asyncio.TimeoutError:
logger.warning(f"Enhanced traffic data collection timed out")
return []
except Exception as e:
logger.warning(f"Enhanced traffic data collection failed: {e}")
return []
# Keep original method for backwards compatibility
async def _collect_traffic_data_with_timeout(
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Legacy traffic data collection method - redirects to enhanced version"""
return await self._collect_traffic_data_with_timeout_enhanced(lat, lon, aligned_range, tenant_id)
def _log_enhanced_traffic_data_storage(self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
record_count: int,
traffic_data: List[Dict[str, Any]]):
"""Enhanced logging for traffic data storage with detailed metadata"""
# Analyze the stored data for additional insights
cities_detected = set()
has_pedestrian_data = 0
data_sources = set()
districts_covered = set()
for record in traffic_data:
if 'city' in record and record['city']:
cities_detected.add(record['city'])
if 'pedestrian_count' in record and record['pedestrian_count'] is not None:
has_pedestrian_data += 1
if 'source' in record and record['source']:
data_sources.add(record['source'])
if 'district' in record and record['district']:
districts_covered.add(record['district'])
logger.info(
"Enhanced traffic data stored for re-training",
location=f"{lat:.4f},{lon:.4f}",
date_range=f"{aligned_range.start.isoformat()} to {aligned_range.end.isoformat()}",
records_stored=record_count,
cities_detected=list(cities_detected),
pedestrian_inference_coverage=f"{has_pedestrian_data}/{record_count}",
data_sources=list(data_sources),
districts_covered=list(districts_covered),
storage_timestamp=datetime.now().isoformat(),
purpose="enhanced_model_training_and_retraining",
architecture_version="2.0_abstracted"
)
def _log_traffic_data_storage(self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
record_count: int):
"""Legacy logging method - redirects to enhanced version"""
# Create minimal traffic data structure for enhanced logging
minimal_traffic_data = [{"city": "madrid", "source": "legacy"}] * min(record_count, 1)
self._log_enhanced_traffic_data_storage(lat, lon, aligned_range, record_count, minimal_traffic_data)
async def retrieve_stored_traffic_for_retraining(
self,
bakery_location: Tuple[float, float],
start_date: datetime,
end_date: datetime,
tenant_id: str
) -> List[Dict[str, Any]]:
"""
Retrieve previously stored traffic data for model re-training
This method specifically accesses the stored traffic data without making new API calls
"""
lat, lon = bakery_location
try:
# Use the dedicated stored traffic data method for training
stored_traffic_data = await self.data_client.fetch_stored_traffic_data_for_training(
tenant_id=tenant_id,
start_date=start_date.isoformat(),
end_date=end_date.isoformat(),
latitude=lat,
longitude=lon
)
if stored_traffic_data:
logger.info(
f"Retrieved {len(stored_traffic_data)} stored traffic records for re-training",
location=f"{lat:.4f},{lon:.4f}",
date_range=f"{start_date.isoformat()} to {end_date.isoformat()}",
tenant_id=tenant_id
)
return stored_traffic_data
else:
logger.warning(
"No stored traffic data found for re-training",
location=f"{lat:.4f},{lon:.4f}",
date_range=f"{start_date.isoformat()} to {end_date.isoformat()}"
)
return []
except Exception as e:
logger.error(
f"Failed to retrieve stored traffic data for re-training: {e}",
location=f"{lat:.4f},{lon:.4f}",
tenant_id=tenant_id
)
return []
def _validate_weather_data(self, weather_data: List[Dict[str, Any]]) -> bool:
"""Validate weather data quality"""
if not weather_data:
return False
required_fields = ['date']
weather_fields = ['temperature', 'temp', 'temperatura', 'precipitation', 'rain', 'lluvia']
valid_records = 0
for record in weather_data:
# Check required fields
if not all(field in record for field in required_fields):
continue
# Check at least one weather field exists
if any(field in record and record[field] is not None for field in weather_fields):
valid_records += 1
# Consider valid if at least 50% of records are valid
validity_threshold = 0.5
is_valid = (valid_records / len(weather_data)) >= validity_threshold
if not is_valid:
logger.warning(f"Weather data validation failed: {valid_records}/{len(weather_data)} valid records")
return is_valid
def _validate_traffic_data_enhanced(self, traffic_data: List[Dict[str, Any]]) -> bool:
"""Enhanced validation for traffic data including pedestrian inference and city-specific fields"""
if not traffic_data:
return False
required_fields = ['date']
traffic_fields = ['traffic_volume', 'traffic_intensity', 'intensidad', 'trafico']
enhanced_fields = ['pedestrian_count', 'congestion_level', 'source']
city_specific_fields = ['city', 'measurement_point_id', 'district']
valid_records = 0
enhanced_records = 0
city_aware_records = 0
for record in traffic_data:
record_score = 0
# Check required fields
if all(field in record and record[field] is not None for field in required_fields):
record_score += 1
# Check traffic data fields
if any(field in record and record[field] is not None for field in traffic_fields):
record_score += 1
# Check enhanced fields (pedestrian inference, etc.)
enhanced_count = sum(1 for field in enhanced_fields
if field in record and record[field] is not None)
if enhanced_count >= 2: # At least 2 enhanced fields
enhanced_records += 1
record_score += 1
# Check city-specific awareness
city_count = sum(1 for field in city_specific_fields
if field in record and record[field] is not None)
if city_count >= 1: # At least some city awareness
city_aware_records += 1
# Record is valid if it has basic requirements (date + any traffic field)
# Lowered requirement from >= 2 to >= 1 to accept records with just date or traffic data
if record_score >= 1:
valid_records += 1
total_records = len(traffic_data)
validity_threshold = 0.1 # Reduced from 0.3 to 0.1 - accept if 10% of records are valid
enhancement_threshold = 0.1 # Reduced threshold for enhanced features
basic_validity = (valid_records / total_records) >= validity_threshold
has_enhancements = (enhanced_records / total_records) >= enhancement_threshold
has_city_awareness = (city_aware_records / total_records) >= enhancement_threshold
logger.info("Enhanced traffic data validation results",
total_records=total_records,
valid_records=valid_records,
enhanced_records=enhanced_records,
city_aware_records=city_aware_records,
basic_validity=basic_validity,
has_enhancements=has_enhancements,
has_city_awareness=has_city_awareness)
if not basic_validity:
logger.warning(f"Traffic data basic validation failed: {valid_records}/{total_records} valid records")
return basic_validity
def _validate_traffic_data(self, traffic_data: List[Dict[str, Any]]) -> bool:
"""Legacy validation method - redirects to enhanced version"""
return self._validate_traffic_data_enhanced(traffic_data)
def _validate_data_sources(
self,
sales_data: List[Dict[str, Any]],
weather_data: List[Dict[str, Any]],
traffic_data: List[Dict[str, Any]],
aligned_range: AlignedDateRange
) -> Dict[str, Any]:
"""Validate all data sources and provide quality metrics"""
validation_results = {
"sales_data": {
"record_count": len(sales_data),
"is_valid": len(sales_data) > 0,
"coverage_days": (aligned_range.end - aligned_range.start).days,
"quality_score": 0.0
},
"weather_data": {
"record_count": len(weather_data),
"is_valid": self._validate_weather_data(weather_data) if weather_data else False,
"quality_score": 0.0
},
"traffic_data": {
"record_count": len(traffic_data),
"is_valid": self._validate_traffic_data(traffic_data) if traffic_data else False,
"quality_score": 0.0
},
"overall_quality_score": 0.0
}
# Calculate quality scores
# Sales data quality (most important)
if validation_results["sales_data"]["record_count"] > 0:
coverage_ratio = min(1.0, validation_results["sales_data"]["record_count"] / validation_results["sales_data"]["coverage_days"])
validation_results["sales_data"]["quality_score"] = coverage_ratio * 100
# Weather data quality
if validation_results["weather_data"]["record_count"] > 0:
expected_weather_records = (aligned_range.end - aligned_range.start).days
coverage_ratio = min(1.0, validation_results["weather_data"]["record_count"] / expected_weather_records)
validation_results["weather_data"]["quality_score"] = coverage_ratio * 100
# Traffic data quality
if validation_results["traffic_data"]["record_count"] > 0:
expected_traffic_records = (aligned_range.end - aligned_range.start).days
coverage_ratio = min(1.0, validation_results["traffic_data"]["record_count"] / expected_traffic_records)
validation_results["traffic_data"]["quality_score"] = coverage_ratio * 100
# Overall quality score (weighted by importance)
weights = {"sales_data": 0.7, "weather_data": 0.2, "traffic_data": 0.1}
overall_score = sum(
validation_results[source]["quality_score"] * weight
for source, weight in weights.items()
)
validation_results["overall_quality_score"] = round(overall_score, 2)
return validation_results
def _generate_synthetic_weather_data(
self,
aligned_range: AlignedDateRange
) -> List[Dict[str, Any]]:
"""Generate realistic synthetic weather data for Madrid"""
synthetic_data = []
current_date = aligned_range.start
# Madrid seasonal temperature patterns
seasonal_temps = {
1: 9, 2: 11, 3: 15, 4: 17, 5: 21, 6: 26,
7: 29, 8: 28, 9: 24, 10: 18, 11: 12, 12: 9
}
while current_date <= aligned_range.end:
month = current_date.month
base_temp = seasonal_temps.get(month, 15)
# Add some realistic variation
import random
temp_variation = random.gauss(0, 3) # ±3°C variation
temperature = max(0, base_temp + temp_variation)
# Precipitation patterns (Madrid is relatively dry)
precipitation = 0.0
if random.random() < 0.15: # 15% chance of rain
precipitation = random.uniform(0.1, 15.0)
synthetic_data.append({
"date": current_date,
"temperature": round(temperature, 1),
"precipitation": round(precipitation, 1),
"humidity": round(random.uniform(40, 80), 1),
"wind_speed": round(random.uniform(2, 15), 1),
"pressure": round(random.uniform(1005, 1025), 1),
"source": "synthetic_madrid_pattern"
})
current_date = current_date + timedelta(days=1)
logger.info(f"Generated {len(synthetic_data)} synthetic weather records with Madrid patterns")
return synthetic_data
def validate_training_data_quality(self, dataset: TrainingDataSet) -> Dict[str, Any]:
"""Enhanced validation of training data quality"""
validation_results = {
"is_valid": True,
"warnings": [],
"errors": [],
"data_quality_score": 100.0,
"recommendations": []
}
# Check sales data completeness
sales_count = len(dataset.sales_data)
if sales_count < 30:
validation_results["warnings"].append(
f"Limited sales data: {sales_count} records (recommended: 30+)"
)
validation_results["data_quality_score"] -= 20
validation_results["recommendations"].append("Consider collecting more historical sales data")
elif sales_count < 90:
validation_results["warnings"].append(
f"Moderate sales data: {sales_count} records (optimal: 90+)"
)
validation_results["data_quality_score"] -= 10
# Check date coverage
date_coverage = (dataset.date_range.end - dataset.date_range.start).days
if date_coverage < 90:
validation_results["warnings"].append(
f"Limited date coverage: {date_coverage} days (recommended: 90+)"
)
validation_results["data_quality_score"] -= 15
validation_results["recommendations"].append("Extend date range for better seasonality detection")
# Check external data availability
if not dataset.weather_data:
validation_results["warnings"].append("No weather data available")
validation_results["data_quality_score"] -= 10
validation_results["recommendations"].append("Weather data improves forecast accuracy")
elif len(dataset.weather_data) < date_coverage * 0.5:
validation_results["warnings"].append("Sparse weather data coverage")
validation_results["data_quality_score"] -= 5
if not dataset.traffic_data:
validation_results["warnings"].append("No traffic data available")
validation_results["data_quality_score"] -= 5
validation_results["recommendations"].append("Traffic data can help with location-based patterns")
# Check data consistency
unique_products = set()
for record in dataset.sales_data:
if 'inventory_product_id' in record:
unique_products.add(record['inventory_product_id'])
if len(unique_products) == 0:
validation_results["errors"].append("No product names found in sales data")
validation_results["is_valid"] = False
elif len(unique_products) > 50:
validation_results["warnings"].append(
f"Many products detected ({len(unique_products)}). Consider training models in batches."
)
validation_results["recommendations"].append("Group similar products for better training efficiency")
# Check for data source constraints
if dataset.date_range.constraints:
constraint_info = []
for constraint_type, message in dataset.date_range.constraints.items():
constraint_info.append(f"{constraint_type}: {message}")
validation_results["warnings"].append(
f"Data source constraints applied: {'; '.join(constraint_info)}"
)
# Final validation
if validation_results["errors"]:
validation_results["is_valid"] = False
validation_results["data_quality_score"] = 0.0
# Ensure score doesn't go below 0
validation_results["data_quality_score"] = max(0.0, validation_results["data_quality_score"])
# Add quality assessment
score = validation_results["data_quality_score"]
if score >= 80:
validation_results["quality_assessment"] = "Excellent"
elif score >= 60:
validation_results["quality_assessment"] = "Good"
elif score >= 40:
validation_results["quality_assessment"] = "Fair"
else:
validation_results["quality_assessment"] = "Poor"
return validation_results
def get_data_collection_plan(self, aligned_range: AlignedDateRange) -> Dict[str, Dict]:
"""
Generate an enhanced data collection plan based on the aligned date range.
"""
plan = {
"collection_summary": {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"duration_days": (aligned_range.end - aligned_range.start).days,
"available_sources": [source.value for source in aligned_range.available_sources],
"constraints": aligned_range.constraints
},
"data_sources": {}
}
# Bakery Sales Data
if DataSourceType.BAKERY_SALES in aligned_range.available_sources:
plan["data_sources"]["sales_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "user_upload",
"required": True,
"priority": "high",
"expected_records": "variable",
"data_points": ["date", "inventory_product_id", "quantity"],
"validation": "required_fields_check"
}
# Madrid Traffic Data
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
plan["data_sources"]["traffic_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "madrid_opendata",
"required": False,
"priority": "medium",
"expected_records": (aligned_range.end - aligned_range.start).days,
"constraint": "Cannot request current month data",
"data_points": ["date", "traffic_volume", "congestion_level"],
"validation": "date_constraint_check"
}
# Weather Data
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
plan["data_sources"]["weather_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "aemet_api",
"required": False,
"priority": "high",
"expected_records": (aligned_range.end - aligned_range.start).days,
"constraint": "Available from yesterday backward",
"data_points": ["date", "temperature", "precipitation", "humidity"],
"validation": "temporal_constraint_check",
"fallback": "synthetic_madrid_weather"
}
return plan
def get_orchestration_summary(self, dataset: TrainingDataSet) -> Dict[str, Any]:
"""
Generate a comprehensive summary of the orchestration process.
"""
return {
"tenant_id": dataset.metadata.get("tenant_id"),
"job_id": dataset.metadata.get("job_id"),
"orchestration_completed_at": dataset.metadata.get("preparation_timestamp"),
"data_alignment": {
"original_range": dataset.metadata.get("original_sales_range"),
"aligned_range": {
"start": dataset.date_range.start.isoformat(),
"end": dataset.date_range.end.isoformat(),
"duration_days": (dataset.date_range.end - dataset.date_range.start).days
},
"constraints_applied": dataset.date_range.constraints,
"available_sources": [source.value for source in dataset.date_range.available_sources]
},
"data_collection_results": {
"sales_records": len(dataset.sales_data),
"weather_records": len(dataset.weather_data),
"traffic_records": len(dataset.traffic_data),
"total_records": len(dataset.sales_data) + len(dataset.weather_data) + len(dataset.traffic_data)
},
"data_quality": dataset.metadata.get("data_quality", {}),
"validation_results": dataset.metadata.get("final_validation", {}),
"processing_metadata": {
"bakery_location": dataset.metadata.get("bakery_location"),
"data_sources_requested": len(dataset.date_range.available_sources),
"data_sources_successful": sum([
1 if len(dataset.sales_data) > 0 else 0,
1 if len(dataset.weather_data) > 0 else 0,
1 if len(dataset.traffic_data) > 0 else 0
])
}
}