Improve training code

This commit is contained in:
Urtzi Alfaro
2025-07-28 19:28:39 +02:00
parent 946015b80c
commit 98f546af12
15 changed files with 2534 additions and 2812 deletions

View File

@@ -0,0 +1,240 @@
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
import logging
logger = logging.getLogger(__name__)
class DataSourceType(Enum):
BAKERY_SALES = "bakery_sales"
MADRID_TRAFFIC = "madrid_traffic"
WEATHER_FORECAST = "weather_forecast"
@dataclass
class DateRange:
start: datetime
end: datetime
source: DataSourceType
def duration_days(self) -> int:
return (self.end - self.start).days
def overlaps_with(self, other: 'DateRange') -> bool:
return self.start <= other.end and other.start <= self.end
@dataclass
class AlignedDateRange:
start: datetime
end: datetime
available_sources: List[DataSourceType]
constraints: Dict[str, str]
class DateAlignmentService:
"""
Central service for managing and aligning dates across multiple data sources
for the bakery sales prediction model.
"""
def __init__(self):
self.MAX_TRAINING_RANGE_DAYS = 365 # Maximum training data range
self.MIN_TRAINING_RANGE_DAYS = 30 # Minimum viable training data
def validate_and_align_dates(
self,
user_sales_range: DateRange,
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None
) -> AlignedDateRange:
"""
Main method to validate and align dates across all data sources.
Args:
user_sales_range: Date range of user-provided sales data
requested_start: Optional explicit start date for training
requested_end: Optional explicit end date for training
Returns:
AlignedDateRange with validated start/end dates and available sources
"""
try:
# Step 1: Determine the base date range
base_range = self._determine_base_range(
user_sales_range, requested_start, requested_end
)
# Step 2: Apply data source constraints
aligned_range = self._apply_data_source_constraints(base_range)
# Step 3: Validate final range
self._validate_final_range(aligned_range)
logger.info(f"Date alignment completed: {aligned_range.start} to {aligned_range.end}")
return aligned_range
except Exception as e:
logger.error(f"Date alignment failed: {str(e)}")
raise ValueError(f"Unable to align dates: {str(e)}")
def _determine_base_range(
self,
user_sales_range: DateRange,
requested_start: Optional[datetime],
requested_end: Optional[datetime]
) -> DateRange:
"""Determine the base date range for training."""
# Use explicit dates if provided
if requested_start and requested_end:
if requested_end <= requested_start:
raise ValueError("End date must be after start date")
return DateRange(requested_start, requested_end, DataSourceType.BAKERY_SALES)
# Otherwise, use the user's sales data range as the foundation
start_date = requested_start or user_sales_range.start
end_date = requested_end or user_sales_range.end
# Ensure we don't exceed maximum training range
if (end_date - start_date).days > self.MAX_TRAINING_RANGE_DAYS:
start_date = end_date - timedelta(days=self.MAX_TRAINING_RANGE_DAYS)
logger.warning(f"Limiting training range to {self.MAX_TRAINING_RANGE_DAYS} days")
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
def _apply_data_source_constraints(self, base_range: DateRange) -> AlignedDateRange:
"""Apply constraints from each data source and determine final aligned range."""
current_month = datetime.now().replace(day=1, hour=0, minute=0, second=0, microsecond=0)
available_sources = [DataSourceType.BAKERY_SALES] # Always have sales data
constraints = {}
# Madrid Traffic Data Constraint
madrid_end_date = self._get_madrid_traffic_end_date()
if base_range.end > madrid_end_date:
# If requested end date is in current month, adjust it
new_end = madrid_end_date
constraints["madrid_traffic"] = f"Adjusted end date to {new_end.date()} (latest available traffic data)"
logger.info(f"Madrid traffic constraint: end date adjusted to {new_end.date()}")
else:
new_end = base_range.end
available_sources.append(DataSourceType.MADRID_TRAFFIC)
# Weather Forecast Constraint
# Weather data available from yesterday backward
weather_end_date = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=1)
if base_range.end > weather_end_date:
if new_end > weather_end_date:
new_end = weather_end_date
constraints["weather"] = f"Adjusted end date to {new_end.date()} (latest available weather data)"
logger.info(f"Weather constraint: end date adjusted to {new_end.date()}")
if new_end >= base_range.start:
available_sources.append(DataSourceType.WEATHER_FORECAST)
# Ensure minimum training period
final_start = base_range.start
if (new_end - final_start).days < self.MIN_TRAINING_RANGE_DAYS:
final_start = new_end - timedelta(days=self.MIN_TRAINING_RANGE_DAYS)
constraints["minimum_period"] = f"Adjusted start date to ensure {self.MIN_TRAINING_RANGE_DAYS} day minimum training period"
logger.info(f"Minimum period constraint: start date adjusted to {final_start.date()}")
return AlignedDateRange(
start=final_start,
end=new_end,
available_sources=available_sources,
constraints=constraints
)
def _get_madrid_traffic_end_date(self) -> datetime:
"""
Get the latest available date for Madrid traffic data.
Data for current month is not available until the following month.
"""
now = datetime.now()
if now.day == 1:
# If it's the first day of the month, data up to previous month should be available
last_available_month = now.replace(day=1) - timedelta(days=1)
else:
# Data up to the previous month is available
last_available_month = now.replace(day=1) - timedelta(days=1)
# Return the last day of the last available month
if last_available_month.month == 12:
next_month = last_available_month.replace(year=last_available_month.year + 1, month=1)
else:
next_month = last_available_month.replace(month=last_available_month.month + 1)
return next_month - timedelta(days=1)
def _validate_final_range(self, aligned_range: AlignedDateRange) -> None:
"""Validate the final aligned date range."""
if aligned_range.start >= aligned_range.end:
raise ValueError("Invalid date range: start date must be before end date")
duration = (aligned_range.end - aligned_range.start).days
if duration < self.MIN_TRAINING_RANGE_DAYS:
raise ValueError(f"Insufficient training data: {duration} days (minimum: {self.MIN_TRAINING_RANGE_DAYS})")
if duration > self.MAX_TRAINING_RANGE_DAYS:
raise ValueError(f"Training period too long: {duration} days (maximum: {self.MAX_TRAINING_RANGE_DAYS})")
# Ensure we have at least sales data
if DataSourceType.BAKERY_SALES not in aligned_range.available_sources:
raise ValueError("No sales data available for the aligned date range")
def get_data_collection_plan(self, aligned_range: AlignedDateRange) -> Dict[str, Dict]:
"""
Generate a data collection plan based on the aligned date range.
Returns:
Dictionary with collection plans for each data source
"""
plan = {}
# Bakery Sales Data
if DataSourceType.BAKERY_SALES in aligned_range.available_sources:
plan["sales_data"] = {
"start_date": aligned_range.start,
"end_date": aligned_range.end,
"source": "user_upload",
"required": True
}
# Madrid Traffic Data
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
plan["traffic_data"] = {
"start_date": aligned_range.start,
"end_date": aligned_range.end,
"source": "madrid_opendata",
"required": False,
"constraint": "Cannot request current month data"
}
# Weather Data
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
plan["weather_data"] = {
"start_date": aligned_range.start,
"end_date": aligned_range.end,
"source": "aemet_api",
"required": False,
"constraint": "Available from yesterday backward"
}
return plan
def check_madrid_current_month_constraint(self, end_date: datetime) -> bool:
"""
Check if the end date violates the Madrid Open Data current month constraint.
Args:
end_date: The requested end date
Returns:
True if the constraint is violated (end date is in current month)
"""
now = datetime.now()
current_month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return end_date >= current_month_start

