Improve training code
This commit is contained in:
240
services/training/app/services/date_alignment_service.py
Normal file
240
services/training/app/services/date_alignment_service.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DataSourceType(Enum):
|
||||
BAKERY_SALES = "bakery_sales"
|
||||
MADRID_TRAFFIC = "madrid_traffic"
|
||||
WEATHER_FORECAST = "weather_forecast"
|
||||
|
||||
@dataclass
|
||||
class DateRange:
|
||||
start: datetime
|
||||
end: datetime
|
||||
source: DataSourceType
|
||||
|
||||
def duration_days(self) -> int:
|
||||
return (self.end - self.start).days
|
||||
|
||||
def overlaps_with(self, other: 'DateRange') -> bool:
|
||||
return self.start <= other.end and other.start <= self.end
|
||||
|
||||
@dataclass
|
||||
class AlignedDateRange:
|
||||
start: datetime
|
||||
end: datetime
|
||||
available_sources: List[DataSourceType]
|
||||
constraints: Dict[str, str]
|
||||
|
||||
class DateAlignmentService:
|
||||
"""
|
||||
Central service for managing and aligning dates across multiple data sources
|
||||
for the bakery sales prediction model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.MAX_TRAINING_RANGE_DAYS = 365 # Maximum training data range
|
||||
self.MIN_TRAINING_RANGE_DAYS = 30 # Minimum viable training data
|
||||
|
||||
def validate_and_align_dates(
|
||||
self,
|
||||
user_sales_range: DateRange,
|
||||
requested_start: Optional[datetime] = None,
|
||||
requested_end: Optional[datetime] = None
|
||||
) -> AlignedDateRange:
|
||||
"""
|
||||
Main method to validate and align dates across all data sources.
|
||||
|
||||
Args:
|
||||
user_sales_range: Date range of user-provided sales data
|
||||
requested_start: Optional explicit start date for training
|
||||
requested_end: Optional explicit end date for training
|
||||
|
||||
Returns:
|
||||
AlignedDateRange with validated start/end dates and available sources
|
||||
"""
|
||||
try:
|
||||
# Step 1: Determine the base date range
|
||||
base_range = self._determine_base_range(
|
||||
user_sales_range, requested_start, requested_end
|
||||
)
|
||||
|
||||
# Step 2: Apply data source constraints
|
||||
aligned_range = self._apply_data_source_constraints(base_range)
|
||||
|
||||
# Step 3: Validate final range
|
||||
self._validate_final_range(aligned_range)
|
||||
|
||||
logger.info(f"Date alignment completed: {aligned_range.start} to {aligned_range.end}")
|
||||
return aligned_range
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Date alignment failed: {str(e)}")
|
||||
raise ValueError(f"Unable to align dates: {str(e)}")
|
||||
|
||||
def _determine_base_range(
|
||||
self,
|
||||
user_sales_range: DateRange,
|
||||
requested_start: Optional[datetime],
|
||||
requested_end: Optional[datetime]
|
||||
) -> DateRange:
|
||||
"""Determine the base date range for training."""
|
||||
|
||||
# Use explicit dates if provided
|
||||
if requested_start and requested_end:
|
||||
if requested_end <= requested_start:
|
||||
raise ValueError("End date must be after start date")
|
||||
return DateRange(requested_start, requested_end, DataSourceType.BAKERY_SALES)
|
||||
|
||||
# Otherwise, use the user's sales data range as the foundation
|
||||
start_date = requested_start or user_sales_range.start
|
||||
end_date = requested_end or user_sales_range.end
|
||||
|
||||
# Ensure we don't exceed maximum training range
|
||||
if (end_date - start_date).days > self.MAX_TRAINING_RANGE_DAYS:
|
||||
start_date = end_date - timedelta(days=self.MAX_TRAINING_RANGE_DAYS)
|
||||
logger.warning(f"Limiting training range to {self.MAX_TRAINING_RANGE_DAYS} days")
|
||||
|
||||
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
|
||||
|
||||
def _apply_data_source_constraints(self, base_range: DateRange) -> AlignedDateRange:
|
||||
"""Apply constraints from each data source and determine final aligned range."""
|
||||
|
||||
current_month = datetime.now().replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
available_sources = [DataSourceType.BAKERY_SALES] # Always have sales data
|
||||
constraints = {}
|
||||
|
||||
# Madrid Traffic Data Constraint
|
||||
madrid_end_date = self._get_madrid_traffic_end_date()
|
||||
if base_range.end > madrid_end_date:
|
||||
# If requested end date is in current month, adjust it
|
||||
new_end = madrid_end_date
|
||||
constraints["madrid_traffic"] = f"Adjusted end date to {new_end.date()} (latest available traffic data)"
|
||||
logger.info(f"Madrid traffic constraint: end date adjusted to {new_end.date()}")
|
||||
else:
|
||||
new_end = base_range.end
|
||||
available_sources.append(DataSourceType.MADRID_TRAFFIC)
|
||||
|
||||
# Weather Forecast Constraint
|
||||
# Weather data available from yesterday backward
|
||||
weather_end_date = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=1)
|
||||
if base_range.end > weather_end_date:
|
||||
if new_end > weather_end_date:
|
||||
new_end = weather_end_date
|
||||
constraints["weather"] = f"Adjusted end date to {new_end.date()} (latest available weather data)"
|
||||
logger.info(f"Weather constraint: end date adjusted to {new_end.date()}")
|
||||
|
||||
if new_end >= base_range.start:
|
||||
available_sources.append(DataSourceType.WEATHER_FORECAST)
|
||||
|
||||
# Ensure minimum training period
|
||||
final_start = base_range.start
|
||||
if (new_end - final_start).days < self.MIN_TRAINING_RANGE_DAYS:
|
||||
final_start = new_end - timedelta(days=self.MIN_TRAINING_RANGE_DAYS)
|
||||
constraints["minimum_period"] = f"Adjusted start date to ensure {self.MIN_TRAINING_RANGE_DAYS} day minimum training period"
|
||||
logger.info(f"Minimum period constraint: start date adjusted to {final_start.date()}")
|
||||
|
||||
return AlignedDateRange(
|
||||
start=final_start,
|
||||
end=new_end,
|
||||
available_sources=available_sources,
|
||||
constraints=constraints
|
||||
)
|
||||
|
||||
def _get_madrid_traffic_end_date(self) -> datetime:
|
||||
"""
|
||||
Get the latest available date for Madrid traffic data.
|
||||
Data for current month is not available until the following month.
|
||||
"""
|
||||
now = datetime.now()
|
||||
if now.day == 1:
|
||||
# If it's the first day of the month, data up to previous month should be available
|
||||
last_available_month = now.replace(day=1) - timedelta(days=1)
|
||||
else:
|
||||
# Data up to the previous month is available
|
||||
last_available_month = now.replace(day=1) - timedelta(days=1)
|
||||
|
||||
# Return the last day of the last available month
|
||||
if last_available_month.month == 12:
|
||||
next_month = last_available_month.replace(year=last_available_month.year + 1, month=1)
|
||||
else:
|
||||
next_month = last_available_month.replace(month=last_available_month.month + 1)
|
||||
|
||||
return next_month - timedelta(days=1)
|
||||
|
||||
def _validate_final_range(self, aligned_range: AlignedDateRange) -> None:
|
||||
"""Validate the final aligned date range."""
|
||||
|
||||
if aligned_range.start >= aligned_range.end:
|
||||
raise ValueError("Invalid date range: start date must be before end date")
|
||||
|
||||
duration = (aligned_range.end - aligned_range.start).days
|
||||
|
||||
if duration < self.MIN_TRAINING_RANGE_DAYS:
|
||||
raise ValueError(f"Insufficient training data: {duration} days (minimum: {self.MIN_TRAINING_RANGE_DAYS})")
|
||||
|
||||
if duration > self.MAX_TRAINING_RANGE_DAYS:
|
||||
raise ValueError(f"Training period too long: {duration} days (maximum: {self.MAX_TRAINING_RANGE_DAYS})")
|
||||
|
||||
# Ensure we have at least sales data
|
||||
if DataSourceType.BAKERY_SALES not in aligned_range.available_sources:
|
||||
raise ValueError("No sales data available for the aligned date range")
|
||||
|
||||
def get_data_collection_plan(self, aligned_range: AlignedDateRange) -> Dict[str, Dict]:
|
||||
"""
|
||||
Generate a data collection plan based on the aligned date range.
|
||||
|
||||
Returns:
|
||||
Dictionary with collection plans for each data source
|
||||
"""
|
||||
plan = {}
|
||||
|
||||
# Bakery Sales Data
|
||||
if DataSourceType.BAKERY_SALES in aligned_range.available_sources:
|
||||
plan["sales_data"] = {
|
||||
"start_date": aligned_range.start,
|
||||
"end_date": aligned_range.end,
|
||||
"source": "user_upload",
|
||||
"required": True
|
||||
}
|
||||
|
||||
# Madrid Traffic Data
|
||||
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
|
||||
plan["traffic_data"] = {
|
||||
"start_date": aligned_range.start,
|
||||
"end_date": aligned_range.end,
|
||||
"source": "madrid_opendata",
|
||||
"required": False,
|
||||
"constraint": "Cannot request current month data"
|
||||
}
|
||||
|
||||
# Weather Data
|
||||
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
|
||||
plan["weather_data"] = {
|
||||
"start_date": aligned_range.start,
|
||||
"end_date": aligned_range.end,
|
||||
"source": "aemet_api",
|
||||
"required": False,
|
||||
"constraint": "Available from yesterday backward"
|
||||
}
|
||||
|
||||
return plan
|
||||
|
||||
def check_madrid_current_month_constraint(self, end_date: datetime) -> bool:
|
||||
"""
|
||||
Check if the end date violates the Madrid Open Data current month constraint.
|
||||
|
||||
Args:
|
||||
end_date: The requested end date
|
||||
|
||||
Returns:
|
||||
True if the constraint is violated (end date is in current month)
|
||||
"""
|
||||
now = datetime.now()
|
||||
current_month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
return end_date >= current_month_start
|
||||
706
services/training/app/services/training_orchestrator.py
Normal file
706
services/training/app/services/training_orchestrator.py
Normal file
@@ -0,0 +1,706 @@
|
||||
# services/training/app/services/training_orchestrator.py
|
||||
"""
|
||||
Training Data Orchestrator - Enhanced Integration Layer
|
||||
Orchestrates data collection, date alignment, and preparation for ML training
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from app.services.data_client import DataServiceClient
|
||||
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TrainingDataSet:
|
||||
"""Container for all training data with metadata"""
|
||||
sales_data: List[Dict[str, Any]]
|
||||
weather_data: List[Dict[str, Any]]
|
||||
traffic_data: List[Dict[str, Any]]
|
||||
date_range: AlignedDateRange
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
class TrainingDataOrchestrator:
|
||||
"""
|
||||
Enhanced orchestrator for data collection from multiple sources.
|
||||
Ensures date alignment, handles data source constraints, and prepares data for ML training.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
madrid_client=None,
|
||||
weather_client=None,
|
||||
date_alignment_service: DateAlignmentService = None):
|
||||
self.madrid_client = madrid_client
|
||||
self.weather_client = weather_client
|
||||
self.data_client = DataServiceClient()
|
||||
self.date_alignment_service = date_alignment_service or DateAlignmentService()
|
||||
self.max_concurrent_requests = 3
|
||||
|
||||
async def prepare_training_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
bakery_location: Tuple[float, float], # (lat, lon)
|
||||
requested_start: Optional[datetime] = None,
|
||||
requested_end: Optional[datetime] = None,
|
||||
job_id: Optional[str] = None
|
||||
) -> TrainingDataSet:
|
||||
"""
|
||||
Main method to prepare all training data with comprehensive date alignment.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
sales_data: User-provided sales data
|
||||
bakery_location: Bakery coordinates (lat, lon)
|
||||
requested_start: Optional explicit start date
|
||||
requested_end: Optional explicit end date
|
||||
job_id: Training job identifier for logging
|
||||
|
||||
Returns:
|
||||
TrainingDataSet with all aligned and validated data
|
||||
"""
|
||||
logger.info(f"Starting comprehensive training data preparation for tenant {tenant_id}, job {job_id}")
|
||||
|
||||
try:
|
||||
|
||||
sales_data = self.data_client.fetch_sales_data(tenant_id)
|
||||
|
||||
# Step 1: Extract and validate sales data date range
|
||||
sales_date_range = self._extract_sales_date_range(sales_data)
|
||||
logger.info(f"Sales data range detected: {sales_date_range.start} to {sales_date_range.end}")
|
||||
|
||||
# Step 2: Apply date alignment across all data sources
|
||||
aligned_range = self.date_alignment_service.validate_and_align_dates(
|
||||
user_sales_range=sales_date_range,
|
||||
requested_start=requested_start,
|
||||
requested_end=requested_end
|
||||
)
|
||||
|
||||
logger.info(f"Date alignment completed: {aligned_range.start} to {aligned_range.end}")
|
||||
if aligned_range.constraints:
|
||||
logger.info(f"Applied constraints: {aligned_range.constraints}")
|
||||
|
||||
# Step 3: Filter sales data to aligned date range
|
||||
filtered_sales = self._filter_sales_data(sales_data, aligned_range)
|
||||
|
||||
# Step 4: Collect external data sources concurrently
|
||||
logger.info("Collecting external data sources...")
|
||||
weather_data, traffic_data = await self._collect_external_data(
|
||||
aligned_range, bakery_location
|
||||
)
|
||||
|
||||
# Step 5: Validate data quality
|
||||
data_quality_results = self._validate_data_sources(
|
||||
filtered_sales, weather_data, traffic_data, aligned_range
|
||||
)
|
||||
|
||||
# Step 6: Create comprehensive training dataset
|
||||
training_dataset = TrainingDataSet(
|
||||
sales_data=filtered_sales,
|
||||
weather_data=weather_data,
|
||||
traffic_data=traffic_data,
|
||||
date_range=aligned_range,
|
||||
metadata={
|
||||
"tenant_id": tenant_id,
|
||||
"job_id": job_id,
|
||||
"bakery_location": bakery_location,
|
||||
"data_sources_used": aligned_range.available_sources,
|
||||
"constraints_applied": aligned_range.constraints,
|
||||
"data_quality": data_quality_results,
|
||||
"preparation_timestamp": datetime.now().isoformat(),
|
||||
"original_sales_range": {
|
||||
"start": sales_date_range.start.isoformat(),
|
||||
"end": sales_date_range.end.isoformat()
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Step 7: Final validation
|
||||
final_validation = self.validate_training_data_quality(training_dataset)
|
||||
training_dataset.metadata["final_validation"] = final_validation
|
||||
|
||||
logger.info(f"Training data preparation completed successfully:")
|
||||
logger.info(f" - Sales records: {len(filtered_sales)}")
|
||||
logger.info(f" - Weather records: {len(weather_data)}")
|
||||
logger.info(f" - Traffic records: {len(traffic_data)}")
|
||||
logger.info(f" - Data quality score: {final_validation.get('data_quality_score', 'N/A')}")
|
||||
|
||||
return training_dataset
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training data preparation failed: {str(e)}")
|
||||
raise ValueError(f"Failed to prepare training data: {str(e)}")
|
||||
|
||||
def _extract_sales_date_range(self, sales_data: List[Dict[str, Any]]) -> DateRange:
|
||||
"""Extract and validate the date range from sales data"""
|
||||
if not sales_data:
|
||||
raise ValueError("No sales data provided")
|
||||
|
||||
dates = []
|
||||
valid_records = 0
|
||||
|
||||
for record in sales_data:
|
||||
try:
|
||||
if 'date' in record:
|
||||
date_val = record['date']
|
||||
if isinstance(date_val, str):
|
||||
# Handle various date formats
|
||||
if 'T' in date_val:
|
||||
date_val = date_val.replace('Z', '+00:00')
|
||||
parsed_date = datetime.fromisoformat(date_val.split('T')[0])
|
||||
elif isinstance(date_val, datetime):
|
||||
parsed_date = date_val
|
||||
else:
|
||||
continue
|
||||
|
||||
dates.append(parsed_date)
|
||||
valid_records += 1
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"Invalid date in sales record: {record.get('date', 'N/A')} - {str(e)}")
|
||||
continue
|
||||
|
||||
if not dates:
|
||||
raise ValueError("No valid dates found in sales data")
|
||||
|
||||
logger.info(f"Processed {valid_records} valid date records from {len(sales_data)} total records")
|
||||
|
||||
return DateRange(
|
||||
start=min(dates),
|
||||
end=max(dates),
|
||||
source=DataSourceType.BAKERY_SALES
|
||||
)
|
||||
|
||||
def _filter_sales_data(
|
||||
self,
|
||||
sales_data: List[Dict[str, Any]],
|
||||
aligned_range: AlignedDateRange
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Filter sales data to the aligned date range with enhanced validation"""
|
||||
filtered_data = []
|
||||
filtered_count = 0
|
||||
|
||||
for record in sales_data:
|
||||
try:
|
||||
if 'date' in record:
|
||||
record_date = record['date']
|
||||
if isinstance(record_date, str):
|
||||
if 'T' in record_date:
|
||||
record_date = record_date.replace('Z', '+00:00')
|
||||
record_date = datetime.fromisoformat(record_date.split('T')[0])
|
||||
elif isinstance(record_date, datetime):
|
||||
record_date = record_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# Check if date falls within aligned range
|
||||
if aligned_range.start <= record_date <= aligned_range.end:
|
||||
# Validate that record has required fields
|
||||
if self._validate_sales_record(record):
|
||||
filtered_data.append(record)
|
||||
else:
|
||||
filtered_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing sales record: {str(e)}")
|
||||
filtered_count += 1
|
||||
continue
|
||||
|
||||
logger.info(f"Filtered sales data: {len(filtered_data)} records in aligned range")
|
||||
if filtered_count > 0:
|
||||
logger.warning(f"Filtered out {filtered_count} invalid records")
|
||||
|
||||
return filtered_data
|
||||
|
||||
def _validate_sales_record(self, record: Dict[str, Any]) -> bool:
|
||||
"""Validate individual sales record"""
|
||||
required_fields = ['date', 'product_name']
|
||||
quantity_fields = ['quantity', 'quantity_sold', 'sales', 'units_sold']
|
||||
|
||||
# Check required fields
|
||||
for field in required_fields:
|
||||
if field not in record or record[field] is None:
|
||||
return False
|
||||
|
||||
# Check at least one quantity field exists
|
||||
has_quantity = any(field in record and record[field] is not None for field in quantity_fields)
|
||||
if not has_quantity:
|
||||
return False
|
||||
|
||||
# Validate quantity is numeric and non-negative
|
||||
for field in quantity_fields:
|
||||
if field in record and record[field] is not None:
|
||||
try:
|
||||
quantity = float(record[field])
|
||||
if quantity < 0:
|
||||
return False
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
break
|
||||
|
||||
return True
|
||||
|
||||
async def _collect_external_data(
|
||||
self,
|
||||
aligned_range: AlignedDateRange,
|
||||
bakery_location: Tuple[float, float]
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
"""Collect weather and traffic data concurrently with enhanced error handling"""
|
||||
|
||||
lat, lon = bakery_location
|
||||
|
||||
# Create collection tasks with timeout
|
||||
tasks = []
|
||||
|
||||
# Weather data collection
|
||||
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
|
||||
weather_task = asyncio.create_task(
|
||||
self._collect_weather_data_with_timeout(lat, lon, aligned_range)
|
||||
)
|
||||
tasks.append(("weather", weather_task))
|
||||
|
||||
# Traffic data collection
|
||||
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
|
||||
traffic_task = asyncio.create_task(
|
||||
self._collect_traffic_data_with_timeout(lat, lon, aligned_range)
|
||||
)
|
||||
tasks.append(("traffic", traffic_task))
|
||||
|
||||
# Execute tasks concurrently with proper error handling
|
||||
results = {}
|
||||
if tasks:
|
||||
try:
|
||||
completed_tasks = await asyncio.gather(
|
||||
*[task for _, task in tasks],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
for i, (task_name, _) in enumerate(tasks):
|
||||
result = completed_tasks[i]
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"{task_name} data collection failed: {result}")
|
||||
results[task_name] = []
|
||||
else:
|
||||
results[task_name] = result
|
||||
logger.info(f"{task_name} data collection completed: {len(result)} records")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in concurrent data collection: {str(e)}")
|
||||
results = {"weather": [], "traffic": []}
|
||||
|
||||
weather_data = results.get("weather", [])
|
||||
traffic_data = results.get("traffic", [])
|
||||
|
||||
return weather_data, traffic_data
|
||||
|
||||
async def _collect_weather_data_with_timeout(
|
||||
self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Collect weather data with timeout and fallback"""
|
||||
try:
|
||||
|
||||
if not self.weather_client:
|
||||
logger.info("Weather client not configured, using synthetic data")
|
||||
return self._generate_synthetic_weather_data(aligned_range)
|
||||
|
||||
weather_data = await asyncio.wait_for(
|
||||
self.data_client.fetch_weather_data(aligned_range.start, aligned_range.end, lat, lon),
|
||||
)
|
||||
|
||||
# Validate weather data
|
||||
if self._validate_weather_data(weather_data):
|
||||
logger.info(f"Collected {len(weather_data)} valid weather records")
|
||||
return weather_data
|
||||
else:
|
||||
logger.warning("Invalid weather data received, using synthetic data")
|
||||
return self._generate_synthetic_weather_data(aligned_range)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Weather data collection timed out after {timeout_seconds}s, using synthetic data")
|
||||
return self._generate_synthetic_weather_data(aligned_range)
|
||||
except Exception as e:
|
||||
logger.warning(f"Weather data collection failed: {e}, using synthetic data")
|
||||
return self._generate_synthetic_weather_data(aligned_range)
|
||||
|
||||
async def _collect_traffic_data_with_timeout(
|
||||
self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Collect traffic data with timeout and Madrid constraint validation"""
|
||||
try:
|
||||
|
||||
if not self.madrid_client:
|
||||
logger.info("Madrid client not configured, no traffic data available")
|
||||
return []
|
||||
|
||||
# Double-check Madrid constraint before making request
|
||||
if self.date_alignment_service.check_madrid_current_month_constraint(aligned_range.end):
|
||||
logger.warning("Madrid current month constraint violation, no traffic data available")
|
||||
return []
|
||||
|
||||
traffic_data = await asyncio.wait_for(
|
||||
self.data_client.fetch_traffic_data(aligned_range.start, aligned_range.end, lat, lon),
|
||||
)
|
||||
|
||||
# Validate traffic data
|
||||
if self._validate_traffic_data(traffic_data):
|
||||
logger.info(f"Collected {len(traffic_data)} valid traffic records")
|
||||
return traffic_data
|
||||
else:
|
||||
logger.warning("Invalid traffic data received")
|
||||
return []
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Traffic data collection timed out after {timeout_seconds}s")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Traffic data collection failed: {e}")
|
||||
return []
|
||||
|
||||
def _validate_weather_data(self, weather_data: List[Dict[str, Any]]) -> bool:
|
||||
"""Validate weather data quality"""
|
||||
if not weather_data:
|
||||
return False
|
||||
|
||||
required_fields = ['date']
|
||||
weather_fields = ['temperature', 'temp', 'temperatura', 'precipitation', 'rain', 'lluvia']
|
||||
|
||||
valid_records = 0
|
||||
for record in weather_data:
|
||||
# Check required fields
|
||||
if not all(field in record for field in required_fields):
|
||||
continue
|
||||
|
||||
# Check at least one weather field exists
|
||||
if any(field in record and record[field] is not None for field in weather_fields):
|
||||
valid_records += 1
|
||||
|
||||
# Consider valid if at least 50% of records are valid
|
||||
validity_threshold = 0.5
|
||||
is_valid = (valid_records / len(weather_data)) >= validity_threshold
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"Weather data validation failed: {valid_records}/{len(weather_data)} valid records")
|
||||
|
||||
return is_valid
|
||||
|
||||
def _validate_traffic_data(self, traffic_data: List[Dict[str, Any]]) -> bool:
|
||||
"""Validate traffic data quality"""
|
||||
if not traffic_data:
|
||||
return False
|
||||
|
||||
required_fields = ['date']
|
||||
traffic_fields = ['traffic_volume', 'traffic_intensity', 'intensidad', 'trafico']
|
||||
|
||||
valid_records = 0
|
||||
for record in traffic_data:
|
||||
# Check required fields
|
||||
if not all(field in record for field in required_fields):
|
||||
continue
|
||||
|
||||
# Check at least one traffic field exists
|
||||
if any(field in record and record[field] is not None for field in traffic_fields):
|
||||
valid_records += 1
|
||||
|
||||
# Consider valid if at least 30% of records are valid (traffic data is often sparse)
|
||||
validity_threshold = 0.3
|
||||
is_valid = (valid_records / len(traffic_data)) >= validity_threshold
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"Traffic data validation failed: {valid_records}/{len(traffic_data)} valid records")
|
||||
|
||||
return is_valid
|
||||
|
||||
def _validate_data_sources(
|
||||
self,
|
||||
sales_data: List[Dict[str, Any]],
|
||||
weather_data: List[Dict[str, Any]],
|
||||
traffic_data: List[Dict[str, Any]],
|
||||
aligned_range: AlignedDateRange
|
||||
) -> Dict[str, Any]:
|
||||
"""Validate all data sources and provide quality metrics"""
|
||||
|
||||
validation_results = {
|
||||
"sales_data": {
|
||||
"record_count": len(sales_data),
|
||||
"is_valid": len(sales_data) > 0,
|
||||
"coverage_days": (aligned_range.end - aligned_range.start).days,
|
||||
"quality_score": 0.0
|
||||
},
|
||||
"weather_data": {
|
||||
"record_count": len(weather_data),
|
||||
"is_valid": self._validate_weather_data(weather_data) if weather_data else False,
|
||||
"quality_score": 0.0
|
||||
},
|
||||
"traffic_data": {
|
||||
"record_count": len(traffic_data),
|
||||
"is_valid": self._validate_traffic_data(traffic_data) if traffic_data else False,
|
||||
"quality_score": 0.0
|
||||
},
|
||||
"overall_quality_score": 0.0
|
||||
}
|
||||
|
||||
# Calculate quality scores
|
||||
# Sales data quality (most important)
|
||||
if validation_results["sales_data"]["record_count"] > 0:
|
||||
coverage_ratio = min(1.0, validation_results["sales_data"]["record_count"] / validation_results["sales_data"]["coverage_days"])
|
||||
validation_results["sales_data"]["quality_score"] = coverage_ratio * 100
|
||||
|
||||
# Weather data quality
|
||||
if validation_results["weather_data"]["record_count"] > 0:
|
||||
expected_weather_records = (aligned_range.end - aligned_range.start).days
|
||||
coverage_ratio = min(1.0, validation_results["weather_data"]["record_count"] / expected_weather_records)
|
||||
validation_results["weather_data"]["quality_score"] = coverage_ratio * 100
|
||||
|
||||
# Traffic data quality
|
||||
if validation_results["traffic_data"]["record_count"] > 0:
|
||||
expected_traffic_records = (aligned_range.end - aligned_range.start).days
|
||||
coverage_ratio = min(1.0, validation_results["traffic_data"]["record_count"] / expected_traffic_records)
|
||||
validation_results["traffic_data"]["quality_score"] = coverage_ratio * 100
|
||||
|
||||
# Overall quality score (weighted by importance)
|
||||
weights = {"sales_data": 0.7, "weather_data": 0.2, "traffic_data": 0.1}
|
||||
overall_score = sum(
|
||||
validation_results[source]["quality_score"] * weight
|
||||
for source, weight in weights.items()
|
||||
)
|
||||
validation_results["overall_quality_score"] = round(overall_score, 2)
|
||||
|
||||
return validation_results
|
||||
|
||||
def _generate_synthetic_weather_data(
|
||||
self,
|
||||
aligned_range: AlignedDateRange
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Generate realistic synthetic weather data for Madrid"""
|
||||
synthetic_data = []
|
||||
current_date = aligned_range.start
|
||||
|
||||
# Madrid seasonal temperature patterns
|
||||
seasonal_temps = {
|
||||
1: 9, 2: 11, 3: 15, 4: 17, 5: 21, 6: 26,
|
||||
7: 29, 8: 28, 9: 24, 10: 18, 11: 12, 12: 9
|
||||
}
|
||||
|
||||
while current_date <= aligned_range.end:
|
||||
month = current_date.month
|
||||
base_temp = seasonal_temps.get(month, 15)
|
||||
|
||||
# Add some realistic variation
|
||||
import random
|
||||
temp_variation = random.gauss(0, 3) # ±3°C variation
|
||||
temperature = max(0, base_temp + temp_variation)
|
||||
|
||||
# Precipitation patterns (Madrid is relatively dry)
|
||||
precipitation = 0.0
|
||||
if random.random() < 0.15: # 15% chance of rain
|
||||
precipitation = random.uniform(0.1, 15.0)
|
||||
|
||||
synthetic_data.append({
|
||||
"date": current_date,
|
||||
"temperature": round(temperature, 1),
|
||||
"precipitation": round(precipitation, 1),
|
||||
"humidity": round(random.uniform(40, 80), 1),
|
||||
"wind_speed": round(random.uniform(2, 15), 1),
|
||||
"pressure": round(random.uniform(1005, 1025), 1),
|
||||
"source": "synthetic_madrid_pattern"
|
||||
})
|
||||
|
||||
current_date = current_date + timedelta(days=1)
|
||||
|
||||
logger.info(f"Generated {len(synthetic_data)} synthetic weather records with Madrid patterns")
|
||||
return synthetic_data
|
||||
|
||||
def validate_training_data_quality(self, dataset: TrainingDataSet) -> Dict[str, Any]:
|
||||
"""Enhanced validation of training data quality"""
|
||||
validation_results = {
|
||||
"is_valid": True,
|
||||
"warnings": [],
|
||||
"errors": [],
|
||||
"data_quality_score": 100.0,
|
||||
"recommendations": []
|
||||
}
|
||||
|
||||
# Check sales data completeness
|
||||
sales_count = len(dataset.sales_data)
|
||||
if sales_count < 30:
|
||||
validation_results["warnings"].append(
|
||||
f"Limited sales data: {sales_count} records (recommended: 30+)"
|
||||
)
|
||||
validation_results["data_quality_score"] -= 20
|
||||
validation_results["recommendations"].append("Consider collecting more historical sales data")
|
||||
elif sales_count < 90:
|
||||
validation_results["warnings"].append(
|
||||
f"Moderate sales data: {sales_count} records (optimal: 90+)"
|
||||
)
|
||||
validation_results["data_quality_score"] -= 10
|
||||
|
||||
# Check date coverage
|
||||
date_coverage = (dataset.date_range.end - dataset.date_range.start).days
|
||||
if date_coverage < 90:
|
||||
validation_results["warnings"].append(
|
||||
f"Limited date coverage: {date_coverage} days (recommended: 90+)"
|
||||
)
|
||||
validation_results["data_quality_score"] -= 15
|
||||
validation_results["recommendations"].append("Extend date range for better seasonality detection")
|
||||
|
||||
# Check external data availability
|
||||
if not dataset.weather_data:
|
||||
validation_results["warnings"].append("No weather data available")
|
||||
validation_results["data_quality_score"] -= 10
|
||||
validation_results["recommendations"].append("Weather data improves forecast accuracy")
|
||||
elif len(dataset.weather_data) < date_coverage * 0.5:
|
||||
validation_results["warnings"].append("Sparse weather data coverage")
|
||||
validation_results["data_quality_score"] -= 5
|
||||
|
||||
if not dataset.traffic_data:
|
||||
validation_results["warnings"].append("No traffic data available")
|
||||
validation_results["data_quality_score"] -= 5
|
||||
validation_results["recommendations"].append("Traffic data can help with location-based patterns")
|
||||
|
||||
# Check data consistency
|
||||
unique_products = set()
|
||||
for record in dataset.sales_data:
|
||||
if 'product_name' in record:
|
||||
unique_products.add(record['product_name'])
|
||||
|
||||
if len(unique_products) == 0:
|
||||
validation_results["errors"].append("No product names found in sales data")
|
||||
validation_results["is_valid"] = False
|
||||
elif len(unique_products) > 50:
|
||||
validation_results["warnings"].append(
|
||||
f"Many products detected ({len(unique_products)}). Consider training models in batches."
|
||||
)
|
||||
validation_results["recommendations"].append("Group similar products for better training efficiency")
|
||||
|
||||
# Check for data source constraints
|
||||
if dataset.date_range.constraints:
|
||||
constraint_info = []
|
||||
for constraint_type, message in dataset.date_range.constraints.items():
|
||||
constraint_info.append(f"{constraint_type}: {message}")
|
||||
|
||||
validation_results["warnings"].append(
|
||||
f"Data source constraints applied: {'; '.join(constraint_info)}"
|
||||
)
|
||||
|
||||
# Final validation
|
||||
if validation_results["errors"]:
|
||||
validation_results["is_valid"] = False
|
||||
validation_results["data_quality_score"] = 0.0
|
||||
|
||||
# Ensure score doesn't go below 0
|
||||
validation_results["data_quality_score"] = max(0.0, validation_results["data_quality_score"])
|
||||
|
||||
# Add quality assessment
|
||||
score = validation_results["data_quality_score"]
|
||||
if score >= 80:
|
||||
validation_results["quality_assessment"] = "Excellent"
|
||||
elif score >= 60:
|
||||
validation_results["quality_assessment"] = "Good"
|
||||
elif score >= 40:
|
||||
validation_results["quality_assessment"] = "Fair"
|
||||
else:
|
||||
validation_results["quality_assessment"] = "Poor"
|
||||
|
||||
return validation_results
|
||||
|
||||
def get_data_collection_plan(self, aligned_range: AlignedDateRange) -> Dict[str, Dict]:
|
||||
"""
|
||||
Generate an enhanced data collection plan based on the aligned date range.
|
||||
"""
|
||||
plan = {
|
||||
"collection_summary": {
|
||||
"start_date": aligned_range.start.isoformat(),
|
||||
"end_date": aligned_range.end.isoformat(),
|
||||
"duration_days": (aligned_range.end - aligned_range.start).days,
|
||||
"available_sources": [source.value for source in aligned_range.available_sources],
|
||||
"constraints": aligned_range.constraints
|
||||
},
|
||||
"data_sources": {}
|
||||
}
|
||||
|
||||
# Bakery Sales Data
|
||||
if DataSourceType.BAKERY_SALES in aligned_range.available_sources:
|
||||
plan["data_sources"]["sales_data"] = {
|
||||
"start_date": aligned_range.start.isoformat(),
|
||||
"end_date": aligned_range.end.isoformat(),
|
||||
"source": "user_upload",
|
||||
"required": True,
|
||||
"priority": "high",
|
||||
"expected_records": "variable",
|
||||
"data_points": ["date", "product_name", "quantity"],
|
||||
"validation": "required_fields_check"
|
||||
}
|
||||
|
||||
# Madrid Traffic Data
|
||||
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
|
||||
plan["data_sources"]["traffic_data"] = {
|
||||
"start_date": aligned_range.start.isoformat(),
|
||||
"end_date": aligned_range.end.isoformat(),
|
||||
"source": "madrid_opendata",
|
||||
"required": False,
|
||||
"priority": "medium",
|
||||
"expected_records": (aligned_range.end - aligned_range.start).days,
|
||||
"constraint": "Cannot request current month data",
|
||||
"data_points": ["date", "traffic_volume", "congestion_level"],
|
||||
"validation": "date_constraint_check"
|
||||
}
|
||||
|
||||
# Weather Data
|
||||
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
|
||||
plan["data_sources"]["weather_data"] = {
|
||||
"start_date": aligned_range.start.isoformat(),
|
||||
"end_date": aligned_range.end.isoformat(),
|
||||
"source": "aemet_api",
|
||||
"required": False,
|
||||
"priority": "high",
|
||||
"expected_records": (aligned_range.end - aligned_range.start).days,
|
||||
"constraint": "Available from yesterday backward",
|
||||
"data_points": ["date", "temperature", "precipitation", "humidity"],
|
||||
"validation": "temporal_constraint_check",
|
||||
"fallback": "synthetic_madrid_weather"
|
||||
}
|
||||
|
||||
return plan
|
||||
|
||||
def get_orchestration_summary(self, dataset: TrainingDataSet) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a comprehensive summary of the orchestration process.
|
||||
"""
|
||||
return {
|
||||
"tenant_id": dataset.metadata.get("tenant_id"),
|
||||
"job_id": dataset.metadata.get("job_id"),
|
||||
"orchestration_completed_at": dataset.metadata.get("preparation_timestamp"),
|
||||
"data_alignment": {
|
||||
"original_range": dataset.metadata.get("original_sales_range"),
|
||||
"aligned_range": {
|
||||
"start": dataset.date_range.start.isoformat(),
|
||||
"end": dataset.date_range.end.isoformat(),
|
||||
"duration_days": (dataset.date_range.end - dataset.date_range.start).days
|
||||
},
|
||||
"constraints_applied": dataset.date_range.constraints,
|
||||
"available_sources": [source.value for source in dataset.date_range.available_sources]
|
||||
},
|
||||
"data_collection_results": {
|
||||
"sales_records": len(dataset.sales_data),
|
||||
"weather_records": len(dataset.weather_data),
|
||||
"traffic_records": len(dataset.traffic_data),
|
||||
"total_records": len(dataset.sales_data) + len(dataset.weather_data) + len(dataset.traffic_data)
|
||||
},
|
||||
"data_quality": dataset.metadata.get("data_quality", {}),
|
||||
"validation_results": dataset.metadata.get("final_validation", {}),
|
||||
"processing_metadata": {
|
||||
"bakery_location": dataset.metadata.get("bakery_location"),
|
||||
"data_sources_requested": len(dataset.date_range.available_sources),
|
||||
"data_sources_successful": sum([
|
||||
1 if len(dataset.sales_data) > 0 else 0,
|
||||
1 if len(dataset.weather_data) > 0 else 0,
|
||||
1 if len(dataset.traffic_data) > 0 else 0
|
||||
])
|
||||
}
|
||||
}
|
||||
@@ -1,721 +1,303 @@
|
||||
# services/training/app/services/training_service.py
|
||||
"""
|
||||
Training service business logic
|
||||
Orchestrates ML training operations and manages job lifecycle
|
||||
Main Training Service - Coordinates the complete training process
|
||||
This is the entry point from the API layer
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, and_
|
||||
import httpx
|
||||
|
||||
from app.models.training import ModelTrainingLog, TrainedModel
|
||||
from app.ml.trainer import BakeryMLTrainer
|
||||
from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
|
||||
from app.services.messaging import publish_job_completed, publish_job_failed
|
||||
from app.core.config import settings
|
||||
from shared.monitoring.metrics import MetricsCollector
|
||||
from app.services.data_client import DataServiceClient
|
||||
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
|
||||
from app.services.training_orchestrator import TrainingDataOrchestrator
|
||||
|
||||
from app.core.database import get_db_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
metrics = MetricsCollector("training-service")
|
||||
|
||||
class TrainingService:
|
||||
"""
|
||||
Main service class for managing ML training operations.
|
||||
Replaces the old Celery-based training system with clean async implementation.
|
||||
Main training service that coordinates the complete training pipeline.
|
||||
Entry point from API layer - handles business logic and orchestration.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.ml_trainer = BakeryMLTrainer()
|
||||
self.data_client = DataServiceClient()
|
||||
|
||||
async def _determine_sales_date_range(self, sales_data: List[Dict]) -> tuple[datetime, datetime]:
|
||||
"""Determine start and end dates from sales data with validation"""
|
||||
if not sales_data:
|
||||
raise ValueError("No sales data available to determine date range")
|
||||
|
||||
dates = []
|
||||
for record in sales_data:
|
||||
if 'date' in record:
|
||||
try:
|
||||
if isinstance(record['date'], str):
|
||||
# Handle various date string formats
|
||||
date_str = record['date'].replace('Z', '+00:00')
|
||||
if 'T' in date_str:
|
||||
parsed_date = datetime.fromisoformat(date_str)
|
||||
else:
|
||||
# Handle date-only strings
|
||||
parsed_date = datetime.strptime(date_str, '%Y-%m-%d')
|
||||
dates.append(parsed_date)
|
||||
elif isinstance(record['date'], datetime):
|
||||
dates.append(record['date'])
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"Invalid date format in record: {record['date']} - {e}")
|
||||
continue
|
||||
|
||||
if not dates:
|
||||
raise ValueError("No valid dates found in sales data")
|
||||
|
||||
start_date = min(dates)
|
||||
end_date = max(dates)
|
||||
|
||||
# Validate and adjust date range for external APIs
|
||||
start_date, end_date = self._adjust_date_range_for_apis(start_date, end_date)
|
||||
|
||||
logger.info(f"Determined and adjusted sales date range: {start_date} to {end_date}")
|
||||
return start_date, end_date
|
||||
|
||||
def _adjust_date_range_for_apis(self, start_date: datetime, end_date: datetime) -> tuple[datetime, datetime]:
|
||||
"""Adjust date range to comply with external API limits"""
|
||||
|
||||
# Weather and traffic APIs have a 90-day limit
|
||||
MAX_DAYS = 90
|
||||
|
||||
# Calculate current range
|
||||
current_range = (end_date - start_date).days
|
||||
|
||||
if current_range > MAX_DAYS:
|
||||
logger.warning(f"Date range ({current_range} days) exceeds API limit ({MAX_DAYS} days). Adjusting...")
|
||||
|
||||
# Keep the most recent data
|
||||
start_date = end_date - timedelta(days=MAX_DAYS)
|
||||
logger.info(f"Adjusted start_date to {start_date} to fit within {MAX_DAYS} day limit")
|
||||
|
||||
# Ensure dates are not in the future
|
||||
now = datetime.now()
|
||||
if end_date > now:
|
||||
end_date = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
logger.info(f"Adjusted end_date to {end_date} (cannot be in future)")
|
||||
|
||||
if start_date > now:
|
||||
start_date = now.replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=30)
|
||||
logger.info(f"Adjusted start_date to {start_date} (was in future)")
|
||||
|
||||
# Ensure start_date is before end_date
|
||||
if start_date >= end_date:
|
||||
start_date = end_date - timedelta(days=30) # Default to 30 days of data
|
||||
logger.warning(f"start_date was not before end_date. Adjusted start_date to {start_date}")
|
||||
def __init__(self, db_session: AsyncSession = None):
|
||||
self.db_session = db_session
|
||||
self.trainer = BakeryMLTrainer(db_session=db_session) # Pass DB session
|
||||
self.date_alignment_service = DateAlignmentService()
|
||||
self.orchestrator = TrainingDataOrchestrator(
|
||||
date_alignment_service=self.date_alignment_service
|
||||
)
|
||||
|
||||
return start_date, end_date
|
||||
async def start_training_job(
|
||||
self,
|
||||
tenant_id: str,
|
||||
bakery_location: tuple[float, float] = (40.4168, -3.7038), # Default Madrid
|
||||
requested_start: Optional[datetime] = None,
|
||||
requested_end: Optional[datetime] = None,
|
||||
job_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Start a complete training job for a tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
sales_data: Historical sales data
|
||||
bakery_location: Bakery coordinates (lat, lon)
|
||||
weather_data: Optional weather data
|
||||
traffic_data: Optional traffic data
|
||||
requested_start: Optional explicit start date
|
||||
requested_end: Optional explicit end date
|
||||
job_id: Optional job identifier
|
||||
|
||||
Returns:
|
||||
Training job results
|
||||
"""
|
||||
if not job_id:
|
||||
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
logger.info(f"Starting training job {job_id} for tenant {tenant_id}")
|
||||
|
||||
async def execute_training_job_simple(self, job_id: str, tenant_id_str: str, request: TrainingJobRequest):
|
||||
"""Simple wrapper that creates its own database session"""
|
||||
try:
|
||||
# Import database_manager locally to avoid circular imports
|
||||
from app.core.database import database_manager
|
||||
|
||||
logger.info(f"Starting background training job {job_id} for tenant {tenant_id_str}")
|
||||
|
||||
# Create new session for background task
|
||||
async with database_manager.async_session_local() as session:
|
||||
await self.execute_training_job(session, job_id, tenant_id_str, request)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Background training job {job_id} failed: {str(e)}")
|
||||
|
||||
# Try to update job status to failed
|
||||
try:
|
||||
from app.core.database import database_manager
|
||||
async with database_manager.async_session_local() as error_session:
|
||||
await self._update_job_status(
|
||||
error_session, job_id, "failed", 0,
|
||||
f"Training failed: {str(e)}", error_message=str(e)
|
||||
)
|
||||
await error_session.commit()
|
||||
except Exception as update_error:
|
||||
logger.error(f"Failed to update job status: {str(update_error)}")
|
||||
|
||||
raise
|
||||
|
||||
async def create_training_job(self,
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
job_id: str,
|
||||
config: Dict[str, Any]) -> ModelTrainingLog:
|
||||
"""Create a new training job record"""
|
||||
try:
|
||||
training_log = ModelTrainingLog(
|
||||
job_id=job_id,
|
||||
# Step 1: Prepare training dataset with date alignment and orchestration
|
||||
logger.info("Step 1: Preparing and aligning training data")
|
||||
training_dataset = await self.orchestrator.prepare_training_data(
|
||||
tenant_id=tenant_id,
|
||||
status="pending",
|
||||
progress=0,
|
||||
current_step="Initializing training job",
|
||||
start_time=datetime.now(),
|
||||
config=config
|
||||
)
|
||||
|
||||
db.add(training_log)
|
||||
await db.commit()
|
||||
await db.refresh(training_log)
|
||||
|
||||
logger.info(f"Created training job {job_id} for tenant {tenant_id}")
|
||||
return training_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create training job: {str(e)}")
|
||||
await db.rollback()
|
||||
raise
|
||||
|
||||
async def create_single_product_job(self,
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
job_id: str,
|
||||
config: Dict[str, Any]) -> ModelTrainingLog:
|
||||
"""Create a training job for a single product"""
|
||||
try:
|
||||
config["single_product"] = product_name
|
||||
|
||||
training_log = ModelTrainingLog(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
status="pending",
|
||||
progress=0,
|
||||
current_step=f"Initializing training for {product_name}",
|
||||
start_time=datetime.now(),
|
||||
config=config
|
||||
)
|
||||
|
||||
db.add(training_log)
|
||||
await db.commit()
|
||||
await db.refresh(training_log)
|
||||
|
||||
logger.info(f"Created single product training job {job_id} for {product_name}")
|
||||
return training_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create single product training job: {str(e)}")
|
||||
await db.rollback()
|
||||
raise
|
||||
|
||||
async def execute_training_job(self,
|
||||
db: AsyncSession,
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
request: TrainingJobRequest):
|
||||
"""Execute a complete training job"""
|
||||
try:
|
||||
logger.info(f"Starting execution of training job {job_id}")
|
||||
|
||||
# Update job status to running
|
||||
await self._update_job_status(db, job_id, "running", 5, "Fetching training data")
|
||||
|
||||
# Fetch sales data from data service
|
||||
sales_data = await self.data_client.fetch_sales_data(tenant_id)
|
||||
|
||||
if not sales_data:
|
||||
raise ValueError("No sales data found for training")
|
||||
|
||||
# Determine date range from sales data
|
||||
start_date, end_date = await self._determine_sales_date_range(sales_data)
|
||||
|
||||
# Convert dates to ISO format strings for API calls
|
||||
start_date_str = start_date.isoformat()
|
||||
end_date_str = end_date.isoformat()
|
||||
|
||||
logger.info(f"Using date range for external APIs: {start_date_str} to {end_date_str}")
|
||||
|
||||
# Fetch external data if requested using the sales date range
|
||||
weather_data = []
|
||||
traffic_data = []
|
||||
|
||||
await self._update_job_status(db, job_id, "running", 15, "Fetching weather data")
|
||||
try:
|
||||
weather_data = await self.data_client.fetch_weather_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
latitude=40.4168, # Madrid coordinates
|
||||
longitude=-3.7038
|
||||
)
|
||||
logger.info(f"Fetched {len(weather_data)} weather records")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch weather data: {e}. Continuing without weather data.")
|
||||
weather_data = []
|
||||
|
||||
await self._update_job_status(db, job_id, "running", 25, "Fetching traffic data")
|
||||
try:
|
||||
traffic_data = await self.data_client.fetch_traffic_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
latitude=40.4168,
|
||||
longitude=-3.7038
|
||||
)
|
||||
logger.info(f"Fetched {len(traffic_data)} traffic records")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch traffic data: {e}. Continuing without traffic data.")
|
||||
traffic_data = []
|
||||
|
||||
# Execute ML training
|
||||
await self._update_job_status(db, job_id, "running", 35, "Processing training data")
|
||||
|
||||
training_results = await self.ml_trainer.train_tenant_models(
|
||||
tenant_id=tenant_id,
|
||||
sales_data=sales_data,
|
||||
weather_data=weather_data,
|
||||
traffic_data=traffic_data,
|
||||
bakery_location=bakery_location,
|
||||
requested_start=requested_start,
|
||||
requested_end=requested_end,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
await self._update_job_status(db, job_id, "running", 85, "Storing trained models")
|
||||
|
||||
# Store trained models in database
|
||||
await self._store_trained_models(db, tenant_id, training_results)
|
||||
|
||||
await self._update_job_status(
|
||||
db, job_id, "completed", 100, "Training completed successfully",
|
||||
results=training_results
|
||||
# Step 2: Execute ML training pipeline
|
||||
logger.info("Step 2: Starting ML training pipeline")
|
||||
training_results = await self.trainer.train_tenant_models(
|
||||
tenant_id=tenant_id,
|
||||
training_dataset=training_dataset,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
# Publish completion event
|
||||
await publish_job_completed(job_id, tenant_id, training_results)
|
||||
# Step 3: Compile final results
|
||||
final_result = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "completed",
|
||||
"training_results": training_results,
|
||||
"data_summary": {
|
||||
"sales_records": len(training_dataset.sales_data),
|
||||
"weather_records": len(training_dataset.weather_data),
|
||||
"traffic_records": len(training_dataset.traffic_data),
|
||||
"date_range": {
|
||||
"start": training_dataset.date_range.start.isoformat(),
|
||||
"end": training_dataset.date_range.end.isoformat()
|
||||
},
|
||||
"data_sources_used": [source.value for source in training_dataset.date_range.available_sources],
|
||||
"constraints_applied": training_dataset.date_range.constraints
|
||||
},
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"Training results {training_results}")
|
||||
logger.info(f"Training job {job_id} completed successfully")
|
||||
metrics.increment_counter("training_jobs_completed")
|
||||
return final_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training job {job_id} failed: {str(e)}")
|
||||
await self._update_job_status(
|
||||
db, job_id, "failed", 0, f"Training failed: {str(e)}",
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
# Publish failure event
|
||||
await publish_job_failed(job_id, tenant_id, str(e))
|
||||
|
||||
metrics.increment_counter("training_jobs_failed")
|
||||
raise
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": "failed",
|
||||
"error_message": str(e),
|
||||
"failed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
async def execute_single_product_training(self,
|
||||
db: AsyncSession,
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
request: SingleProductTrainingRequest):
|
||||
"""Execute training for a single product"""
|
||||
async def start_single_product_training(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
sales_data: List[Dict[str, Any]],
|
||||
bakery_location: tuple[float, float] = (40.4168, -3.7038),
|
||||
weather_data: Optional[List[Dict[str, Any]]] = None,
|
||||
traffic_data: Optional[List[Dict[str, Any]]] = None,
|
||||
job_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Train a model for a single product.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
product_name: Product name
|
||||
sales_data: Historical sales data
|
||||
bakery_location: Bakery coordinates
|
||||
weather_data: Optional weather data
|
||||
traffic_data: Optional traffic data
|
||||
job_id: Optional job identifier
|
||||
|
||||
Returns:
|
||||
Single product training result
|
||||
"""
|
||||
if not job_id:
|
||||
job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
logger.info(f"Starting single product training {job_id} for {product_name}")
|
||||
|
||||
try:
|
||||
logger.info(f"Starting single product training {job_id} for {product_name}")
|
||||
# Filter sales data for the specific product
|
||||
product_sales = [
|
||||
record for record in sales_data
|
||||
if record.get('product_name') == product_name
|
||||
]
|
||||
|
||||
# Update job status
|
||||
await self._update_job_status(db, job_id, "running", 10, f"Fetching data for {product_name}")
|
||||
if not product_sales:
|
||||
raise ValueError(f"No sales data found for product: {product_name}")
|
||||
|
||||
# Fetch data
|
||||
sales_data = await self._fetch_product_sales_data(tenant_id, product_name, request)
|
||||
weather_data = []
|
||||
traffic_data = []
|
||||
|
||||
if request.include_weather:
|
||||
await self._update_job_status(db, job_id, "running", 30, "Fetching weather data")
|
||||
weather_data = await self.data_client.fetch_weather_data(tenant_id, request)
|
||||
|
||||
if request.include_traffic:
|
||||
await self._update_job_status(db, job_id, "running", 50, "Fetching traffic data")
|
||||
traffic_data = await self.data_client.fetch_traffic_data(tenant_id, request)
|
||||
|
||||
# Execute training
|
||||
await self._update_job_status(db, job_id, "running", 70, f"Training model for {product_name}")
|
||||
|
||||
training_result = await self.ml_trainer.train_single_product(
|
||||
# Use the same pipeline but for single product
|
||||
return await self.start_training_job(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
sales_data=sales_data,
|
||||
sales_data=product_sales,
|
||||
bakery_location=bakery_location,
|
||||
weather_data=weather_data,
|
||||
traffic_data=traffic_data,
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
# Store model
|
||||
await self._update_job_status(db, job_id, "running", 90, "Storing trained model")
|
||||
await self._store_single_trained_model(db, tenant_id, product_name, training_result)
|
||||
|
||||
await self._update_job_status(
|
||||
db, job_id, "completed", 100, f"Training completed for {product_name}",
|
||||
results=training_result
|
||||
)
|
||||
|
||||
logger.info(f"Single product training {job_id} completed successfully")
|
||||
metrics.increment_counter("single_product_training_completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Single product training {job_id} failed: {str(e)}")
|
||||
await self._update_job_status(
|
||||
db, job_id, "failed", 0, f"Training failed: {str(e)}",
|
||||
error_message=str(e)
|
||||
)
|
||||
metrics.increment_counter("single_product_training_failed")
|
||||
raise
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"status": "failed",
|
||||
"error_message": str(e),
|
||||
"failed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
async def get_job_status(self,
|
||||
db: AsyncSession,
|
||||
job_id: str,
|
||||
tenant_id: str) -> Optional[ModelTrainingLog]:
|
||||
"""Get training job status"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(ModelTrainingLog).where(
|
||||
and_(
|
||||
ModelTrainingLog.job_id == job_id,
|
||||
ModelTrainingLog.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
async def validate_training_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
sales_data: List[Dict[str, Any]],
|
||||
products: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate training data quality before starting training.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
sales_data: Sales data to validate
|
||||
products: Optional list of specific products to validate
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get job status: {str(e)}")
|
||||
return None
|
||||
|
||||
async def list_training_jobs(self,
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
limit: int = 10,
|
||||
status_filter: Optional[str] = None) -> List[ModelTrainingLog]:
|
||||
"""List training jobs for a tenant"""
|
||||
try:
|
||||
query = select(ModelTrainingLog).where(
|
||||
ModelTrainingLog.tenant_id == tenant_id
|
||||
).order_by(ModelTrainingLog.start_time.desc()).limit(limit)
|
||||
|
||||
if status_filter:
|
||||
query = query.where(ModelTrainingLog.status == status_filter)
|
||||
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list training jobs: {str(e)}")
|
||||
return []
|
||||
|
||||
async def cancel_training_job(self,
|
||||
db: AsyncSession,
|
||||
job_id: str,
|
||||
tenant_id: str) -> bool:
|
||||
"""Cancel a training job"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
update(ModelTrainingLog)
|
||||
.where(
|
||||
and_(
|
||||
ModelTrainingLog.job_id == job_id,
|
||||
ModelTrainingLog.tenant_id == tenant_id,
|
||||
ModelTrainingLog.status.in_(["pending", "running"])
|
||||
)
|
||||
)
|
||||
.values(
|
||||
status="cancelled",
|
||||
end_time=datetime.now(),
|
||||
current_step="Training cancelled by user"
|
||||
)
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
logger.info(f"Cancelled training job {job_id}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Could not cancel training job {job_id} - not found or not cancellable")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel training job: {str(e)}")
|
||||
await db.rollback()
|
||||
return False
|
||||
|
||||
async def validate_training_data(self,
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate training data before starting a job"""
|
||||
Returns:
|
||||
Validation results
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Validating training data for tenant {tenant_id}")
|
||||
|
||||
issues = []
|
||||
recommendations = []
|
||||
|
||||
# Fetch a sample of sales data to validate
|
||||
sales_data = await self._fetch_sales_data(tenant_id, config, limit=1000)
|
||||
|
||||
# Extract sales date range for validation
|
||||
if not sales_data:
|
||||
issues.append("No sales data found for tenant")
|
||||
return {
|
||||
"is_valid": False,
|
||||
"issues": issues,
|
||||
"recommendations": ["Upload sales data before training"],
|
||||
"estimated_time_minutes": 0
|
||||
"valid": False,
|
||||
"errors": ["No sales data provided"],
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
# Analyze data quality
|
||||
products = set(item.get("product_name") for item in sales_data)
|
||||
total_records = len(sales_data)
|
||||
|
||||
# Check for sufficient data per product
|
||||
product_counts = {}
|
||||
for item in sales_data:
|
||||
product = item.get("product_name")
|
||||
if product:
|
||||
product_counts[product] = product_counts.get(product, 0) + 1
|
||||
|
||||
insufficient_products = [
|
||||
product for product, count in product_counts.items()
|
||||
if count < config.get("min_data_points", 30)
|
||||
]
|
||||
|
||||
if insufficient_products:
|
||||
issues.append(f"Insufficient data for products: {', '.join(insufficient_products)}")
|
||||
recommendations.append("Collect more historical data for these products")
|
||||
|
||||
# Estimate training time
|
||||
valid_products = len(products) - len(insufficient_products)
|
||||
estimated_time = max(5, valid_products * 2) # 2 minutes per product minimum
|
||||
|
||||
is_valid = len(issues) == 0
|
||||
|
||||
return {
|
||||
"is_valid": is_valid,
|
||||
"issues": issues,
|
||||
"recommendations": recommendations,
|
||||
"estimated_time_minutes": estimated_time,
|
||||
"products_analyzed": len(products),
|
||||
"total_data_points": total_records
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate training data: {str(e)}")
|
||||
return {
|
||||
"is_valid": False,
|
||||
"issues": [f"Validation error: {str(e)}"],
|
||||
"recommendations": ["Check data service connectivity"],
|
||||
"estimated_time_minutes": 0
|
||||
}
|
||||
|
||||
async def _update_job_status(self,
|
||||
db: AsyncSession,
|
||||
job_id: str,
|
||||
status: str,
|
||||
progress: int,
|
||||
current_step: str,
|
||||
results: Optional[Dict] = None,
|
||||
error_message: Optional[str] = None):
|
||||
"""Update training job status"""
|
||||
try:
|
||||
update_values = {
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
"current_step": current_step
|
||||
}
|
||||
|
||||
if status == "completed":
|
||||
update_values["end_time"] = datetime.now()
|
||||
|
||||
if results:
|
||||
update_values["results"] = results
|
||||
|
||||
if error_message:
|
||||
update_values["error_message"] = error_message
|
||||
update_values["end_time"] = datetime.now()
|
||||
|
||||
await db.execute(
|
||||
update(ModelTrainingLog)
|
||||
.where(ModelTrainingLog.job_id == job_id)
|
||||
.values(**update_values)
|
||||
# Create a mock training dataset to validate
|
||||
mock_dataset = await self.orchestrator.prepare_training_data(
|
||||
tenant_id=tenant_id,
|
||||
sales_data=sales_data,
|
||||
bakery_location=(40.4168, -3.7038), # Default Madrid
|
||||
job_id=f"validation_{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
# Validate the dataset
|
||||
validation_results = self.orchestrator.validate_training_data_quality(mock_dataset)
|
||||
|
||||
# Add product-specific information
|
||||
unique_products = list(set(record.get('product_name', 'unknown') for record in sales_data))
|
||||
product_data_points = {}
|
||||
|
||||
for record in sales_data:
|
||||
product = record.get('product_name', 'unknown')
|
||||
product_data_points[product] = product_data_points.get(product, 0) + 1
|
||||
|
||||
validation_results.update({
|
||||
"products_found": unique_products,
|
||||
"product_data_points": product_data_points,
|
||||
"total_records": len(sales_data),
|
||||
"date_range_info": {
|
||||
"start": mock_dataset.date_range.start.isoformat(),
|
||||
"end": mock_dataset.date_range.end.isoformat(),
|
||||
"duration_days": (mock_dataset.date_range.end - mock_dataset.date_range.start).days
|
||||
}
|
||||
})
|
||||
|
||||
return validation_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update job status: {str(e)}")
|
||||
await db.rollback()
|
||||
logger.error(f"Training data validation failed: {str(e)}")
|
||||
return {
|
||||
"valid": False,
|
||||
"errors": [f"Validation failed: {str(e)}"],
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
async def _store_trained_models(self,
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
training_results: Dict[str, Any]):
|
||||
"""Store trained models in database"""
|
||||
async def get_training_recommendations(
|
||||
self,
|
||||
tenant_id: str,
|
||||
sales_data: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get training recommendations based on data analysis.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
sales_data: Historical sales data
|
||||
|
||||
Returns:
|
||||
Training recommendations
|
||||
"""
|
||||
try:
|
||||
models_to_store = []
|
||||
logger.info(f"Generating training recommendations for tenant {tenant_id}")
|
||||
|
||||
for product_name, result in training_results.get("training_results", {}).items():
|
||||
if result.get("status") == "success":
|
||||
model_info = result.get("model_info", {})
|
||||
|
||||
trained_model = TrainedModel(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
model_id=model_info.get("model_id"),
|
||||
model_type=model_info.get("type", "prophet"),
|
||||
model_path=model_info.get("model_path"),
|
||||
version=1, # Start with version 1
|
||||
training_samples=model_info.get("training_samples", 0),
|
||||
features=model_info.get("features", []),
|
||||
hyperparameters=model_info.get("hyperparameters", {}),
|
||||
training_metrics=model_info.get("training_metrics", {}),
|
||||
data_period_start=datetime.fromisoformat(
|
||||
model_info.get("data_period", {}).get("start_date", datetime.now().isoformat())
|
||||
),
|
||||
data_period_end=datetime.fromisoformat(
|
||||
model_info.get("data_period", {}).get("end_date", datetime.now().isoformat())
|
||||
),
|
||||
created_at=datetime.now(),
|
||||
is_active=True
|
||||
)
|
||||
|
||||
models_to_store.append(trained_model)
|
||||
# Analyze the data
|
||||
validation_results = await self.validate_training_data(tenant_id, sales_data)
|
||||
|
||||
# Deactivate old models for these products
|
||||
if models_to_store:
|
||||
product_names = [model.product_name for model in models_to_store]
|
||||
|
||||
await db.execute(
|
||||
update(TrainedModel)
|
||||
.where(
|
||||
and_(
|
||||
TrainedModel.tenant_id == tenant_id,
|
||||
TrainedModel.product_name.in_(product_names),
|
||||
TrainedModel.is_active == True
|
||||
)
|
||||
)
|
||||
.values(is_active=False)
|
||||
)
|
||||
|
||||
# Add new models
|
||||
db.add_all(models_to_store)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"Stored {len(models_to_store)} trained models for tenant {tenant_id}")
|
||||
recommendations = {
|
||||
"should_retrain": True,
|
||||
"reasons": [],
|
||||
"recommended_products": [],
|
||||
"optimal_config": {
|
||||
"include_weather": True,
|
||||
"include_traffic": True,
|
||||
"min_data_points": 30,
|
||||
"hyperparameter_optimization": True
|
||||
}
|
||||
}
|
||||
|
||||
# Analyze data quality and provide recommendations
|
||||
if validation_results.get("data_quality_score", 0) >= 80:
|
||||
recommendations["reasons"].append("High quality data detected")
|
||||
else:
|
||||
recommendations["reasons"].append("Data quality could be improved")
|
||||
|
||||
# Recommend products with sufficient data
|
||||
product_data_points = validation_results.get("product_data_points", {})
|
||||
for product, points in product_data_points.items():
|
||||
if points >= 30: # Minimum viable data points
|
||||
recommendations["recommended_products"].append(product)
|
||||
|
||||
if len(recommendations["recommended_products"]) == 0:
|
||||
recommendations["should_retrain"] = False
|
||||
recommendations["reasons"].append("Insufficient data for reliable training")
|
||||
|
||||
return recommendations
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store trained models: {str(e)}")
|
||||
await db.rollback()
|
||||
raise
|
||||
|
||||
async def _store_single_trained_model(self,
|
||||
db: AsyncSession,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
training_result: Dict[str, Any]):
|
||||
"""Store a single trained model"""
|
||||
try:
|
||||
if training_result.get("status") == "success":
|
||||
model_info = training_result.get("model_info", {})
|
||||
|
||||
# Deactivate old model for this product
|
||||
await db.execute(
|
||||
update(TrainedModel)
|
||||
.where(
|
||||
and_(
|
||||
TrainedModel.tenant_id == tenant_id,
|
||||
TrainedModel.product_name == product_name,
|
||||
TrainedModel.is_active == True
|
||||
)
|
||||
)
|
||||
.values(is_active=False)
|
||||
)
|
||||
|
||||
# Create new model record
|
||||
trained_model = TrainedModel(
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
model_id=model_info.get("model_id"),
|
||||
model_type=model_info.get("type", "prophet"),
|
||||
model_path=model_info.get("model_path"),
|
||||
version=1,
|
||||
training_samples=model_info.get("training_samples", 0),
|
||||
features=model_info.get("features", []),
|
||||
hyperparameters=model_info.get("hyperparameters", {}),
|
||||
training_metrics=model_info.get("training_metrics", {}),
|
||||
data_period_start=datetime.fromisoformat(
|
||||
model_info.get("data_period", {}).get("start_date", datetime.now().isoformat())
|
||||
),
|
||||
data_period_end=datetime.fromisoformat(
|
||||
model_info.get("data_period", {}).get("end_date", datetime.now().isoformat())
|
||||
),
|
||||
created_at=datetime.now(),
|
||||
is_active=True
|
||||
)
|
||||
|
||||
db.add(trained_model)
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"Stored trained model for {product_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store trained model: {str(e)}")
|
||||
await db.rollback()
|
||||
raise
|
||||
|
||||
async def get_training_logs(self,
|
||||
db: AsyncSession,
|
||||
job_id: str,
|
||||
tenant_id: str) -> Optional[List[str]]:
|
||||
"""Get detailed training logs for a job"""
|
||||
try:
|
||||
# For now, return basic log information from the database
|
||||
# In a production system, you might store detailed logs separately
|
||||
result = await db.execute(
|
||||
select(ModelTrainingLog).where(
|
||||
and_(
|
||||
ModelTrainingLog.job_id == job_id,
|
||||
ModelTrainingLog.tenant_id == tenant_id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
training_log = result.scalar_one_or_none()
|
||||
|
||||
if training_log:
|
||||
logs = [
|
||||
f"Job started at: {training_log.start_time}",
|
||||
f"Current status: {training_log.status}",
|
||||
f"Progress: {training_log.progress}%",
|
||||
f"Current step: {training_log.current_step}"
|
||||
]
|
||||
|
||||
if training_log.end_time:
|
||||
logs.append(f"Job completed at: {training_log.end_time}")
|
||||
|
||||
if training_log.error_message:
|
||||
logs.append(f"Error: {training_log.error_message}")
|
||||
|
||||
if training_log.results:
|
||||
results = training_log.results
|
||||
logs.append(f"Models trained: {results.get('products_trained', 0)}")
|
||||
logs.append(f"Models failed: {results.get('products_failed', 0)}")
|
||||
|
||||
return logs
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get training logs: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _determine_sales_date_range(self, sales_data: List[Dict]) -> tuple[datetime, datetime]:
|
||||
"""Determine start and end dates from sales data"""
|
||||
if not sales_data:
|
||||
raise ValueError("No sales data available to determine date range")
|
||||
|
||||
dates = []
|
||||
for record in sales_data:
|
||||
if 'date' in record:
|
||||
if isinstance(record['date'], str):
|
||||
dates.append(datetime.fromisoformat(record['date'].replace('Z', '+00:00')))
|
||||
elif isinstance(record['date'], datetime):
|
||||
dates.append(record['date'])
|
||||
|
||||
if not dates:
|
||||
raise ValueError("No valid dates found in sales data")
|
||||
|
||||
start_date = min(dates)
|
||||
end_date = max(dates)
|
||||
|
||||
logger.info(f"Determined sales date range: {start_date} to {end_date}")
|
||||
return start_date, end_date
|
||||
logger.error(f"Failed to generate training recommendations: {str(e)}")
|
||||
return {
|
||||
"should_retrain": False,
|
||||
"reasons": [f"Error analyzing data: {str(e)}"],
|
||||
"recommended_products": [],
|
||||
"optimal_config": {}
|
||||
}
|
||||
Reference in New Issue
Block a user