Improve training code 2

This commit is contained in:
Urtzi Alfaro
2025-07-28 20:20:54 +02:00
parent 98f546af12
commit 7cd595df81
6 changed files with 229 additions and 153 deletions

View File

@@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
import logging
from datetime import datetime, timedelta, timezone
logger = logging.getLogger(__name__)
@@ -84,15 +85,24 @@ class DateAlignmentService:
) -> DateRange:
"""Determine the base date range for training."""
# ✅ FIX: Ensure all datetimes are timezone-aware for comparison
def ensure_timezone_aware(dt: datetime) -> datetime:
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt
# Use explicit dates if provided
if requested_start and requested_end:
requested_start = ensure_timezone_aware(requested_start)
requested_end = ensure_timezone_aware(requested_end)
if requested_end <= requested_start:
raise ValueError("End date must be after start date")
return DateRange(requested_start, requested_end, DataSourceType.BAKERY_SALES)
# Otherwise, use the user's sales data range as the foundation
start_date = requested_start or user_sales_range.start
end_date = requested_end or user_sales_range.end
start_date = ensure_timezone_aware(requested_start or user_sales_range.start)
end_date = ensure_timezone_aware(requested_end or user_sales_range.end)
# Ensure we don't exceed maximum training range
if (end_date - start_date).days > self.MAX_TRAINING_RANGE_DAYS:
@@ -104,7 +114,7 @@ class DateAlignmentService:
def _apply_data_source_constraints(self, base_range: DateRange) -> AlignedDateRange:
"""Apply constraints from each data source and determine final aligned range."""
current_month = datetime.now().replace(day=1, hour=0, minute=0, second=0, microsecond=0)
current_month = datetime.now(timezone.utc).replace(day=1, hour=0, minute=0, second=0, microsecond=0)
available_sources = [DataSourceType.BAKERY_SALES] # Always have sales data
constraints = {}
@@ -121,7 +131,7 @@ class DateAlignmentService:
# Weather Forecast Constraint
# Weather data available from yesterday backward
weather_end_date = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=1)
weather_end_date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=1)
if base_range.end > weather_end_date:
if new_end > weather_end_date:
new_end = weather_end_date
@@ -150,7 +160,7 @@ class DateAlignmentService:
Get the latest available date for Madrid traffic data.
Data for current month is not available until the following month.
"""
now = datetime.now()
now = datetime.now(timezone.utc)
if now.day == 1:
# If it's the first day of the month, data up to previous month should be available
last_available_month = now.replace(day=1) - timedelta(days=1)
@@ -234,7 +244,7 @@ class DateAlignmentService:
Returns:
True if the constraint is violated (end date is in current month)
"""
now = datetime.now()
now = datetime.now(timezone.utc)
current_month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return end_date >= current_month_start

View File

