Improve training code 2

This commit is contained in:
Urtzi Alfaro
2025-07-28 20:20:54 +02:00
parent 98f546af12
commit 7cd595df81
6 changed files with 229 additions and 153 deletions

View File

@@ -10,6 +10,8 @@ from dataclasses import dataclass
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import timezone
import pandas as pd
from app.services.data_client import DataServiceClient
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
@@ -35,8 +37,6 @@ class TrainingDataOrchestrator:
madrid_client=None,
weather_client=None,
date_alignment_service: DateAlignmentService = None):
self.madrid_client = madrid_client
self.weather_client = weather_client
self.data_client = DataServiceClient()
self.date_alignment_service = date_alignment_service or DateAlignmentService()
self.max_concurrent_requests = 3
@@ -67,7 +67,7 @@ class TrainingDataOrchestrator:
try:
sales_data = self.data_client.fetch_sales_data(tenant_id)
sales_data = await self.data_client.fetch_sales_data(tenant_id)
# Step 1: Extract and validate sales data date range
sales_date_range = self._extract_sales_date_range(sales_data)
@@ -90,7 +90,7 @@ class TrainingDataOrchestrator:
# Step 4: Collect external data sources concurrently
logger.info("Collecting external data sources...")
weather_data, traffic_data = await self._collect_external_data(
aligned_range, bakery_location
aligned_range, bakery_location, tenant_id
)
# Step 5: Validate data quality
@@ -136,44 +136,33 @@ class TrainingDataOrchestrator:
raise ValueError(f"Failed to prepare training data: {str(e)}")
def _extract_sales_date_range(self, sales_data: List[Dict[str, Any]]) -> DateRange:
"""Extract and validate the date range from sales data"""
"""Extract date range from sales data with timezone handling"""
if not sales_data:
raise ValueError("No sales data provided")
dates = []
valid_records = 0
for record in sales_data:
try:
if 'date' in record:
date_val = record['date']
if isinstance(date_val, str):
# Handle various date formats
if 'T' in date_val:
date_val = date_val.replace('Z', '+00:00')
parsed_date = datetime.fromisoformat(date_val.split('T')[0])
elif isinstance(date_val, datetime):
parsed_date = date_val
else:
continue
dates.append(parsed_date)
valid_records += 1
except (ValueError, TypeError) as e:
logger.warning(f"Invalid date in sales record: {record.get('date', 'N/A')} - {str(e)}")
continue
date_value = record.get('date')
if date_value:
# ✅ FIX: Ensure timezone-aware datetime
if isinstance(date_value, str):
dt = pd.to_datetime(date_value)
if dt.tz is None:
dt = dt.replace(tzinfo=timezone.utc)
dates.append(dt.to_pydatetime())
elif isinstance(date_value, datetime):
if date_value.tzinfo is None:
date_value = date_value.replace(tzinfo=timezone.utc)
dates.append(date_value)
if not dates:
raise ValueError("No valid dates found in sales data")
logger.info(f"Processed {valid_records} valid date records from {len(sales_data)} total records")
start_date = min(dates)
end_date = max(dates)
return DateRange(
start=min(dates),
end=max(dates),
source=DataSourceType.BAKERY_SALES
)
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
def _filter_sales_data(
self,
sales_data: List[Dict[str, Any]],
@@ -187,20 +176,43 @@ class TrainingDataOrchestrator:
try:
if 'date' in record:
record_date = record['date']
# ✅ FIX: Proper timezone handling for date parsing
if isinstance(record_date, str):
if 'T' in record_date:
record_date = record_date.replace('Z', '+00:00')
record_date = datetime.fromisoformat(record_date.split('T')[0])
# Parse with timezone info intact
parsed_date = datetime.fromisoformat(record_date.split('T')[0])
# Ensure timezone-aware
if parsed_date.tzinfo is None:
parsed_date = parsed_date.replace(tzinfo=timezone.utc)
record_date = parsed_date
elif isinstance(record_date, datetime):
# Ensure timezone-aware
if record_date.tzinfo is None:
record_date = record_date.replace(tzinfo=timezone.utc)
# Normalize to start of day
record_date = record_date.replace(hour=0, minute=0, second=0, microsecond=0)
# Check if date falls within aligned range
if aligned_range.start <= record_date <= aligned_range.end:
# ✅ FIX: Ensure aligned_range dates are also timezone-aware for comparison
aligned_start = aligned_range.start
aligned_end = aligned_range.end
if aligned_start.tzinfo is None:
aligned_start = aligned_start.replace(tzinfo=timezone.utc)
if aligned_end.tzinfo is None:
aligned_end = aligned_end.replace(tzinfo=timezone.utc)
# Check if date falls within aligned range (now both are timezone-aware)
if aligned_start <= record_date <= aligned_end:
# Validate that record has required fields
if self._validate_sales_record(record):
filtered_data.append(record)
else:
filtered_count += 1
else:
# Record outside date range
filtered_count += 1
except Exception as e:
logger.warning(f"Error processing sales record: {str(e)}")
filtered_count += 1
@@ -243,7 +255,8 @@ class TrainingDataOrchestrator:
async def _collect_external_data(
self,
aligned_range: AlignedDateRange,
bakery_location: Tuple[float, float]
bakery_location: Tuple[float, float],
tenant_id: str
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Collect weather and traffic data concurrently with enhanced error handling"""
@@ -255,14 +268,14 @@ class TrainingDataOrchestrator:
# Weather data collection
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
weather_task = asyncio.create_task(
self._collect_weather_data_with_timeout(lat, lon, aligned_range)
self._collect_weather_data_with_timeout(lat, lon, aligned_range, tenant_id)
)
tasks.append(("weather", weather_task))
# Traffic data collection
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
traffic_task = asyncio.create_task(
self._collect_traffic_data_with_timeout(lat, lon, aligned_range)
self._collect_traffic_data_with_timeout(lat, lon, aligned_range, tenant_id)
)
tasks.append(("traffic", traffic_task))
@@ -297,18 +310,21 @@ class TrainingDataOrchestrator:
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Collect weather data with timeout and fallback"""
try:
if not self.weather_client:
logger.info("Weather client not configured, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
start_date_str = aligned_range.start.isoformat()
end_date_str = aligned_range.end.isoformat()
weather_data = await asyncio.wait_for(
self.data_client.fetch_weather_data(aligned_range.start, aligned_range.end, lat, lon),
)
weather_data = await self.data_client.fetch_weather_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=lat,
longitude=lon)
# Validate weather data
if self._validate_weather_data(weather_data):
@@ -319,7 +335,7 @@ class TrainingDataOrchestrator:
return self._generate_synthetic_weather_data(aligned_range)
except asyncio.TimeoutError:
logger.warning(f"Weather data collection timed out after {timeout_seconds}s, using synthetic data")
logger.warning(f"Weather data collection timed out, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
except Exception as e:
logger.warning(f"Weather data collection failed: {e}, using synthetic data")
@@ -329,24 +345,27 @@ class TrainingDataOrchestrator:
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Collect traffic data with timeout and Madrid constraint validation"""
try:
if not self.madrid_client:
logger.info("Madrid client not configured, no traffic data available")
return []
# Double-check Madrid constraint before making request
if self.date_alignment_service.check_madrid_current_month_constraint(aligned_range.end):
logger.warning("Madrid current month constraint violation, no traffic data available")
return []
traffic_data = await asyncio.wait_for(
self.data_client.fetch_traffic_data(aligned_range.start, aligned_range.end, lat, lon),
)
start_date_str = aligned_range.start.isoformat()
end_date_str = aligned_range.end.isoformat()
traffic_data = await self.data_client.fetch_traffic_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=lat,
longitude=lon)
# Validate traffic data
if self._validate_traffic_data(traffic_data):
logger.info(f"Collected {len(traffic_data)} valid traffic records")
@@ -356,7 +375,7 @@ class TrainingDataOrchestrator:
return []
except asyncio.TimeoutError:
logger.warning(f"Traffic data collection timed out after {timeout_seconds}s")
logger.warning(f"Traffic data collection timed out")
return []
except Exception as e:
logger.warning(f"Traffic data collection failed: {e}")