Improve training code 2
This commit is contained in:
@@ -10,6 +10,8 @@ from dataclasses import dataclass
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timezone
|
||||
import pandas as pd
|
||||
|
||||
from app.services.data_client import DataServiceClient
|
||||
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
|
||||
@@ -35,8 +37,6 @@ class TrainingDataOrchestrator:
|
||||
madrid_client=None,
|
||||
weather_client=None,
|
||||
date_alignment_service: DateAlignmentService = None):
|
||||
self.madrid_client = madrid_client
|
||||
self.weather_client = weather_client
|
||||
self.data_client = DataServiceClient()
|
||||
self.date_alignment_service = date_alignment_service or DateAlignmentService()
|
||||
self.max_concurrent_requests = 3
|
||||
@@ -67,7 +67,7 @@ class TrainingDataOrchestrator:
|
||||
|
||||
try:
|
||||
|
||||
sales_data = self.data_client.fetch_sales_data(tenant_id)
|
||||
sales_data = await self.data_client.fetch_sales_data(tenant_id)
|
||||
|
||||
# Step 1: Extract and validate sales data date range
|
||||
sales_date_range = self._extract_sales_date_range(sales_data)
|
||||
@@ -90,7 +90,7 @@ class TrainingDataOrchestrator:
|
||||
# Step 4: Collect external data sources concurrently
|
||||
logger.info("Collecting external data sources...")
|
||||
weather_data, traffic_data = await self._collect_external_data(
|
||||
aligned_range, bakery_location
|
||||
aligned_range, bakery_location, tenant_id
|
||||
)
|
||||
|
||||
# Step 5: Validate data quality
|
||||
@@ -136,44 +136,33 @@ class TrainingDataOrchestrator:
|
||||
raise ValueError(f"Failed to prepare training data: {str(e)}")
|
||||
|
||||
def _extract_sales_date_range(self, sales_data: List[Dict[str, Any]]) -> DateRange:
|
||||
"""Extract and validate the date range from sales data"""
|
||||
"""Extract date range from sales data with timezone handling"""
|
||||
if not sales_data:
|
||||
raise ValueError("No sales data provided")
|
||||
|
||||
dates = []
|
||||
valid_records = 0
|
||||
|
||||
for record in sales_data:
|
||||
try:
|
||||
if 'date' in record:
|
||||
date_val = record['date']
|
||||
if isinstance(date_val, str):
|
||||
# Handle various date formats
|
||||
if 'T' in date_val:
|
||||
date_val = date_val.replace('Z', '+00:00')
|
||||
parsed_date = datetime.fromisoformat(date_val.split('T')[0])
|
||||
elif isinstance(date_val, datetime):
|
||||
parsed_date = date_val
|
||||
else:
|
||||
continue
|
||||
|
||||
dates.append(parsed_date)
|
||||
valid_records += 1
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"Invalid date in sales record: {record.get('date', 'N/A')} - {str(e)}")
|
||||
continue
|
||||
date_value = record.get('date')
|
||||
if date_value:
|
||||
# ✅ FIX: Ensure timezone-aware datetime
|
||||
if isinstance(date_value, str):
|
||||
dt = pd.to_datetime(date_value)
|
||||
if dt.tz is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
dates.append(dt.to_pydatetime())
|
||||
elif isinstance(date_value, datetime):
|
||||
if date_value.tzinfo is None:
|
||||
date_value = date_value.replace(tzinfo=timezone.utc)
|
||||
dates.append(date_value)
|
||||
|
||||
if not dates:
|
||||
raise ValueError("No valid dates found in sales data")
|
||||
|
||||
logger.info(f"Processed {valid_records} valid date records from {len(sales_data)} total records")
|
||||
start_date = min(dates)
|
||||
end_date = max(dates)
|
||||
|
||||
return DateRange(
|
||||
start=min(dates),
|
||||
end=max(dates),
|
||||
source=DataSourceType.BAKERY_SALES
|
||||
)
|
||||
|
||||
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
|
||||
|
||||
def _filter_sales_data(
|
||||
self,
|
||||
sales_data: List[Dict[str, Any]],
|
||||
@@ -187,20 +176,43 @@ class TrainingDataOrchestrator:
|
||||
try:
|
||||
if 'date' in record:
|
||||
record_date = record['date']
|
||||
|
||||
# ✅ FIX: Proper timezone handling for date parsing
|
||||
if isinstance(record_date, str):
|
||||
if 'T' in record_date:
|
||||
record_date = record_date.replace('Z', '+00:00')
|
||||
record_date = datetime.fromisoformat(record_date.split('T')[0])
|
||||
# Parse with timezone info intact
|
||||
parsed_date = datetime.fromisoformat(record_date.split('T')[0])
|
||||
# Ensure timezone-aware
|
||||
if parsed_date.tzinfo is None:
|
||||
parsed_date = parsed_date.replace(tzinfo=timezone.utc)
|
||||
record_date = parsed_date
|
||||
elif isinstance(record_date, datetime):
|
||||
# Ensure timezone-aware
|
||||
if record_date.tzinfo is None:
|
||||
record_date = record_date.replace(tzinfo=timezone.utc)
|
||||
# Normalize to start of day
|
||||
record_date = record_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# Check if date falls within aligned range
|
||||
if aligned_range.start <= record_date <= aligned_range.end:
|
||||
# ✅ FIX: Ensure aligned_range dates are also timezone-aware for comparison
|
||||
aligned_start = aligned_range.start
|
||||
aligned_end = aligned_range.end
|
||||
|
||||
if aligned_start.tzinfo is None:
|
||||
aligned_start = aligned_start.replace(tzinfo=timezone.utc)
|
||||
if aligned_end.tzinfo is None:
|
||||
aligned_end = aligned_end.replace(tzinfo=timezone.utc)
|
||||
|
||||
# Check if date falls within aligned range (now both are timezone-aware)
|
||||
if aligned_start <= record_date <= aligned_end:
|
||||
# Validate that record has required fields
|
||||
if self._validate_sales_record(record):
|
||||
filtered_data.append(record)
|
||||
else:
|
||||
filtered_count += 1
|
||||
else:
|
||||
# Record outside date range
|
||||
filtered_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing sales record: {str(e)}")
|
||||
filtered_count += 1
|
||||
@@ -243,7 +255,8 @@ class TrainingDataOrchestrator:
|
||||
async def _collect_external_data(
|
||||
self,
|
||||
aligned_range: AlignedDateRange,
|
||||
bakery_location: Tuple[float, float]
|
||||
bakery_location: Tuple[float, float],
|
||||
tenant_id: str
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
"""Collect weather and traffic data concurrently with enhanced error handling"""
|
||||
|
||||
@@ -255,14 +268,14 @@ class TrainingDataOrchestrator:
|
||||
# Weather data collection
|
||||
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
|
||||
weather_task = asyncio.create_task(
|
||||
self._collect_weather_data_with_timeout(lat, lon, aligned_range)
|
||||
self._collect_weather_data_with_timeout(lat, lon, aligned_range, tenant_id)
|
||||
)
|
||||
tasks.append(("weather", weather_task))
|
||||
|
||||
# Traffic data collection
|
||||
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
|
||||
traffic_task = asyncio.create_task(
|
||||
self._collect_traffic_data_with_timeout(lat, lon, aligned_range)
|
||||
self._collect_traffic_data_with_timeout(lat, lon, aligned_range, tenant_id)
|
||||
)
|
||||
tasks.append(("traffic", traffic_task))
|
||||
|
||||
@@ -297,18 +310,21 @@ class TrainingDataOrchestrator:
|
||||
self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange
|
||||
aligned_range: AlignedDateRange,
|
||||
tenant_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Collect weather data with timeout and fallback"""
|
||||
try:
|
||||
|
||||
if not self.weather_client:
|
||||
logger.info("Weather client not configured, using synthetic data")
|
||||
return self._generate_synthetic_weather_data(aligned_range)
|
||||
start_date_str = aligned_range.start.isoformat()
|
||||
end_date_str = aligned_range.end.isoformat()
|
||||
|
||||
weather_data = await asyncio.wait_for(
|
||||
self.data_client.fetch_weather_data(aligned_range.start, aligned_range.end, lat, lon),
|
||||
)
|
||||
weather_data = await self.data_client.fetch_weather_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
latitude=lat,
|
||||
longitude=lon)
|
||||
|
||||
# Validate weather data
|
||||
if self._validate_weather_data(weather_data):
|
||||
@@ -319,7 +335,7 @@ class TrainingDataOrchestrator:
|
||||
return self._generate_synthetic_weather_data(aligned_range)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Weather data collection timed out after {timeout_seconds}s, using synthetic data")
|
||||
logger.warning(f"Weather data collection timed out, using synthetic data")
|
||||
return self._generate_synthetic_weather_data(aligned_range)
|
||||
except Exception as e:
|
||||
logger.warning(f"Weather data collection failed: {e}, using synthetic data")
|
||||
@@ -329,24 +345,27 @@ class TrainingDataOrchestrator:
|
||||
self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange
|
||||
aligned_range: AlignedDateRange,
|
||||
tenant_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Collect traffic data with timeout and Madrid constraint validation"""
|
||||
try:
|
||||
|
||||
if not self.madrid_client:
|
||||
logger.info("Madrid client not configured, no traffic data available")
|
||||
return []
|
||||
|
||||
|
||||
# Double-check Madrid constraint before making request
|
||||
if self.date_alignment_service.check_madrid_current_month_constraint(aligned_range.end):
|
||||
logger.warning("Madrid current month constraint violation, no traffic data available")
|
||||
return []
|
||||
|
||||
traffic_data = await asyncio.wait_for(
|
||||
self.data_client.fetch_traffic_data(aligned_range.start, aligned_range.end, lat, lon),
|
||||
)
|
||||
|
||||
start_date_str = aligned_range.start.isoformat()
|
||||
end_date_str = aligned_range.end.isoformat()
|
||||
|
||||
traffic_data = await self.data_client.fetch_traffic_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
latitude=lat,
|
||||
longitude=lon)
|
||||
|
||||
# Validate traffic data
|
||||
if self._validate_traffic_data(traffic_data):
|
||||
logger.info(f"Collected {len(traffic_data)} valid traffic records")
|
||||
@@ -356,7 +375,7 @@ class TrainingDataOrchestrator:
|
||||
return []
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Traffic data collection timed out after {timeout_seconds}s")
|
||||
logger.warning(f"Traffic data collection timed out")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Traffic data collection failed: {e}")
|
||||
|
||||
Reference in New Issue
Block a user