@@ -10,6 +10,8 @@ from dataclasses import dataclass
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import timezone
import pandas as pd
from app.services.data_client import DataServiceClient
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
@@ -35,8 +37,6 @@ class TrainingDataOrchestrator:
madrid_client=None,
weather_client=None,
date_alignment_service: DateAlignmentService = None):
self.madrid_client = madrid_client
self.weather_client = weather_client
self.data_client = DataServiceClient()
self.date_alignment_service = date_alignment_service or DateAlignmentService()
self.max_concurrent_requests = 3
@@ -67,7 +67,7 @@ class TrainingDataOrchestrator:
try:
sales_data = self.data_client.fetch_sales_data(tenant_id)
sales_data = await self.data_client.fetch_sales_data(tenant_id)
# Step 1: Extract and validate sales data date range
sales_date_range = self._extract_sales_date_range(sales_data)
@@ -90,7 +90,7 @@ class TrainingDataOrchestrator:
# Step 4: Collect external data sources concurrently
logger.info("Collecting external data sources...")
weather_data, traffic_data = await self._collect_external_data(
aligned_range, bakery_location
aligned_range, bakery_location, tenant_id
)
# Step 5: Validate data quality
@@ -136,44 +136,33 @@ class TrainingDataOrchestrator:
raise ValueError(f"Failed to prepare training data: {str(e)}")
def _extract_sales_date_range(self, sales_data: List[Dict[str, Any]]) -> DateRange:
"""Extract and validate the date range from sales data"""
"""Extract date range from sales data with timezone handling"""
if not sales_data:
raise ValueError("No sales data provided")
dates = []
valid_records = 0
for record in sales_data:
try:
if 'date' in record:
date_val = record['date']
if isinstance(date_val, str):
# Handle various date formats
if 'T' in date_val:
date_val = date_val.replace('Z', '+00:00')
parsed_date = datetime.fromisoformat(date_val.split('T')[0])
elif isinstance(date_val, datetime):
parsed_date = date_val
else:
continue
dates.append(parsed_date)
valid_records += 1
except (ValueError, TypeError) as e:
logger.warning(f"Invalid date in sales record: {record.get('date', 'N/A')} - {str(e)}")
continue
date_value = record.get('date')
if date_value:
# ✅ FIX: Ensure timezone-aware datetime
if isinstance(date_value, str):
dt = pd.to_datetime(date_value)
if dt.tz is None:
dt = dt.replace(tzinfo=timezone.utc)
dates.append(dt.to_pydatetime())
elif isinstance(date_value, datetime):
if date_value.tzinfo is None:
date_value = date_value.replace(tzinfo=timezone.utc)
dates.append(date_value)
if not dates:
raise ValueError("No valid dates found in sales data")
logger.info(f"Processed {valid_records} valid date records from {len(sales_data)} total records")
start_date = min(dates)
end_date = max(dates)
return DateRange(
start=min(dates),
end=max(dates),
source=DataSourceType.BAKERY_SALES
)
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
def _filter_sales_data(
self,
sales_data: List[Dict[str, Any]],
@@ -187,20 +176,43 @@ class TrainingDataOrchestrator:
try:
if 'date' in record:
record_date = record['date']
# ✅ FIX: Proper timezone handling for date parsing
if isinstance(record_date, str):
if 'T' in record_date:
record_date = record_date.replace('Z', '+00:00')
record_date = datetime.fromisoformat(record_date.split('T')[0])
# Parse with timezone info intact
parsed_date = datetime.fromisoformat(record_date.split('T')[0])
# Ensure timezone-aware
if parsed_date.tzinfo is None:
parsed_date = parsed_date.replace(tzinfo=timezone.utc)
record_date = parsed_date
elif isinstance(record_date, datetime):
# Ensure timezone-aware
if record_date.tzinfo is None:
record_date = record_date.replace(tzinfo=timezone.utc)
# Normalize to start of day
record_date = record_date.replace(hour=0, minute=0, second=0, microsecond=0)
# Check if date falls within aligned range
if aligned_range.start <= record_date <= aligned_range.end:
# ✅ FIX: Ensure aligned_range dates are also timezone-aware for comparison
aligned_start = aligned_range.start
aligned_end = aligned_range.end
if aligned_start.tzinfo is None:
aligned_start = aligned_start.replace(tzinfo=timezone.utc)
if aligned_end.tzinfo is None:
aligned_end = aligned_end.replace(tzinfo=timezone.utc)
# Check if date falls within aligned range (now both are timezone-aware)
if aligned_start <= record_date <= aligned_end:
# Validate that record has required fields
if self._validate_sales_record(record):
filtered_data.append(record)
else:
filtered_count += 1
else:
# Record outside date range
filtered_count += 1
except Exception as e:
logger.warning(f"Error processing sales record: {str(e)}")
filtered_count += 1
@@ -243,7 +255,8 @@ class TrainingDataOrchestrator:
async def _collect_external_data(
self,
aligned_range: AlignedDateRange,
bakery_location: Tuple[float, float]
bakery_location: Tuple[float, float],
tenant_id: str
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Collect weather and traffic data concurrently with enhanced error handling"""
@@ -255,14 +268,14 @@ class TrainingDataOrchestrator:
# Weather data collection
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
weather_task = asyncio.create_task(
self._collect_weather_data_with_timeout(lat, lon, aligned_range)
self._collect_weather_data_with_timeout(lat, lon, aligned_range, tenant_id)
)
tasks.append(("weather", weather_task))
# Traffic data collection
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
traffic_task = asyncio.create_task(
self._collect_traffic_data_with_timeout(lat, lon, aligned_range)
self._collect_traffic_data_with_timeout(lat, lon, aligned_range, tenant_id)
)
tasks.append(("traffic", traffic_task))
@@ -297,18 +310,21 @@ class TrainingDataOrchestrator:
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Collect weather data with timeout and fallback"""
try:
if not self.weather_client:
logger.info("Weather client not configured, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
start_date_str = aligned_range.start.isoformat()
end_date_str = aligned_range.end.isoformat()
weather_data = await asyncio.wait_for(
self.data_client.fetch_weather_data(aligned_range.start, aligned_range.end, lat, lon),
)
weather_data = await self.data_client.fetch_weather_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=lat,
longitude=lon)
# Validate weather data
if self._validate_weather_data(weather_data):
@@ -319,7 +335,7 @@ class TrainingDataOrchestrator:
return self._generate_synthetic_weather_data(aligned_range)
except asyncio.TimeoutError:
logger.warning(f"Weather data collection timed out after {timeout_seconds}s, using synthetic data")
logger.warning(f"Weather data collection timed out, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
except Exception as e:
logger.warning(f"Weather data collection failed: {e}, using synthetic data")
@@ -329,24 +345,27 @@ class TrainingDataOrchestrator:
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Collect traffic data with timeout and Madrid constraint validation"""
try:
if not self.madrid_client:
logger.info("Madrid client not configured, no traffic data available")
return []
# Double-check Madrid constraint before making request
if self.date_alignment_service.check_madrid_current_month_constraint(aligned_range.end):
logger.warning("Madrid current month constraint violation, no traffic data available")
return []
traffic_data = await asyncio.wait_for(
self.data_client.fetch_traffic_data(aligned_range.start, aligned_range.end, lat, lon),
)
start_date_str = aligned_range.start.isoformat()
end_date_str = aligned_range.end.isoformat()
traffic_data = await self.data_client.fetch_traffic_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=lat,
longitude=lon)
# Validate traffic data
if self._validate_traffic_data(traffic_data):
logger.info(f"Collected {len(traffic_data)} valid traffic records")
@@ -356,7 +375,7 @@ class TrainingDataOrchestrator:
return []
except asyncio.TimeoutError:
logger.warning(f"Traffic data collection timed out after {timeout_seconds}s")
logger.warning(f"Traffic data collection timed out")
return []
except Exception as e:
logger.warning(f"Traffic data collection failed: {e}")

View File

@@ -81,36 +81,26 @@ class TrainingService:
)
# Step 3: Compile final results
final_result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "completed",
"training_results": training_results,
"data_summary": {
"sales_records": len(training_dataset.sales_data),
"weather_records": len(training_dataset.weather_data),
"traffic_records": len(training_dataset.traffic_data),
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat()
},
"data_sources_used": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
},
"completed_at": datetime.now().isoformat()
}
logger.info(f"Training job {job_id} completed successfully")
return final_result
return {
"job_id": job_id,
"status": "completed", # or "running" if async
"message": "Training job completed successfully",
"tenant_id": tenant_id,
"created_at": datetime.now(),
"estimated_duration_minutes": 5 # reasonable estimate
}
except Exception as e:
logger.error(f"Training job {job_id} failed: {str(e)}")
# Return error response that still matches schema
return {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "failed",
"error_message": str(e),
"failed_at": datetime.now().isoformat()
"message": f"Training job failed: {str(e)}",
"tenant_id": tenant_id,
"created_at": datetime.now(),
"estimated_duration_minutes": 0
}
async def start_single_product_training(