895 lines
39 KiB
Python
895 lines
39 KiB
Python
# services/training/app/services/training_orchestrator.py
|
|
"""
|
|
Training Data Orchestrator - Enhanced Integration Layer
|
|
Orchestrates data collection, date alignment, and preparation for ML training
|
|
"""
|
|
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Tuple
|
|
from dataclasses import dataclass
|
|
import asyncio
|
|
import structlog
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from datetime import timezone
|
|
import pandas as pd
|
|
|
|
from app.services.data_client import DataClient
|
|
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
|
|
|
|
from app.services.messaging import (
|
|
publish_job_progress,
|
|
publish_data_validation_started,
|
|
publish_data_validation_completed,
|
|
publish_job_step_completed,
|
|
publish_job_failed
|
|
)
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
@dataclass
|
|
class TrainingDataSet:
|
|
"""Container for all training data with metadata"""
|
|
sales_data: List[Dict[str, Any]]
|
|
weather_data: List[Dict[str, Any]]
|
|
traffic_data: List[Dict[str, Any]]
|
|
date_range: AlignedDateRange
|
|
metadata: Dict[str, Any]
|
|
|
|
class TrainingDataOrchestrator:
|
|
"""
|
|
Enhanced orchestrator for data collection from multiple sources.
|
|
Ensures date alignment, handles data source constraints, and prepares data for ML training.
|
|
Uses the new abstracted traffic service layer for multi-city support.
|
|
"""
|
|
|
|
def __init__(self,
|
|
date_alignment_service: DateAlignmentService = None):
|
|
self.data_client = DataClient()
|
|
self.date_alignment_service = date_alignment_service or DateAlignmentService()
|
|
self.max_concurrent_requests = 5 # Increased for better performance
|
|
|
|
async def prepare_training_data(
|
|
self,
|
|
tenant_id: str,
|
|
bakery_location: Tuple[float, float], # (lat, lon)
|
|
requested_start: Optional[datetime] = None,
|
|
requested_end: Optional[datetime] = None,
|
|
job_id: Optional[str] = None
|
|
) -> TrainingDataSet:
|
|
"""
|
|
Main method to prepare all training data with comprehensive date alignment.
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
sales_data: User-provided sales data
|
|
bakery_location: Bakery coordinates (lat, lon)
|
|
requested_start: Optional explicit start date
|
|
requested_end: Optional explicit end date
|
|
job_id: Training job identifier for logging
|
|
|
|
Returns:
|
|
TrainingDataSet with all aligned and validated data
|
|
"""
|
|
logger.info(f"Starting comprehensive training data preparation for tenant {tenant_id}, job {job_id}")
|
|
|
|
try:
|
|
|
|
sales_data = await self.data_client.fetch_sales_data(tenant_id)
|
|
|
|
# Step 1: Extract and validate sales data date range
|
|
sales_date_range = self._extract_sales_date_range(sales_data)
|
|
logger.info(f"Sales data range detected: {sales_date_range.start} to {sales_date_range.end}")
|
|
|
|
# Step 2: Apply date alignment across all data sources
|
|
aligned_range = self.date_alignment_service.validate_and_align_dates(
|
|
user_sales_range=sales_date_range,
|
|
requested_start=requested_start,
|
|
requested_end=requested_end
|
|
)
|
|
|
|
logger.info(f"Date alignment completed: {aligned_range.start} to {aligned_range.end}")
|
|
if aligned_range.constraints:
|
|
logger.info(f"Applied constraints: {aligned_range.constraints}")
|
|
|
|
# Step 3: Filter sales data to aligned date range
|
|
filtered_sales = self._filter_sales_data(sales_data, aligned_range)
|
|
|
|
# Step 4: Collect external data sources concurrently
|
|
logger.info("Collecting external data sources...")
|
|
weather_data, traffic_data = await self._collect_external_data(
|
|
aligned_range, bakery_location, tenant_id
|
|
)
|
|
|
|
# Step 5: Validate data quality
|
|
data_quality_results = self._validate_data_sources(
|
|
filtered_sales, weather_data, traffic_data, aligned_range
|
|
)
|
|
|
|
# Step 6: Create comprehensive training dataset
|
|
training_dataset = TrainingDataSet(
|
|
sales_data=filtered_sales,
|
|
weather_data=weather_data,
|
|
traffic_data=traffic_data,
|
|
date_range=aligned_range,
|
|
metadata={
|
|
"tenant_id": tenant_id,
|
|
"job_id": job_id,
|
|
"bakery_location": bakery_location,
|
|
"data_sources_used": aligned_range.available_sources,
|
|
"constraints_applied": aligned_range.constraints,
|
|
"data_quality": data_quality_results,
|
|
"preparation_timestamp": datetime.now().isoformat(),
|
|
"original_sales_range": {
|
|
"start": sales_date_range.start.isoformat(),
|
|
"end": sales_date_range.end.isoformat()
|
|
}
|
|
}
|
|
)
|
|
|
|
# Step 7: Final validation
|
|
final_validation = self.validate_training_data_quality(training_dataset)
|
|
training_dataset.metadata["final_validation"] = final_validation
|
|
|
|
logger.info(f"Training data preparation completed successfully:")
|
|
logger.info(f" - Sales records: {len(filtered_sales)}")
|
|
logger.info(f" - Weather records: {len(weather_data)}")
|
|
logger.info(f" - Traffic records: {len(traffic_data)}")
|
|
logger.info(f" - Data quality score: {final_validation.get('data_quality_score', 'N/A')}")
|
|
|
|
return training_dataset
|
|
|
|
except Exception as e:
|
|
publish_job_failed(job_id, tenant_id, str(e))
|
|
logger.error(f"Training data preparation failed: {str(e)}")
|
|
raise ValueError(f"Failed to prepare training data: {str(e)}")
|
|
|
|
def _extract_sales_date_range(self, sales_data: List[Dict[str, Any]]) -> DateRange:
|
|
"""Extract date range from sales data with timezone handling"""
|
|
if not sales_data:
|
|
raise ValueError("No sales data provided")
|
|
|
|
dates = []
|
|
for record in sales_data:
|
|
date_value = record.get('date')
|
|
if date_value:
|
|
# ✅ FIX: Ensure timezone-aware datetime
|
|
if isinstance(date_value, str):
|
|
dt = pd.to_datetime(date_value)
|
|
if dt.tz is None:
|
|
dt = dt.replace(tzinfo=timezone.utc)
|
|
dates.append(dt.to_pydatetime())
|
|
elif isinstance(date_value, datetime):
|
|
if date_value.tzinfo is None:
|
|
date_value = date_value.replace(tzinfo=timezone.utc)
|
|
dates.append(date_value)
|
|
|
|
if not dates:
|
|
raise ValueError("No valid dates found in sales data")
|
|
|
|
start_date = min(dates)
|
|
end_date = max(dates)
|
|
|
|
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
|
|
|
|
def _filter_sales_data(
|
|
self,
|
|
sales_data: List[Dict[str, Any]],
|
|
aligned_range: AlignedDateRange
|
|
) -> List[Dict[str, Any]]:
|
|
"""Filter sales data to the aligned date range with enhanced validation"""
|
|
filtered_data = []
|
|
filtered_count = 0
|
|
|
|
for record in sales_data:
|
|
try:
|
|
if 'date' in record:
|
|
record_date = record['date']
|
|
|
|
# ✅ FIX: Proper timezone handling for date parsing
|
|
if isinstance(record_date, str):
|
|
if 'T' in record_date:
|
|
record_date = record_date.replace('Z', '+00:00')
|
|
# Parse with timezone info intact
|
|
parsed_date = datetime.fromisoformat(record_date.split('T')[0])
|
|
# Ensure timezone-aware
|
|
if parsed_date.tzinfo is None:
|
|
parsed_date = parsed_date.replace(tzinfo=timezone.utc)
|
|
record_date = parsed_date
|
|
elif isinstance(record_date, datetime):
|
|
# Ensure timezone-aware
|
|
if record_date.tzinfo is None:
|
|
record_date = record_date.replace(tzinfo=timezone.utc)
|
|
# Normalize to start of day
|
|
record_date = record_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
|
|
# ✅ FIX: Ensure aligned_range dates are also timezone-aware for comparison
|
|
aligned_start = aligned_range.start
|
|
aligned_end = aligned_range.end
|
|
|
|
if aligned_start.tzinfo is None:
|
|
aligned_start = aligned_start.replace(tzinfo=timezone.utc)
|
|
if aligned_end.tzinfo is None:
|
|
aligned_end = aligned_end.replace(tzinfo=timezone.utc)
|
|
|
|
# Check if date falls within aligned range (now both are timezone-aware)
|
|
if aligned_start <= record_date <= aligned_end:
|
|
# Validate that record has required fields
|
|
if self._validate_sales_record(record):
|
|
filtered_data.append(record)
|
|
else:
|
|
filtered_count += 1
|
|
else:
|
|
# Record outside date range
|
|
filtered_count += 1
|
|
except Exception as e:
|
|
logger.warning(f"Error processing sales record: {str(e)}")
|
|
filtered_count += 1
|
|
continue
|
|
|
|
logger.info(f"Filtered sales data: {len(filtered_data)} records in aligned range")
|
|
if filtered_count > 0:
|
|
logger.warning(f"Filtered out {filtered_count} invalid records")
|
|
|
|
return filtered_data
|
|
|
|
def _validate_sales_record(self, record: Dict[str, Any]) -> bool:
|
|
"""Validate individual sales record"""
|
|
required_fields = ['date', 'product_name']
|
|
quantity_fields = ['quantity', 'quantity_sold', 'sales', 'units_sold']
|
|
|
|
# Check required fields
|
|
for field in required_fields:
|
|
if field not in record or record[field] is None:
|
|
return False
|
|
|
|
# Check at least one quantity field exists
|
|
has_quantity = any(field in record and record[field] is not None for field in quantity_fields)
|
|
if not has_quantity:
|
|
return False
|
|
|
|
# Validate quantity is numeric and non-negative
|
|
for field in quantity_fields:
|
|
if field in record and record[field] is not None:
|
|
try:
|
|
quantity = float(record[field])
|
|
if quantity < 0:
|
|
return False
|
|
except (ValueError, TypeError):
|
|
return False
|
|
break
|
|
|
|
return True
|
|
|
|
async def _collect_external_data(
|
|
self,
|
|
aligned_range: AlignedDateRange,
|
|
bakery_location: Tuple[float, float],
|
|
tenant_id: str
|
|
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
"""Collect weather and traffic data concurrently with enhanced error handling"""
|
|
|
|
lat, lon = bakery_location
|
|
|
|
# Create collection tasks with timeout
|
|
tasks = []
|
|
|
|
# Weather data collection
|
|
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
|
|
weather_task = asyncio.create_task(
|
|
self._collect_weather_data_with_timeout(lat, lon, aligned_range, tenant_id)
|
|
)
|
|
tasks.append(("weather", weather_task))
|
|
|
|
# Enhanced Traffic data collection (supports multiple cities)
|
|
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
|
|
logger.info(f"🚛 Traffic data source available for multiple cities, creating collection task for date range: {aligned_range.start} to {aligned_range.end}")
|
|
traffic_task = asyncio.create_task(
|
|
self._collect_traffic_data_with_timeout_enhanced(lat, lon, aligned_range, tenant_id)
|
|
)
|
|
tasks.append(("traffic", traffic_task))
|
|
else:
|
|
logger.warning(f"🚫 Traffic data source NOT available in sources: {[s.value for s in aligned_range.available_sources]}")
|
|
|
|
# Execute tasks concurrently with proper error handling
|
|
results = {}
|
|
if tasks:
|
|
try:
|
|
completed_tasks = await asyncio.gather(
|
|
*[task for _, task in tasks],
|
|
return_exceptions=True
|
|
)
|
|
|
|
for i, (task_name, _) in enumerate(tasks):
|
|
result = completed_tasks[i]
|
|
if isinstance(result, Exception):
|
|
logger.warning(f"{task_name} data collection failed: {result}")
|
|
results[task_name] = []
|
|
else:
|
|
results[task_name] = result
|
|
logger.info(f"{task_name} data collection completed: {len(result)} records")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in concurrent data collection: {str(e)}")
|
|
results = {"weather": [], "traffic": []}
|
|
|
|
weather_data = results.get("weather", [])
|
|
traffic_data = results.get("traffic", [])
|
|
|
|
return weather_data, traffic_data
|
|
|
|
async def _collect_weather_data_with_timeout(
|
|
self,
|
|
lat: float,
|
|
lon: float,
|
|
aligned_range: AlignedDateRange,
|
|
tenant_id: str
|
|
) -> List[Dict[str, Any]]:
|
|
"""Collect weather data with timeout and fallback"""
|
|
try:
|
|
|
|
start_date_str = aligned_range.start.isoformat()
|
|
end_date_str = aligned_range.end.isoformat()
|
|
|
|
weather_data = await self.data_client.fetch_weather_data(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date_str,
|
|
end_date=end_date_str,
|
|
latitude=lat,
|
|
longitude=lon)
|
|
|
|
# Validate weather data
|
|
if self._validate_weather_data(weather_data):
|
|
logger.info(f"Collected {len(weather_data)} valid weather records")
|
|
return weather_data
|
|
else:
|
|
logger.warning("Invalid weather data received, using synthetic data")
|
|
return self._generate_synthetic_weather_data(aligned_range)
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Weather data collection timed out, using synthetic data")
|
|
return self._generate_synthetic_weather_data(aligned_range)
|
|
except Exception as e:
|
|
logger.warning(f"Weather data collection failed: {e}, using synthetic data")
|
|
return self._generate_synthetic_weather_data(aligned_range)
|
|
|
|
async def _collect_traffic_data_with_timeout_enhanced(
|
|
self,
|
|
lat: float,
|
|
lon: float,
|
|
aligned_range: AlignedDateRange,
|
|
tenant_id: str
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Enhanced traffic data collection with multi-city support and improved storage
|
|
Uses the new abstracted traffic service layer
|
|
"""
|
|
try:
|
|
# Double-check constraints before making request
|
|
constraint_violated = self.date_alignment_service.check_madrid_current_month_constraint(aligned_range.end)
|
|
if constraint_violated:
|
|
logger.warning(f"🚫 Current month constraint violation: end_date={aligned_range.end}, no traffic data available")
|
|
return []
|
|
else:
|
|
logger.info(f"✅ Date constraints passed: end_date={aligned_range.end}, proceeding with traffic data request")
|
|
|
|
start_date_str = aligned_range.start.isoformat()
|
|
end_date_str = aligned_range.end.isoformat()
|
|
|
|
# Enhanced: Fetch traffic data using new abstracted service
|
|
# This automatically detects the appropriate city and uses the right client
|
|
traffic_data = await self.data_client.fetch_traffic_data(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date_str,
|
|
end_date=end_date_str,
|
|
latitude=lat,
|
|
longitude=lon)
|
|
|
|
# Enhanced validation including pedestrian inference data
|
|
if self._validate_traffic_data_enhanced(traffic_data):
|
|
logger.info(f"Collected and stored {len(traffic_data)} valid enhanced traffic records for re-training")
|
|
|
|
# Log storage success with enhanced metadata
|
|
self._log_enhanced_traffic_data_storage(lat, lon, aligned_range, len(traffic_data), traffic_data)
|
|
|
|
return traffic_data
|
|
else:
|
|
logger.warning("Invalid enhanced traffic data received")
|
|
return []
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Enhanced traffic data collection timed out")
|
|
return []
|
|
except Exception as e:
|
|
logger.warning(f"Enhanced traffic data collection failed: {e}")
|
|
return []
|
|
|
|
# Keep original method for backwards compatibility
|
|
async def _collect_traffic_data_with_timeout(
|
|
self,
|
|
lat: float,
|
|
lon: float,
|
|
aligned_range: AlignedDateRange,
|
|
tenant_id: str
|
|
) -> List[Dict[str, Any]]:
|
|
"""Legacy traffic data collection method - redirects to enhanced version"""
|
|
return await self._collect_traffic_data_with_timeout_enhanced(lat, lon, aligned_range, tenant_id)
|
|
|
|
def _log_enhanced_traffic_data_storage(self,
|
|
lat: float,
|
|
lon: float,
|
|
aligned_range: AlignedDateRange,
|
|
record_count: int,
|
|
traffic_data: List[Dict[str, Any]]):
|
|
"""Enhanced logging for traffic data storage with detailed metadata"""
|
|
# Analyze the stored data for additional insights
|
|
cities_detected = set()
|
|
has_pedestrian_data = 0
|
|
data_sources = set()
|
|
districts_covered = set()
|
|
|
|
for record in traffic_data:
|
|
if 'city' in record and record['city']:
|
|
cities_detected.add(record['city'])
|
|
if 'pedestrian_count' in record and record['pedestrian_count'] is not None:
|
|
has_pedestrian_data += 1
|
|
if 'source' in record and record['source']:
|
|
data_sources.add(record['source'])
|
|
if 'district' in record and record['district']:
|
|
districts_covered.add(record['district'])
|
|
|
|
logger.info(
|
|
"Enhanced traffic data stored for re-training",
|
|
location=f"{lat:.4f},{lon:.4f}",
|
|
date_range=f"{aligned_range.start.isoformat()} to {aligned_range.end.isoformat()}",
|
|
records_stored=record_count,
|
|
cities_detected=list(cities_detected),
|
|
pedestrian_inference_coverage=f"{has_pedestrian_data}/{record_count}",
|
|
data_sources=list(data_sources),
|
|
districts_covered=list(districts_covered),
|
|
storage_timestamp=datetime.now().isoformat(),
|
|
purpose="enhanced_model_training_and_retraining",
|
|
architecture_version="2.0_abstracted"
|
|
)
|
|
|
|
def _log_traffic_data_storage(self,
|
|
lat: float,
|
|
lon: float,
|
|
aligned_range: AlignedDateRange,
|
|
record_count: int):
|
|
"""Legacy logging method - redirects to enhanced version"""
|
|
# Create minimal traffic data structure for enhanced logging
|
|
minimal_traffic_data = [{"city": "madrid", "source": "legacy"}] * min(record_count, 1)
|
|
self._log_enhanced_traffic_data_storage(lat, lon, aligned_range, record_count, minimal_traffic_data)
|
|
|
|
async def retrieve_stored_traffic_for_retraining(
|
|
self,
|
|
bakery_location: Tuple[float, float],
|
|
start_date: datetime,
|
|
end_date: datetime,
|
|
tenant_id: str
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Retrieve previously stored traffic data for model re-training
|
|
This method specifically accesses the stored traffic data without making new API calls
|
|
"""
|
|
lat, lon = bakery_location
|
|
|
|
try:
|
|
# Use the dedicated stored traffic data method for training
|
|
stored_traffic_data = await self.data_client.fetch_stored_traffic_data_for_training(
|
|
tenant_id=tenant_id,
|
|
start_date=start_date.isoformat(),
|
|
end_date=end_date.isoformat(),
|
|
latitude=lat,
|
|
longitude=lon
|
|
)
|
|
|
|
if stored_traffic_data:
|
|
logger.info(
|
|
f"Retrieved {len(stored_traffic_data)} stored traffic records for re-training",
|
|
location=f"{lat:.4f},{lon:.4f}",
|
|
date_range=f"{start_date.isoformat()} to {end_date.isoformat()}",
|
|
tenant_id=tenant_id
|
|
)
|
|
|
|
return stored_traffic_data
|
|
else:
|
|
logger.warning(
|
|
"No stored traffic data found for re-training",
|
|
location=f"{lat:.4f},{lon:.4f}",
|
|
date_range=f"{start_date.isoformat()} to {end_date.isoformat()}"
|
|
)
|
|
return []
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to retrieve stored traffic data for re-training: {e}",
|
|
location=f"{lat:.4f},{lon:.4f}",
|
|
tenant_id=tenant_id
|
|
)
|
|
return []
|
|
|
|
def _validate_weather_data(self, weather_data: List[Dict[str, Any]]) -> bool:
|
|
"""Validate weather data quality"""
|
|
if not weather_data:
|
|
return False
|
|
|
|
required_fields = ['date']
|
|
weather_fields = ['temperature', 'temp', 'temperatura', 'precipitation', 'rain', 'lluvia']
|
|
|
|
valid_records = 0
|
|
for record in weather_data:
|
|
# Check required fields
|
|
if not all(field in record for field in required_fields):
|
|
continue
|
|
|
|
# Check at least one weather field exists
|
|
if any(field in record and record[field] is not None for field in weather_fields):
|
|
valid_records += 1
|
|
|
|
# Consider valid if at least 50% of records are valid
|
|
validity_threshold = 0.5
|
|
is_valid = (valid_records / len(weather_data)) >= validity_threshold
|
|
|
|
if not is_valid:
|
|
logger.warning(f"Weather data validation failed: {valid_records}/{len(weather_data)} valid records")
|
|
|
|
return is_valid
|
|
|
|
def _validate_traffic_data_enhanced(self, traffic_data: List[Dict[str, Any]]) -> bool:
|
|
"""Enhanced validation for traffic data including pedestrian inference and city-specific fields"""
|
|
if not traffic_data:
|
|
return False
|
|
|
|
required_fields = ['date']
|
|
traffic_fields = ['traffic_volume', 'traffic_intensity', 'intensidad', 'trafico']
|
|
enhanced_fields = ['pedestrian_count', 'congestion_level', 'source']
|
|
city_specific_fields = ['city', 'measurement_point_id', 'district']
|
|
|
|
valid_records = 0
|
|
enhanced_records = 0
|
|
city_aware_records = 0
|
|
|
|
for record in traffic_data:
|
|
record_score = 0
|
|
|
|
# Check required fields
|
|
if all(field in record and record[field] is not None for field in required_fields):
|
|
record_score += 1
|
|
|
|
# Check traffic data fields
|
|
if any(field in record and record[field] is not None for field in traffic_fields):
|
|
record_score += 1
|
|
|
|
# Check enhanced fields (pedestrian inference, etc.)
|
|
enhanced_count = sum(1 for field in enhanced_fields
|
|
if field in record and record[field] is not None)
|
|
if enhanced_count >= 2: # At least 2 enhanced fields
|
|
enhanced_records += 1
|
|
record_score += 1
|
|
|
|
# Check city-specific awareness
|
|
city_count = sum(1 for field in city_specific_fields
|
|
if field in record and record[field] is not None)
|
|
if city_count >= 1: # At least some city awareness
|
|
city_aware_records += 1
|
|
|
|
# Record is valid if it has basic requirements
|
|
if record_score >= 2:
|
|
valid_records += 1
|
|
|
|
total_records = len(traffic_data)
|
|
validity_threshold = 0.3
|
|
enhancement_threshold = 0.2 # Lower threshold for enhanced features
|
|
|
|
basic_validity = (valid_records / total_records) >= validity_threshold
|
|
has_enhancements = (enhanced_records / total_records) >= enhancement_threshold
|
|
has_city_awareness = (city_aware_records / total_records) >= enhancement_threshold
|
|
|
|
logger.info("Enhanced traffic data validation results",
|
|
total_records=total_records,
|
|
valid_records=valid_records,
|
|
enhanced_records=enhanced_records,
|
|
city_aware_records=city_aware_records,
|
|
basic_validity=basic_validity,
|
|
has_enhancements=has_enhancements,
|
|
has_city_awareness=has_city_awareness)
|
|
|
|
if not basic_validity:
|
|
logger.warning(f"Traffic data basic validation failed: {valid_records}/{total_records} valid records")
|
|
|
|
return basic_validity
|
|
|
|
def _validate_traffic_data(self, traffic_data: List[Dict[str, Any]]) -> bool:
|
|
"""Legacy validation method - redirects to enhanced version"""
|
|
return self._validate_traffic_data_enhanced(traffic_data)
|
|
|
|
def _validate_data_sources(
|
|
self,
|
|
sales_data: List[Dict[str, Any]],
|
|
weather_data: List[Dict[str, Any]],
|
|
traffic_data: List[Dict[str, Any]],
|
|
aligned_range: AlignedDateRange
|
|
) -> Dict[str, Any]:
|
|
"""Validate all data sources and provide quality metrics"""
|
|
|
|
validation_results = {
|
|
"sales_data": {
|
|
"record_count": len(sales_data),
|
|
"is_valid": len(sales_data) > 0,
|
|
"coverage_days": (aligned_range.end - aligned_range.start).days,
|
|
"quality_score": 0.0
|
|
},
|
|
"weather_data": {
|
|
"record_count": len(weather_data),
|
|
"is_valid": self._validate_weather_data(weather_data) if weather_data else False,
|
|
"quality_score": 0.0
|
|
},
|
|
"traffic_data": {
|
|
"record_count": len(traffic_data),
|
|
"is_valid": self._validate_traffic_data(traffic_data) if traffic_data else False,
|
|
"quality_score": 0.0
|
|
},
|
|
"overall_quality_score": 0.0
|
|
}
|
|
|
|
# Calculate quality scores
|
|
# Sales data quality (most important)
|
|
if validation_results["sales_data"]["record_count"] > 0:
|
|
coverage_ratio = min(1.0, validation_results["sales_data"]["record_count"] / validation_results["sales_data"]["coverage_days"])
|
|
validation_results["sales_data"]["quality_score"] = coverage_ratio * 100
|
|
|
|
# Weather data quality
|
|
if validation_results["weather_data"]["record_count"] > 0:
|
|
expected_weather_records = (aligned_range.end - aligned_range.start).days
|
|
coverage_ratio = min(1.0, validation_results["weather_data"]["record_count"] / expected_weather_records)
|
|
validation_results["weather_data"]["quality_score"] = coverage_ratio * 100
|
|
|
|
# Traffic data quality
|
|
if validation_results["traffic_data"]["record_count"] > 0:
|
|
expected_traffic_records = (aligned_range.end - aligned_range.start).days
|
|
coverage_ratio = min(1.0, validation_results["traffic_data"]["record_count"] / expected_traffic_records)
|
|
validation_results["traffic_data"]["quality_score"] = coverage_ratio * 100
|
|
|
|
# Overall quality score (weighted by importance)
|
|
weights = {"sales_data": 0.7, "weather_data": 0.2, "traffic_data": 0.1}
|
|
overall_score = sum(
|
|
validation_results[source]["quality_score"] * weight
|
|
for source, weight in weights.items()
|
|
)
|
|
validation_results["overall_quality_score"] = round(overall_score, 2)
|
|
|
|
return validation_results
|
|
|
|
def _generate_synthetic_weather_data(
|
|
self,
|
|
aligned_range: AlignedDateRange
|
|
) -> List[Dict[str, Any]]:
|
|
"""Generate realistic synthetic weather data for Madrid"""
|
|
synthetic_data = []
|
|
current_date = aligned_range.start
|
|
|
|
# Madrid seasonal temperature patterns
|
|
seasonal_temps = {
|
|
1: 9, 2: 11, 3: 15, 4: 17, 5: 21, 6: 26,
|
|
7: 29, 8: 28, 9: 24, 10: 18, 11: 12, 12: 9
|
|
}
|
|
|
|
while current_date <= aligned_range.end:
|
|
month = current_date.month
|
|
base_temp = seasonal_temps.get(month, 15)
|
|
|
|
# Add some realistic variation
|
|
import random
|
|
temp_variation = random.gauss(0, 3) # ±3°C variation
|
|
temperature = max(0, base_temp + temp_variation)
|
|
|
|
# Precipitation patterns (Madrid is relatively dry)
|
|
precipitation = 0.0
|
|
if random.random() < 0.15: # 15% chance of rain
|
|
precipitation = random.uniform(0.1, 15.0)
|
|
|
|
synthetic_data.append({
|
|
"date": current_date,
|
|
"temperature": round(temperature, 1),
|
|
"precipitation": round(precipitation, 1),
|
|
"humidity": round(random.uniform(40, 80), 1),
|
|
"wind_speed": round(random.uniform(2, 15), 1),
|
|
"pressure": round(random.uniform(1005, 1025), 1),
|
|
"source": "synthetic_madrid_pattern"
|
|
})
|
|
|
|
current_date = current_date + timedelta(days=1)
|
|
|
|
logger.info(f"Generated {len(synthetic_data)} synthetic weather records with Madrid patterns")
|
|
return synthetic_data
|
|
|
|
def validate_training_data_quality(self, dataset: TrainingDataSet) -> Dict[str, Any]:
|
|
|
|
"""Enhanced validation of training data quality"""
|
|
validation_results = {
|
|
"is_valid": True,
|
|
"warnings": [],
|
|
"errors": [],
|
|
"data_quality_score": 100.0,
|
|
"recommendations": []
|
|
}
|
|
|
|
# Check sales data completeness
|
|
sales_count = len(dataset.sales_data)
|
|
if sales_count < 30:
|
|
validation_results["warnings"].append(
|
|
f"Limited sales data: {sales_count} records (recommended: 30+)"
|
|
)
|
|
validation_results["data_quality_score"] -= 20
|
|
validation_results["recommendations"].append("Consider collecting more historical sales data")
|
|
elif sales_count < 90:
|
|
validation_results["warnings"].append(
|
|
f"Moderate sales data: {sales_count} records (optimal: 90+)"
|
|
)
|
|
validation_results["data_quality_score"] -= 10
|
|
|
|
# Check date coverage
|
|
date_coverage = (dataset.date_range.end - dataset.date_range.start).days
|
|
if date_coverage < 90:
|
|
validation_results["warnings"].append(
|
|
f"Limited date coverage: {date_coverage} days (recommended: 90+)"
|
|
)
|
|
validation_results["data_quality_score"] -= 15
|
|
validation_results["recommendations"].append("Extend date range for better seasonality detection")
|
|
|
|
# Check external data availability
|
|
if not dataset.weather_data:
|
|
validation_results["warnings"].append("No weather data available")
|
|
validation_results["data_quality_score"] -= 10
|
|
validation_results["recommendations"].append("Weather data improves forecast accuracy")
|
|
elif len(dataset.weather_data) < date_coverage * 0.5:
|
|
validation_results["warnings"].append("Sparse weather data coverage")
|
|
validation_results["data_quality_score"] -= 5
|
|
|
|
if not dataset.traffic_data:
|
|
validation_results["warnings"].append("No traffic data available")
|
|
validation_results["data_quality_score"] -= 5
|
|
validation_results["recommendations"].append("Traffic data can help with location-based patterns")
|
|
|
|
# Check data consistency
|
|
unique_products = set()
|
|
for record in dataset.sales_data:
|
|
if 'product_name' in record:
|
|
unique_products.add(record['product_name'])
|
|
|
|
if len(unique_products) == 0:
|
|
validation_results["errors"].append("No product names found in sales data")
|
|
validation_results["is_valid"] = False
|
|
elif len(unique_products) > 50:
|
|
validation_results["warnings"].append(
|
|
f"Many products detected ({len(unique_products)}). Consider training models in batches."
|
|
)
|
|
validation_results["recommendations"].append("Group similar products for better training efficiency")
|
|
|
|
# Check for data source constraints
|
|
if dataset.date_range.constraints:
|
|
constraint_info = []
|
|
for constraint_type, message in dataset.date_range.constraints.items():
|
|
constraint_info.append(f"{constraint_type}: {message}")
|
|
|
|
validation_results["warnings"].append(
|
|
f"Data source constraints applied: {'; '.join(constraint_info)}"
|
|
)
|
|
|
|
# Final validation
|
|
if validation_results["errors"]:
|
|
validation_results["is_valid"] = False
|
|
validation_results["data_quality_score"] = 0.0
|
|
|
|
# Ensure score doesn't go below 0
|
|
validation_results["data_quality_score"] = max(0.0, validation_results["data_quality_score"])
|
|
|
|
# Add quality assessment
|
|
score = validation_results["data_quality_score"]
|
|
if score >= 80:
|
|
validation_results["quality_assessment"] = "Excellent"
|
|
elif score >= 60:
|
|
validation_results["quality_assessment"] = "Good"
|
|
elif score >= 40:
|
|
validation_results["quality_assessment"] = "Fair"
|
|
else:
|
|
validation_results["quality_assessment"] = "Poor"
|
|
|
|
return validation_results
|
|
|
|
def get_data_collection_plan(self, aligned_range: AlignedDateRange) -> Dict[str, Dict]:
|
|
"""
|
|
Generate an enhanced data collection plan based on the aligned date range.
|
|
"""
|
|
plan = {
|
|
"collection_summary": {
|
|
"start_date": aligned_range.start.isoformat(),
|
|
"end_date": aligned_range.end.isoformat(),
|
|
"duration_days": (aligned_range.end - aligned_range.start).days,
|
|
"available_sources": [source.value for source in aligned_range.available_sources],
|
|
"constraints": aligned_range.constraints
|
|
},
|
|
"data_sources": {}
|
|
}
|
|
|
|
# Bakery Sales Data
|
|
if DataSourceType.BAKERY_SALES in aligned_range.available_sources:
|
|
plan["data_sources"]["sales_data"] = {
|
|
"start_date": aligned_range.start.isoformat(),
|
|
"end_date": aligned_range.end.isoformat(),
|
|
"source": "user_upload",
|
|
"required": True,
|
|
"priority": "high",
|
|
"expected_records": "variable",
|
|
"data_points": ["date", "product_name", "quantity"],
|
|
"validation": "required_fields_check"
|
|
}
|
|
|
|
# Madrid Traffic Data
|
|
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
|
|
plan["data_sources"]["traffic_data"] = {
|
|
"start_date": aligned_range.start.isoformat(),
|
|
"end_date": aligned_range.end.isoformat(),
|
|
"source": "madrid_opendata",
|
|
"required": False,
|
|
"priority": "medium",
|
|
"expected_records": (aligned_range.end - aligned_range.start).days,
|
|
"constraint": "Cannot request current month data",
|
|
"data_points": ["date", "traffic_volume", "congestion_level"],
|
|
"validation": "date_constraint_check"
|
|
}
|
|
|
|
# Weather Data
|
|
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
|
|
plan["data_sources"]["weather_data"] = {
|
|
"start_date": aligned_range.start.isoformat(),
|
|
"end_date": aligned_range.end.isoformat(),
|
|
"source": "aemet_api",
|
|
"required": False,
|
|
"priority": "high",
|
|
"expected_records": (aligned_range.end - aligned_range.start).days,
|
|
"constraint": "Available from yesterday backward",
|
|
"data_points": ["date", "temperature", "precipitation", "humidity"],
|
|
"validation": "temporal_constraint_check",
|
|
"fallback": "synthetic_madrid_weather"
|
|
}
|
|
|
|
return plan
|
|
|
|
def get_orchestration_summary(self, dataset: TrainingDataSet) -> Dict[str, Any]:
|
|
"""
|
|
Generate a comprehensive summary of the orchestration process.
|
|
"""
|
|
return {
|
|
"tenant_id": dataset.metadata.get("tenant_id"),
|
|
"job_id": dataset.metadata.get("job_id"),
|
|
"orchestration_completed_at": dataset.metadata.get("preparation_timestamp"),
|
|
"data_alignment": {
|
|
"original_range": dataset.metadata.get("original_sales_range"),
|
|
"aligned_range": {
|
|
"start": dataset.date_range.start.isoformat(),
|
|
"end": dataset.date_range.end.isoformat(),
|
|
"duration_days": (dataset.date_range.end - dataset.date_range.start).days
|
|
},
|
|
"constraints_applied": dataset.date_range.constraints,
|
|
"available_sources": [source.value for source in dataset.date_range.available_sources]
|
|
},
|
|
"data_collection_results": {
|
|
"sales_records": len(dataset.sales_data),
|
|
"weather_records": len(dataset.weather_data),
|
|
"traffic_records": len(dataset.traffic_data),
|
|
"total_records": len(dataset.sales_data) + len(dataset.weather_data) + len(dataset.traffic_data)
|
|
},
|
|
"data_quality": dataset.metadata.get("data_quality", {}),
|
|
"validation_results": dataset.metadata.get("final_validation", {}),
|
|
"processing_metadata": {
|
|
"bakery_location": dataset.metadata.get("bakery_location"),
|
|
"data_sources_requested": len(dataset.date_range.available_sources),
|
|
"data_sources_successful": sum([
|
|
1 if len(dataset.sales_data) > 0 else 0,
|
|
1 if len(dataset.weather_data) > 0 else 0,
|
|
1 if len(dataset.traffic_data) > 0 else 0
|
|
])
|
|
}
|
|
} |