View File

@@ -0,0 +1,706 @@
# services/training/app/services/training_orchestrator.py
"""
Training Data Orchestrator - Enhanced Integration Layer
Orchestrates data collection, date alignment, and preparation for ML training
"""
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from app.services.data_client import DataServiceClient
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
logger = logging.getLogger(__name__)
@dataclass
class TrainingDataSet:
"""Container for all training data with metadata"""
sales_data: List[Dict[str, Any]]
weather_data: List[Dict[str, Any]]
traffic_data: List[Dict[str, Any]]
date_range: AlignedDateRange
metadata: Dict[str, Any]
class TrainingDataOrchestrator:
"""
Enhanced orchestrator for data collection from multiple sources.
Ensures date alignment, handles data source constraints, and prepares data for ML training.
"""
def __init__(self,
madrid_client=None,
weather_client=None,
date_alignment_service: DateAlignmentService = None):
self.madrid_client = madrid_client
self.weather_client = weather_client
self.data_client = DataServiceClient()
self.date_alignment_service = date_alignment_service or DateAlignmentService()
self.max_concurrent_requests = 3
async def prepare_training_data(
self,
tenant_id: str,
bakery_location: Tuple[float, float], # (lat, lon)
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None,
job_id: Optional[str] = None
) -> TrainingDataSet:
"""
Main method to prepare all training data with comprehensive date alignment.
Args:
tenant_id: Tenant identifier
sales_data: User-provided sales data
bakery_location: Bakery coordinates (lat, lon)
requested_start: Optional explicit start date
requested_end: Optional explicit end date
job_id: Training job identifier for logging
Returns:
TrainingDataSet with all aligned and validated data
"""
logger.info(f"Starting comprehensive training data preparation for tenant {tenant_id}, job {job_id}")
try:
sales_data = self.data_client.fetch_sales_data(tenant_id)
# Step 1: Extract and validate sales data date range
sales_date_range = self._extract_sales_date_range(sales_data)
logger.info(f"Sales data range detected: {sales_date_range.start} to {sales_date_range.end}")
# Step 2: Apply date alignment across all data sources
aligned_range = self.date_alignment_service.validate_and_align_dates(
user_sales_range=sales_date_range,
requested_start=requested_start,
requested_end=requested_end
)
logger.info(f"Date alignment completed: {aligned_range.start} to {aligned_range.end}")
if aligned_range.constraints:
logger.info(f"Applied constraints: {aligned_range.constraints}")
# Step 3: Filter sales data to aligned date range
filtered_sales = self._filter_sales_data(sales_data, aligned_range)
# Step 4: Collect external data sources concurrently
logger.info("Collecting external data sources...")
weather_data, traffic_data = await self._collect_external_data(
aligned_range, bakery_location
)
# Step 5: Validate data quality
data_quality_results = self._validate_data_sources(
filtered_sales, weather_data, traffic_data, aligned_range
)
# Step 6: Create comprehensive training dataset
training_dataset = TrainingDataSet(
sales_data=filtered_sales,
weather_data=weather_data,
traffic_data=traffic_data,
date_range=aligned_range,
metadata={
"tenant_id": tenant_id,
"job_id": job_id,
"bakery_location": bakery_location,
"data_sources_used": aligned_range.available_sources,
"constraints_applied": aligned_range.constraints,
"data_quality": data_quality_results,
"preparation_timestamp": datetime.now().isoformat(),
"original_sales_range": {
"start": sales_date_range.start.isoformat(),
"end": sales_date_range.end.isoformat()
}
}
)
# Step 7: Final validation
final_validation = self.validate_training_data_quality(training_dataset)
training_dataset.metadata["final_validation"] = final_validation
logger.info(f"Training data preparation completed successfully:")
logger.info(f" - Sales records: {len(filtered_sales)}")
logger.info(f" - Weather records: {len(weather_data)}")
logger.info(f" - Traffic records: {len(traffic_data)}")
logger.info(f" - Data quality score: {final_validation.get('data_quality_score', 'N/A')}")
return training_dataset
except Exception as e:
logger.error(f"Training data preparation failed: {str(e)}")
raise ValueError(f"Failed to prepare training data: {str(e)}")
def _extract_sales_date_range(self, sales_data: List[Dict[str, Any]]) -> DateRange:
"""Extract and validate the date range from sales data"""
if not sales_data:
raise ValueError("No sales data provided")
dates = []
valid_records = 0
for record in sales_data:
try:
if 'date' in record:
date_val = record['date']
if isinstance(date_val, str):
# Handle various date formats
if 'T' in date_val:
date_val = date_val.replace('Z', '+00:00')
parsed_date = datetime.fromisoformat(date_val.split('T')[0])
elif isinstance(date_val, datetime):
parsed_date = date_val
else:
continue
dates.append(parsed_date)
valid_records += 1
except (ValueError, TypeError) as e:
logger.warning(f"Invalid date in sales record: {record.get('date', 'N/A')} - {str(e)}")
continue
if not dates:
raise ValueError("No valid dates found in sales data")
logger.info(f"Processed {valid_records} valid date records from {len(sales_data)} total records")
return DateRange(
start=min(dates),
end=max(dates),
source=DataSourceType.BAKERY_SALES
)
def _filter_sales_data(
self,
sales_data: List[Dict[str, Any]],
aligned_range: AlignedDateRange
) -> List[Dict[str, Any]]:
"""Filter sales data to the aligned date range with enhanced validation"""
filtered_data = []
filtered_count = 0
for record in sales_data:
try:
if 'date' in record:
record_date = record['date']
if isinstance(record_date, str):
if 'T' in record_date:
record_date = record_date.replace('Z', '+00:00')
record_date = datetime.fromisoformat(record_date.split('T')[0])
elif isinstance(record_date, datetime):
record_date = record_date.replace(hour=0, minute=0, second=0, microsecond=0)
# Check if date falls within aligned range
if aligned_range.start <= record_date <= aligned_range.end:
# Validate that record has required fields
if self._validate_sales_record(record):
filtered_data.append(record)
else:
filtered_count += 1
except Exception as e:
logger.warning(f"Error processing sales record: {str(e)}")
filtered_count += 1
continue
logger.info(f"Filtered sales data: {len(filtered_data)} records in aligned range")
if filtered_count > 0:
logger.warning(f"Filtered out {filtered_count} invalid records")
return filtered_data
def _validate_sales_record(self, record: Dict[str, Any]) -> bool:
"""Validate individual sales record"""
required_fields = ['date', 'product_name']
quantity_fields = ['quantity', 'quantity_sold', 'sales', 'units_sold']
# Check required fields
for field in required_fields:
if field not in record or record[field] is None:
return False
# Check at least one quantity field exists
has_quantity = any(field in record and record[field] is not None for field in quantity_fields)
if not has_quantity:
return False
# Validate quantity is numeric and non-negative
for field in quantity_fields:
if field in record and record[field] is not None:
try:
quantity = float(record[field])
if quantity < 0:
return False
except (ValueError, TypeError):
return False
break
return True
async def _collect_external_data(
self,
aligned_range: AlignedDateRange,
bakery_location: Tuple[float, float]
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Collect weather and traffic data concurrently with enhanced error handling"""
lat, lon = bakery_location
# Create collection tasks with timeout
tasks = []
# Weather data collection
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
weather_task = asyncio.create_task(
self._collect_weather_data_with_timeout(lat, lon, aligned_range)
)
tasks.append(("weather", weather_task))
# Traffic data collection
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
traffic_task = asyncio.create_task(
self._collect_traffic_data_with_timeout(lat, lon, aligned_range)
)
tasks.append(("traffic", traffic_task))
# Execute tasks concurrently with proper error handling
results = {}
if tasks:
try:
completed_tasks = await asyncio.gather(
*[task for _, task in tasks],
return_exceptions=True
)
for i, (task_name, _) in enumerate(tasks):
result = completed_tasks[i]
if isinstance(result, Exception):
logger.warning(f"{task_name} data collection failed: {result}")
results[task_name] = []
else:
results[task_name] = result
logger.info(f"{task_name} data collection completed: {len(result)} records")
except Exception as e:
logger.error(f"Error in concurrent data collection: {str(e)}")
results = {"weather": [], "traffic": []}
weather_data = results.get("weather", [])
traffic_data = results.get("traffic", [])
return weather_data, traffic_data
async def _collect_weather_data_with_timeout(
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange
) -> List[Dict[str, Any]]:
"""Collect weather data with timeout and fallback"""
try:
if not self.weather_client:
logger.info("Weather client not configured, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
weather_data = await asyncio.wait_for(
self.data_client.fetch_weather_data(aligned_range.start, aligned_range.end, lat, lon),
)
# Validate weather data
if self._validate_weather_data(weather_data):
logger.info(f"Collected {len(weather_data)} valid weather records")
return weather_data
else:
logger.warning("Invalid weather data received, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
except asyncio.TimeoutError:
logger.warning(f"Weather data collection timed out after {timeout_seconds}s, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
except Exception as e:
logger.warning(f"Weather data collection failed: {e}, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
async def _collect_traffic_data_with_timeout(
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange
) -> List[Dict[str, Any]]:
"""Collect traffic data with timeout and Madrid constraint validation"""
try:
if not self.madrid_client:
logger.info("Madrid client not configured, no traffic data available")
return []
# Double-check Madrid constraint before making request
if self.date_alignment_service.check_madrid_current_month_constraint(aligned_range.end):
logger.warning("Madrid current month constraint violation, no traffic data available")
return []
traffic_data = await asyncio.wait_for(
self.data_client.fetch_traffic_data(aligned_range.start, aligned_range.end, lat, lon),
)
# Validate traffic data
if self._validate_traffic_data(traffic_data):
logger.info(f"Collected {len(traffic_data)} valid traffic records")
return traffic_data
else:
logger.warning("Invalid traffic data received")
return []
except asyncio.TimeoutError:
logger.warning(f"Traffic data collection timed out after {timeout_seconds}s")
return []
except Exception as e:
logger.warning(f"Traffic data collection failed: {e}")
return []
def _validate_weather_data(self, weather_data: List[Dict[str, Any]]) -> bool:
"""Validate weather data quality"""
if not weather_data:
return False
required_fields = ['date']
weather_fields = ['temperature', 'temp', 'temperatura', 'precipitation', 'rain', 'lluvia']
valid_records = 0
for record in weather_data:
# Check required fields
if not all(field in record for field in required_fields):
continue
# Check at least one weather field exists
if any(field in record and record[field] is not None for field in weather_fields):
valid_records += 1
# Consider valid if at least 50% of records are valid
validity_threshold = 0.5
is_valid = (valid_records / len(weather_data)) >= validity_threshold
if not is_valid:
logger.warning(f"Weather data validation failed: {valid_records}/{len(weather_data)} valid records")
return is_valid
def _validate_traffic_data(self, traffic_data: List[Dict[str, Any]]) -> bool:
"""Validate traffic data quality"""
if not traffic_data:
return False
required_fields = ['date']
traffic_fields = ['traffic_volume', 'traffic_intensity', 'intensidad', 'trafico']
valid_records = 0
for record in traffic_data:
# Check required fields
if not all(field in record for field in required_fields):
continue
# Check at least one traffic field exists
if any(field in record and record[field] is not None for field in traffic_fields):
valid_records += 1
# Consider valid if at least 30% of records are valid (traffic data is often sparse)
validity_threshold = 0.3
is_valid = (valid_records / len(traffic_data)) >= validity_threshold
if not is_valid:
logger.warning(f"Traffic data validation failed: {valid_records}/{len(traffic_data)} valid records")
return is_valid
def _validate_data_sources(
self,
sales_data: List[Dict[str, Any]],
weather_data: List[Dict[str, Any]],
traffic_data: List[Dict[str, Any]],
aligned_range: AlignedDateRange
) -> Dict[str, Any]:
"""Validate all data sources and provide quality metrics"""
validation_results = {
"sales_data": {
"record_count": len(sales_data),
"is_valid": len(sales_data) > 0,
"coverage_days": (aligned_range.end - aligned_range.start).days,
"quality_score": 0.0
},
"weather_data": {
"record_count": len(weather_data),
"is_valid": self._validate_weather_data(weather_data) if weather_data else False,
"quality_score": 0.0
},
"traffic_data": {
"record_count": len(traffic_data),
"is_valid": self._validate_traffic_data(traffic_data) if traffic_data else False,
"quality_score": 0.0
},
"overall_quality_score": 0.0
}
# Calculate quality scores
# Sales data quality (most important)
if validation_results["sales_data"]["record_count"] > 0:
coverage_ratio = min(1.0, validation_results["sales_data"]["record_count"] / validation_results["sales_data"]["coverage_days"])
validation_results["sales_data"]["quality_score"] = coverage_ratio * 100
# Weather data quality
if validation_results["weather_data"]["record_count"] > 0:
expected_weather_records = (aligned_range.end - aligned_range.start).days
coverage_ratio = min(1.0, validation_results["weather_data"]["record_count"] / expected_weather_records)
validation_results["weather_data"]["quality_score"] = coverage_ratio * 100
# Traffic data quality
if validation_results["traffic_data"]["record_count"] > 0:
expected_traffic_records = (aligned_range.end - aligned_range.start).days
coverage_ratio = min(1.0, validation_results["traffic_data"]["record_count"] / expected_traffic_records)
validation_results["traffic_data"]["quality_score"] = coverage_ratio * 100
# Overall quality score (weighted by importance)
weights = {"sales_data": 0.7, "weather_data": 0.2, "traffic_data": 0.1}
overall_score = sum(
validation_results[source]["quality_score"] * weight
for source, weight in weights.items()
)
validation_results["overall_quality_score"] = round(overall_score, 2)
return validation_results
def _generate_synthetic_weather_data(
self,
aligned_range: AlignedDateRange
) -> List[Dict[str, Any]]:
"""Generate realistic synthetic weather data for Madrid"""
synthetic_data = []
current_date = aligned_range.start
# Madrid seasonal temperature patterns
seasonal_temps = {
1: 9, 2: 11, 3: 15, 4: 17, 5: 21, 6: 26,
7: 29, 8: 28, 9: 24, 10: 18, 11: 12, 12: 9
}
while current_date <= aligned_range.end:
month = current_date.month
base_temp = seasonal_temps.get(month, 15)
# Add some realistic variation
import random
temp_variation = random.gauss(0, 3) # ±3°C variation
temperature = max(0, base_temp + temp_variation)
# Precipitation patterns (Madrid is relatively dry)
precipitation = 0.0
if random.random() < 0.15: # 15% chance of rain
precipitation = random.uniform(0.1, 15.0)
synthetic_data.append({
"date": current_date,
"temperature": round(temperature, 1),
"precipitation": round(precipitation, 1),
"humidity": round(random.uniform(40, 80), 1),
"wind_speed": round(random.uniform(2, 15), 1),
"pressure": round(random.uniform(1005, 1025), 1),
"source": "synthetic_madrid_pattern"
})
current_date = current_date + timedelta(days=1)
logger.info(f"Generated {len(synthetic_data)} synthetic weather records with Madrid patterns")
return synthetic_data
def validate_training_data_quality(self, dataset: TrainingDataSet) -> Dict[str, Any]:
"""Enhanced validation of training data quality"""
validation_results = {
"is_valid": True,
"warnings": [],
"errors": [],
"data_quality_score": 100.0,
"recommendations": []
}
# Check sales data completeness
sales_count = len(dataset.sales_data)
if sales_count < 30:
validation_results["warnings"].append(
f"Limited sales data: {sales_count} records (recommended: 30+)"
)
validation_results["data_quality_score"] -= 20
validation_results["recommendations"].append("Consider collecting more historical sales data")
elif sales_count < 90:
validation_results["warnings"].append(
f"Moderate sales data: {sales_count} records (optimal: 90+)"
)
validation_results["data_quality_score"] -= 10
# Check date coverage
date_coverage = (dataset.date_range.end - dataset.date_range.start).days
if date_coverage < 90:
validation_results["warnings"].append(
f"Limited date coverage: {date_coverage} days (recommended: 90+)"
)
validation_results["data_quality_score"] -= 15
validation_results["recommendations"].append("Extend date range for better seasonality detection")
# Check external data availability
if not dataset.weather_data:
validation_results["warnings"].append("No weather data available")
validation_results["data_quality_score"] -= 10
validation_results["recommendations"].append("Weather data improves forecast accuracy")
elif len(dataset.weather_data) < date_coverage * 0.5:
validation_results["warnings"].append("Sparse weather data coverage")
validation_results["data_quality_score"] -= 5
if not dataset.traffic_data:
validation_results["warnings"].append("No traffic data available")
validation_results["data_quality_score"] -= 5
validation_results["recommendations"].append("Traffic data can help with location-based patterns")
# Check data consistency
unique_products = set()
for record in dataset.sales_data:
if 'product_name' in record:
unique_products.add(record['product_name'])
if len(unique_products) == 0:
validation_results["errors"].append("No product names found in sales data")
validation_results["is_valid"] = False
elif len(unique_products) > 50:
validation_results["warnings"].append(
f"Many products detected ({len(unique_products)}). Consider training models in batches."
)
validation_results["recommendations"].append("Group similar products for better training efficiency")
# Check for data source constraints
if dataset.date_range.constraints:
constraint_info = []
for constraint_type, message in dataset.date_range.constraints.items():
constraint_info.append(f"{constraint_type}: {message}")
validation_results["warnings"].append(
f"Data source constraints applied: {'; '.join(constraint_info)}"
)
# Final validation
if validation_results["errors"]:
validation_results["is_valid"] = False
validation_results["data_quality_score"] = 0.0
# Ensure score doesn't go below 0
validation_results["data_quality_score"] = max(0.0, validation_results["data_quality_score"])
# Add quality assessment
score = validation_results["data_quality_score"]
if score >= 80:
validation_results["quality_assessment"] = "Excellent"
elif score >= 60:
validation_results["quality_assessment"] = "Good"
elif score >= 40:
validation_results["quality_assessment"] = "Fair"
else:
validation_results["quality_assessment"] = "Poor"
return validation_results
def get_data_collection_plan(self, aligned_range: AlignedDateRange) -> Dict[str, Dict]:
"""
Generate an enhanced data collection plan based on the aligned date range.
"""
plan = {
"collection_summary": {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"duration_days": (aligned_range.end - aligned_range.start).days,
"available_sources": [source.value for source in aligned_range.available_sources],
"constraints": aligned_range.constraints
},
"data_sources": {}
}
# Bakery Sales Data
if DataSourceType.BAKERY_SALES in aligned_range.available_sources:
plan["data_sources"]["sales_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "user_upload",
"required": True,
"priority": "high",
"expected_records": "variable",
"data_points": ["date", "product_name", "quantity"],
"validation": "required_fields_check"
}
# Madrid Traffic Data
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
plan["data_sources"]["traffic_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "madrid_opendata",
"required": False,
"priority": "medium",
"expected_records": (aligned_range.end - aligned_range.start).days,
"constraint": "Cannot request current month data",
"data_points": ["date", "traffic_volume", "congestion_level"],
"validation": "date_constraint_check"
}
# Weather Data
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
plan["data_sources"]["weather_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "aemet_api",
"required": False,
"priority": "high",
"expected_records": (aligned_range.end - aligned_range.start).days,
"constraint": "Available from yesterday backward",
"data_points": ["date", "temperature", "precipitation", "humidity"],
"validation": "temporal_constraint_check",
"fallback": "synthetic_madrid_weather"
}
return plan
def get_orchestration_summary(self, dataset: TrainingDataSet) -> Dict[str, Any]:
"""
Generate a comprehensive summary of the orchestration process.
"""
return {
"tenant_id": dataset.metadata.get("tenant_id"),
"job_id": dataset.metadata.get("job_id"),
"orchestration_completed_at": dataset.metadata.get("preparation_timestamp"),
"data_alignment": {
"original_range": dataset.metadata.get("original_sales_range"),
"aligned_range": {
"start": dataset.date_range.start.isoformat(),
"end": dataset.date_range.end.isoformat(),
"duration_days": (dataset.date_range.end - dataset.date_range.start).days
},
"constraints_applied": dataset.date_range.constraints,
"available_sources": [source.value for source in dataset.date_range.available_sources]
},
"data_collection_results": {
"sales_records": len(dataset.sales_data),
"weather_records": len(dataset.weather_data),
"traffic_records": len(dataset.traffic_data),
"total_records": len(dataset.sales_data) + len(dataset.weather_data) + len(dataset.traffic_data)
},
"data_quality": dataset.metadata.get("data_quality", {}),
"validation_results": dataset.metadata.get("final_validation", {}),
"processing_metadata": {
"bakery_location": dataset.metadata.get("bakery_location"),
"data_sources_requested": len(dataset.date_range.available_sources),
"data_sources_successful": sum([
1 if len(dataset.sales_data) > 0 else 0,
1 if len(dataset.weather_data) > 0 else 0,
1 if len(dataset.traffic_data) > 0 else 0
])
}
}

View File

@@ -1,721 +1,303 @@
# services/training/app/services/training_service.py
"""
Training service business logic
Orchestrates ML training operations and manages job lifecycle
Main Training Service - Coordinates the complete training process
This is the entry point from the API layer
"""
from typing import Dict, List, Any, Optional
import logging
from datetime import datetime, timedelta
import asyncio
import uuid
import logging
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, and_
import httpx
from app.models.training import ModelTrainingLog, TrainedModel
from app.ml.trainer import BakeryMLTrainer
from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
from app.services.messaging import publish_job_completed, publish_job_failed
from app.core.config import settings
from shared.monitoring.metrics import MetricsCollector
from app.services.data_client import DataServiceClient
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
from app.services.training_orchestrator import TrainingDataOrchestrator
from app.core.database import get_db_session
logger = logging.getLogger(__name__)
metrics = MetricsCollector("training-service")
class TrainingService:
"""
Main service class for managing ML training operations.
Replaces the old Celery-based training system with clean async implementation.
Main training service that coordinates the complete training pipeline.
Entry point from API layer - handles business logic and orchestration.
"""
def __init__(self):
self.ml_trainer = BakeryMLTrainer()
self.data_client = DataServiceClient()
async def _determine_sales_date_range(self, sales_data: List[Dict]) -> tuple[datetime, datetime]:
"""Determine start and end dates from sales data with validation"""
if not sales_data:
raise ValueError("No sales data available to determine date range")
dates = []
for record in sales_data:
if 'date' in record:
try:
if isinstance(record['date'], str):
# Handle various date string formats
date_str = record['date'].replace('Z', '+00:00')
if 'T' in date_str:
parsed_date = datetime.fromisoformat(date_str)
else:
# Handle date-only strings
parsed_date = datetime.strptime(date_str, '%Y-%m-%d')
dates.append(parsed_date)
elif isinstance(record['date'], datetime):
dates.append(record['date'])
except (ValueError, AttributeError) as e:
logger.warning(f"Invalid date format in record: {record['date']} - {e}")
continue
if not dates:
raise ValueError("No valid dates found in sales data")
start_date = min(dates)
end_date = max(dates)
# Validate and adjust date range for external APIs
start_date, end_date = self._adjust_date_range_for_apis(start_date, end_date)
logger.info(f"Determined and adjusted sales date range: {start_date} to {end_date}")
return start_date, end_date
def _adjust_date_range_for_apis(self, start_date: datetime, end_date: datetime) -> tuple[datetime, datetime]:
"""Adjust date range to comply with external API limits"""
# Weather and traffic APIs have a 90-day limit
MAX_DAYS = 90
# Calculate current range
current_range = (end_date - start_date).days
if current_range > MAX_DAYS:
logger.warning(f"Date range ({current_range} days) exceeds API limit ({MAX_DAYS} days). Adjusting...")
# Keep the most recent data
start_date = end_date - timedelta(days=MAX_DAYS)
logger.info(f"Adjusted start_date to {start_date} to fit within {MAX_DAYS} day limit")
# Ensure dates are not in the future
now = datetime.now()
if end_date > now:
end_date = now.replace(hour=0, minute=0, second=0, microsecond=0)
logger.info(f"Adjusted end_date to {end_date} (cannot be in future)")
if start_date > now:
start_date = now.replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=30)
logger.info(f"Adjusted start_date to {start_date} (was in future)")
# Ensure start_date is before end_date
if start_date >= end_date:
start_date = end_date - timedelta(days=30) # Default to 30 days of data
logger.warning(f"start_date was not before end_date. Adjusted start_date to {start_date}")
def __init__(self, db_session: AsyncSession = None):
self.db_session = db_session
self.trainer = BakeryMLTrainer(db_session=db_session) # Pass DB session
self.date_alignment_service = DateAlignmentService()
self.orchestrator = TrainingDataOrchestrator(
date_alignment_service=self.date_alignment_service
)
return start_date, end_date
async def start_training_job(
self,
tenant_id: str,
bakery_location: tuple[float, float] = (40.4168, -3.7038), # Default Madrid
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None,
job_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Start a complete training job for a tenant.
Args:
tenant_id: Tenant identifier
sales_data: Historical sales data
bakery_location: Bakery coordinates (lat, lon)
weather_data: Optional weather data
traffic_data: Optional traffic data
requested_start: Optional explicit start date
requested_end: Optional explicit end date
job_id: Optional job identifier
Returns:
Training job results
"""
if not job_id:
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info(f"Starting training job {job_id} for tenant {tenant_id}")
async def execute_training_job_simple(self, job_id: str, tenant_id_str: str, request: TrainingJobRequest):
"""Simple wrapper that creates its own database session"""
try:
# Import database_manager locally to avoid circular imports
from app.core.database import database_manager
logger.info(f"Starting background training job {job_id} for tenant {tenant_id_str}")
# Create new session for background task
async with database_manager.async_session_local() as session:
await self.execute_training_job(session, job_id, tenant_id_str, request)
await session.commit()
except Exception as e:
logger.error(f"Background training job {job_id} failed: {str(e)}")
# Try to update job status to failed
try:
from app.core.database import database_manager
async with database_manager.async_session_local() as error_session:
await self._update_job_status(
error_session, job_id, "failed", 0,
f"Training failed: {str(e)}", error_message=str(e)
)
await error_session.commit()
except Exception as update_error:
logger.error(f"Failed to update job status: {str(update_error)}")
raise
async def create_training_job(self,
db: AsyncSession,
tenant_id: str,
job_id: str,
config: Dict[str, Any]) -> ModelTrainingLog:
"""Create a new training job record"""
try:
training_log = ModelTrainingLog(
job_id=job_id,
# Step 1: Prepare training dataset with date alignment and orchestration
logger.info("Step 1: Preparing and aligning training data")
training_dataset = await self.orchestrator.prepare_training_data(
tenant_id=tenant_id,
status="pending",
progress=0,
current_step="Initializing training job",
start_time=datetime.now(),
config=config
)
db.add(training_log)
await db.commit()
await db.refresh(training_log)
logger.info(f"Created training job {job_id} for tenant {tenant_id}")
return training_log
except Exception as e:
logger.error(f"Failed to create training job: {str(e)}")
await db.rollback()
raise
async def create_single_product_job(self,
db: AsyncSession,
tenant_id: str,
product_name: str,
job_id: str,
config: Dict[str, Any]) -> ModelTrainingLog:
"""Create a training job for a single product"""
try:
config["single_product"] = product_name
training_log = ModelTrainingLog(
job_id=job_id,
tenant_id=tenant_id,
status="pending",
progress=0,
current_step=f"Initializing training for {product_name}",
start_time=datetime.now(),
config=config
)
db.add(training_log)
await db.commit()
await db.refresh(training_log)
logger.info(f"Created single product training job {job_id} for {product_name}")
return training_log
except Exception as e:
logger.error(f"Failed to create single product training job: {str(e)}")
await db.rollback()
raise
async def execute_training_job(self,
db: AsyncSession,
job_id: str,
tenant_id: str,
request: TrainingJobRequest):
"""Execute a complete training job"""
try:
logger.info(f"Starting execution of training job {job_id}")
# Update job status to running
await self._update_job_status(db, job_id, "running", 5, "Fetching training data")
# Fetch sales data from data service
sales_data = await self.data_client.fetch_sales_data(tenant_id)
if not sales_data:
raise ValueError("No sales data found for training")
# Determine date range from sales data
start_date, end_date = await self._determine_sales_date_range(sales_data)
# Convert dates to ISO format strings for API calls
start_date_str = start_date.isoformat()
end_date_str = end_date.isoformat()
logger.info(f"Using date range for external APIs: {start_date_str} to {end_date_str}")
# Fetch external data if requested using the sales date range
weather_data = []
traffic_data = []
await self._update_job_status(db, job_id, "running", 15, "Fetching weather data")
try:
weather_data = await self.data_client.fetch_weather_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=40.4168, # Madrid coordinates
longitude=-3.7038
)
logger.info(f"Fetched {len(weather_data)} weather records")
except Exception as e:
logger.warning(f"Failed to fetch weather data: {e}. Continuing without weather data.")
weather_data = []
await self._update_job_status(db, job_id, "running", 25, "Fetching traffic data")
try:
traffic_data = await self.data_client.fetch_traffic_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=40.4168,
longitude=-3.7038
)
logger.info(f"Fetched {len(traffic_data)} traffic records")
except Exception as e:
logger.warning(f"Failed to fetch traffic data: {e}. Continuing without traffic data.")
traffic_data = []
# Execute ML training
await self._update_job_status(db, job_id, "running", 35, "Processing training data")
training_results = await self.ml_trainer.train_tenant_models(
tenant_id=tenant_id,
sales_data=sales_data,
weather_data=weather_data,
traffic_data=traffic_data,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end,
job_id=job_id
)
await self._update_job_status(db, job_id, "running", 85, "Storing trained models")
# Store trained models in database
await self._store_trained_models(db, tenant_id, training_results)
await self._update_job_status(
db, job_id, "completed", 100, "Training completed successfully",
results=training_results
# Step 2: Execute ML training pipeline
logger.info("Step 2: Starting ML training pipeline")
training_results = await self.trainer.train_tenant_models(
tenant_id=tenant_id,
training_dataset=training_dataset,
job_id=job_id
)
# Publish completion event
await publish_job_completed(job_id, tenant_id, training_results)
# Step 3: Compile final results
final_result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "completed",
"training_results": training_results,
"data_summary": {
"sales_records": len(training_dataset.sales_data),
"weather_records": len(training_dataset.weather_data),
"traffic_records": len(training_dataset.traffic_data),
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat()
},
"data_sources_used": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
},
"completed_at": datetime.now().isoformat()
}
logger.info(f"Training results {training_results}")
logger.info(f"Training job {job_id} completed successfully")
metrics.increment_counter("training_jobs_completed")
return final_result
except Exception as e:
logger.error(f"Training job {job_id} failed: {str(e)}")
await self._update_job_status(
db, job_id, "failed", 0, f"Training failed: {str(e)}",
error_message=str(e)
)
# Publish failure event
await publish_job_failed(job_id, tenant_id, str(e))
metrics.increment_counter("training_jobs_failed")
raise
return {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "failed",
"error_message": str(e),
"failed_at": datetime.now().isoformat()
}
async def execute_single_product_training(self,
db: AsyncSession,
job_id: str,
tenant_id: str,
product_name: str,
request: SingleProductTrainingRequest):
"""Execute training for a single product"""
async def start_single_product_training(
self,
tenant_id: str,
product_name: str,
sales_data: List[Dict[str, Any]],
bakery_location: tuple[float, float] = (40.4168, -3.7038),
weather_data: Optional[List[Dict[str, Any]]] = None,
traffic_data: Optional[List[Dict[str, Any]]] = None,
job_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Train a model for a single product.
Args:
tenant_id: Tenant identifier
product_name: Product name
sales_data: Historical sales data
bakery_location: Bakery coordinates
weather_data: Optional weather data
traffic_data: Optional traffic data
job_id: Optional job identifier
Returns:
Single product training result
"""
if not job_id:
job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
logger.info(f"Starting single product training {job_id} for {product_name}")
try:
logger.info(f"Starting single product training {job_id} for {product_name}")
# Filter sales data for the specific product
product_sales = [
record for record in sales_data
if record.get('product_name') == product_name
]
# Update job status
await self._update_job_status(db, job_id, "running", 10, f"Fetching data for {product_name}")
if not product_sales:
raise ValueError(f"No sales data found for product: {product_name}")
# Fetch data
sales_data = await self._fetch_product_sales_data(tenant_id, product_name, request)
weather_data = []
traffic_data = []
if request.include_weather:
await self._update_job_status(db, job_id, "running", 30, "Fetching weather data")
weather_data = await self.data_client.fetch_weather_data(tenant_id, request)
if request.include_traffic:
await self._update_job_status(db, job_id, "running", 50, "Fetching traffic data")
traffic_data = await self.data_client.fetch_traffic_data(tenant_id, request)
# Execute training
await self._update_job_status(db, job_id, "running", 70, f"Training model for {product_name}")
training_result = await self.ml_trainer.train_single_product(
# Use the same pipeline but for single product
return await self.start_training_job(
tenant_id=tenant_id,
product_name=product_name,
sales_data=sales_data,
sales_data=product_sales,
bakery_location=bakery_location,
weather_data=weather_data,
traffic_data=traffic_data,
job_id=job_id
)
# Store model
await self._update_job_status(db, job_id, "running", 90, "Storing trained model")
await self._store_single_trained_model(db, tenant_id, product_name, training_result)
await self._update_job_status(
db, job_id, "completed", 100, f"Training completed for {product_name}",
results=training_result
)
logger.info(f"Single product training {job_id} completed successfully")
metrics.increment_counter("single_product_training_completed")
except Exception as e:
logger.error(f"Single product training {job_id} failed: {str(e)}")
await self._update_job_status(
db, job_id, "failed", 0, f"Training failed: {str(e)}",
error_message=str(e)
)
metrics.increment_counter("single_product_training_failed")
raise
return {
"job_id": job_id,
"tenant_id": tenant_id,
"product_name": product_name,
"status": "failed",
"error_message": str(e),
"failed_at": datetime.now().isoformat()
}
async def get_job_status(self,
db: AsyncSession,
job_id: str,
tenant_id: str) -> Optional[ModelTrainingLog]:
"""Get training job status"""
try:
result = await db.execute(
select(ModelTrainingLog).where(
and_(
ModelTrainingLog.job_id == job_id,
ModelTrainingLog.tenant_id == tenant_id
)
)
)
return result.scalar_one_or_none()
async def validate_training_data(
self,
tenant_id: str,
sales_data: List[Dict[str, Any]],
products: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Validate training data quality before starting training.
Args:
tenant_id: Tenant identifier
sales_data: Sales data to validate
products: Optional list of specific products to validate
except Exception as e:
logger.error(f"Failed to get job status: {str(e)}")
return None
async def list_training_jobs(self,
db: AsyncSession,
tenant_id: str,
limit: int = 10,
status_filter: Optional[str] = None) -> List[ModelTrainingLog]:
"""List training jobs for a tenant"""
try:
query = select(ModelTrainingLog).where(
ModelTrainingLog.tenant_id == tenant_id
).order_by(ModelTrainingLog.start_time.desc()).limit(limit)
if status_filter:
query = query.where(ModelTrainingLog.status == status_filter)
result = await db.execute(query)
return result.scalars().all()
except Exception as e:
logger.error(f"Failed to list training jobs: {str(e)}")
return []
async def cancel_training_job(self,
db: AsyncSession,
job_id: str,
tenant_id: str) -> bool:
"""Cancel a training job"""
try:
result = await db.execute(
update(ModelTrainingLog)
.where(
and_(
ModelTrainingLog.job_id == job_id,
ModelTrainingLog.tenant_id == tenant_id,
ModelTrainingLog.status.in_(["pending", "running"])
)
)
.values(
status="cancelled",
end_time=datetime.now(),
current_step="Training cancelled by user"
)
)
await db.commit()
if result.rowcount > 0:
logger.info(f"Cancelled training job {job_id}")
return True
else:
logger.warning(f"Could not cancel training job {job_id} - not found or not cancellable")
return False
except Exception as e:
logger.error(f"Failed to cancel training job: {str(e)}")
await db.rollback()
return False
async def validate_training_data(self,
db: AsyncSession,
tenant_id: str,
config: Dict[str, Any]) -> Dict[str, Any]:
"""Validate training data before starting a job"""
Returns:
Validation results
"""
try:
logger.info(f"Validating training data for tenant {tenant_id}")
issues = []
recommendations = []
# Fetch a sample of sales data to validate
sales_data = await self._fetch_sales_data(tenant_id, config, limit=1000)
# Extract sales date range for validation
if not sales_data:
issues.append("No sales data found for tenant")
return {
"is_valid": False,
"issues": issues,
"recommendations": ["Upload sales data before training"],
"estimated_time_minutes": 0
"valid": False,
"errors": ["No sales data provided"],
"warnings": []
}
# Analyze data quality
products = set(item.get("product_name") for item in sales_data)
total_records = len(sales_data)
# Check for sufficient data per product
product_counts = {}
for item in sales_data:
product = item.get("product_name")
if product:
product_counts[product] = product_counts.get(product, 0) + 1
insufficient_products = [
product for product, count in product_counts.items()
if count < config.get("min_data_points", 30)
]
if insufficient_products:
issues.append(f"Insufficient data for products: {', '.join(insufficient_products)}")
recommendations.append("Collect more historical data for these products")
# Estimate training time
valid_products = len(products) - len(insufficient_products)
estimated_time = max(5, valid_products * 2) # 2 minutes per product minimum
is_valid = len(issues) == 0
return {
"is_valid": is_valid,
"issues": issues,
"recommendations": recommendations,
"estimated_time_minutes": estimated_time,
"products_analyzed": len(products),
"total_data_points": total_records
}
except Exception as e:
logger.error(f"Failed to validate training data: {str(e)}")
return {
"is_valid": False,
"issues": [f"Validation error: {str(e)}"],
"recommendations": ["Check data service connectivity"],
"estimated_time_minutes": 0
}
async def _update_job_status(self,
db: AsyncSession,
job_id: str,
status: str,
progress: int,
current_step: str,
results: Optional[Dict] = None,
error_message: Optional[str] = None):
"""Update training job status"""
try:
update_values = {
"status": status,
"progress": progress,
"current_step": current_step
}
if status == "completed":
update_values["end_time"] = datetime.now()
if results:
update_values["results"] = results
if error_message:
update_values["error_message"] = error_message
update_values["end_time"] = datetime.now()
await db.execute(
update(ModelTrainingLog)
.where(ModelTrainingLog.job_id == job_id)
.values(**update_values)
# Create a mock training dataset to validate
mock_dataset = await self.orchestrator.prepare_training_data(
tenant_id=tenant_id,
sales_data=sales_data,
bakery_location=(40.4168, -3.7038), # Default Madrid
job_id=f"validation_{uuid.uuid4().hex[:8]}"
)
await db.commit()
# Validate the dataset
validation_results = self.orchestrator.validate_training_data_quality(mock_dataset)
# Add product-specific information
unique_products = list(set(record.get('product_name', 'unknown') for record in sales_data))
product_data_points = {}
for record in sales_data:
product = record.get('product_name', 'unknown')
product_data_points[product] = product_data_points.get(product, 0) + 1
validation_results.update({
"products_found": unique_products,
"product_data_points": product_data_points,
"total_records": len(sales_data),
"date_range_info": {
"start": mock_dataset.date_range.start.isoformat(),
"end": mock_dataset.date_range.end.isoformat(),
"duration_days": (mock_dataset.date_range.end - mock_dataset.date_range.start).days
}
})
return validation_results
except Exception as e:
logger.error(f"Failed to update job status: {str(e)}")
await db.rollback()
logger.error(f"Training data validation failed: {str(e)}")
return {
"valid": False,
"errors": [f"Validation failed: {str(e)}"],
"warnings": []
}
async def _store_trained_models(self,
db: AsyncSession,
tenant_id: str,
training_results: Dict[str, Any]):
"""Store trained models in database"""
async def get_training_recommendations(
self,
tenant_id: str,
sales_data: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Get training recommendations based on data analysis.
Args:
tenant_id: Tenant identifier
sales_data: Historical sales data
Returns:
Training recommendations
"""
try:
models_to_store = []
logger.info(f"Generating training recommendations for tenant {tenant_id}")
for product_name, result in training_results.get("training_results", {}).items():
if result.get("status") == "success":
model_info = result.get("model_info", {})
trained_model = TrainedModel(
tenant_id=tenant_id,
product_name=product_name,
model_id=model_info.get("model_id"),
model_type=model_info.get("type", "prophet"),
model_path=model_info.get("model_path"),
version=1, # Start with version 1
training_samples=model_info.get("training_samples", 0),
features=model_info.get("features", []),
hyperparameters=model_info.get("hyperparameters", {}),
training_metrics=model_info.get("training_metrics", {}),
data_period_start=datetime.fromisoformat(
model_info.get("data_period", {}).get("start_date", datetime.now().isoformat())
),
data_period_end=datetime.fromisoformat(
model_info.get("data_period", {}).get("end_date", datetime.now().isoformat())
),
created_at=datetime.now(),
is_active=True
)
models_to_store.append(trained_model)
# Analyze the data
validation_results = await self.validate_training_data(tenant_id, sales_data)
# Deactivate old models for these products
if models_to_store:
product_names = [model.product_name for model in models_to_store]
await db.execute(
update(TrainedModel)
.where(
and_(
TrainedModel.tenant_id == tenant_id,
TrainedModel.product_name.in_(product_names),
TrainedModel.is_active == True
)
)
.values(is_active=False)
)
# Add new models
db.add_all(models_to_store)
await db.commit()
logger.info(f"Stored {len(models_to_store)} trained models for tenant {tenant_id}")
recommendations = {
"should_retrain": True,
"reasons": [],
"recommended_products": [],
"optimal_config": {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30,
"hyperparameter_optimization": True
}
}
# Analyze data quality and provide recommendations
if validation_results.get("data_quality_score", 0) >= 80:
recommendations["reasons"].append("High quality data detected")
else:
recommendations["reasons"].append("Data quality could be improved")
# Recommend products with sufficient data
product_data_points = validation_results.get("product_data_points", {})
for product, points in product_data_points.items():
if points >= 30: # Minimum viable data points
recommendations["recommended_products"].append(product)
if len(recommendations["recommended_products"]) == 0:
recommendations["should_retrain"] = False
recommendations["reasons"].append("Insufficient data for reliable training")
return recommendations
except Exception as e:
logger.error(f"Failed to store trained models: {str(e)}")
await db.rollback()
raise
async def _store_single_trained_model(self,
db: AsyncSession,
tenant_id: str,
product_name: str,
training_result: Dict[str, Any]):
"""Store a single trained model"""
try:
if training_result.get("status") == "success":
model_info = training_result.get("model_info", {})
# Deactivate old model for this product
await db.execute(
update(TrainedModel)
.where(
and_(
TrainedModel.tenant_id == tenant_id,
TrainedModel.product_name == product_name,
TrainedModel.is_active == True
)
)
.values(is_active=False)
)
# Create new model record
trained_model = TrainedModel(
tenant_id=tenant_id,
product_name=product_name,
model_id=model_info.get("model_id"),
model_type=model_info.get("type", "prophet"),
model_path=model_info.get("model_path"),
version=1,
training_samples=model_info.get("training_samples", 0),
features=model_info.get("features", []),
hyperparameters=model_info.get("hyperparameters", {}),
training_metrics=model_info.get("training_metrics", {}),
data_period_start=datetime.fromisoformat(
model_info.get("data_period", {}).get("start_date", datetime.now().isoformat())
),
data_period_end=datetime.fromisoformat(
model_info.get("data_period", {}).get("end_date", datetime.now().isoformat())
),
created_at=datetime.now(),
is_active=True
)
db.add(trained_model)
await db.commit()
logger.info(f"Stored trained model for {product_name}")
except Exception as e:
logger.error(f"Failed to store trained model: {str(e)}")
await db.rollback()
raise
async def get_training_logs(self,
db: AsyncSession,
job_id: str,
tenant_id: str) -> Optional[List[str]]:
"""Get detailed training logs for a job"""
try:
# For now, return basic log information from the database
# In a production system, you might store detailed logs separately
result = await db.execute(
select(ModelTrainingLog).where(
and_(
ModelTrainingLog.job_id == job_id,
ModelTrainingLog.tenant_id == tenant_id
)
)
)
training_log = result.scalar_one_or_none()
if training_log:
logs = [
f"Job started at: {training_log.start_time}",
f"Current status: {training_log.status}",
f"Progress: {training_log.progress}%",
f"Current step: {training_log.current_step}"
]
if training_log.end_time:
logs.append(f"Job completed at: {training_log.end_time}")
if training_log.error_message:
logs.append(f"Error: {training_log.error_message}")
if training_log.results:
results = training_log.results
logs.append(f"Models trained: {results.get('products_trained', 0)}")
logs.append(f"Models failed: {results.get('products_failed', 0)}")
return logs
return None
except Exception as e:
logger.error(f"Failed to get training logs: {str(e)}")
return None
async def _determine_sales_date_range(self, sales_data: List[Dict]) -> tuple[datetime, datetime]:
"""Determine start and end dates from sales data"""
if not sales_data:
raise ValueError("No sales data available to determine date range")
dates = []
for record in sales_data:
if 'date' in record:
if isinstance(record['date'], str):
dates.append(datetime.fromisoformat(record['date'].replace('Z', '+00:00')))
elif isinstance(record['date'], datetime):
dates.append(record['date'])
if not dates:
raise ValueError("No valid dates found in sales data")
start_date = min(dates)
end_date = max(dates)
logger.info(f"Determined sales date range: {start_date} to {end_date}")
return start_date, end_date
logger.error(f"Failed to generate training recommendations: {str(e)}")
return {
"should_retrain": False,
"reasons": [f"Error analyzing data: {str(e)}"],
"recommended_products": [],
"optimal_config": {}
}