Fix issues
This commit is contained in:
@@ -123,6 +123,86 @@ class DataClient:
|
||||
logger.error(f"Error fetching weather data: {e}", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
async def fetch_traffic_data_unified(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
latitude: Optional[float] = None,
|
||||
longitude: Optional[float] = None,
|
||||
force_refresh: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Unified traffic data fetching with intelligent cache-first strategy
|
||||
|
||||
Strategy:
|
||||
1. Check if stored/cached traffic data exists for the date range
|
||||
2. If exists and not force_refresh, return cached data
|
||||
3. If not exists or force_refresh, fetch fresh data
|
||||
4. Always return data without duplicate fetching
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
start_date: Start date string (ISO format)
|
||||
end_date: End date string (ISO format)
|
||||
latitude: Optional latitude for location-based data
|
||||
longitude: Optional longitude for location-based data
|
||||
force_refresh: If True, bypass cache and fetch fresh data
|
||||
"""
|
||||
cache_key = f"{tenant_id}_{start_date}_{end_date}_{latitude}_{longitude}"
|
||||
|
||||
try:
|
||||
# Step 1: Try to get stored/cached data first (unless force_refresh)
|
||||
if not force_refresh and self.supports_stored_traffic_data:
|
||||
logger.info("Attempting to fetch cached traffic data",
|
||||
tenant_id=tenant_id, cache_key=cache_key)
|
||||
|
||||
try:
|
||||
cached_data = await self.external_client.get_stored_traffic_data_for_training(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude
|
||||
)
|
||||
|
||||
if cached_data and len(cached_data) > 0:
|
||||
logger.info(f"✅ Using cached traffic data: {len(cached_data)} records",
|
||||
tenant_id=tenant_id)
|
||||
return cached_data
|
||||
else:
|
||||
logger.info("No cached traffic data found, fetching fresh data",
|
||||
tenant_id=tenant_id)
|
||||
except Exception as cache_error:
|
||||
logger.warning(f"Cache fetch failed, falling back to fresh data: {cache_error}",
|
||||
tenant_id=tenant_id)
|
||||
|
||||
# Step 2: Fetch fresh data if no cache or force_refresh
|
||||
logger.info("Fetching fresh traffic data" + (" (force refresh)" if force_refresh else ""),
|
||||
tenant_id=tenant_id)
|
||||
|
||||
fresh_data = await self.external_client.get_traffic_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude
|
||||
)
|
||||
|
||||
if fresh_data and len(fresh_data) > 0:
|
||||
logger.info(f"✅ Fetched fresh traffic data: {len(fresh_data)} records",
|
||||
tenant_id=tenant_id)
|
||||
return fresh_data
|
||||
else:
|
||||
logger.warning("No fresh traffic data available", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in unified traffic data fetch: {e}",
|
||||
tenant_id=tenant_id, cache_key=cache_key)
|
||||
return []
|
||||
|
||||
# Legacy methods for backward compatibility - now delegate to unified method
|
||||
async def fetch_traffic_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@@ -131,29 +211,16 @@ class DataClient:
|
||||
latitude: Optional[float] = None,
|
||||
longitude: Optional[float] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch traffic data for training
|
||||
"""
|
||||
try:
|
||||
traffic_data = await self.external_client.get_traffic_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude
|
||||
)
|
||||
|
||||
if traffic_data:
|
||||
logger.info(f"Fetched {len(traffic_data)} traffic records",
|
||||
tenant_id=tenant_id)
|
||||
return traffic_data
|
||||
else:
|
||||
logger.warning("No traffic data returned", tenant_id=tenant_id)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching traffic data: {e}", tenant_id=tenant_id)
|
||||
return []
|
||||
"""Legacy method - delegates to unified fetcher with cache-first strategy"""
|
||||
logger.info("Legacy fetch_traffic_data called - delegating to unified method", tenant_id=tenant_id)
|
||||
return await self.fetch_traffic_data_unified(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
force_refresh=False # Use cache-first for legacy calls
|
||||
)
|
||||
|
||||
async def fetch_stored_traffic_data_for_training(
|
||||
self,
|
||||
@@ -163,42 +230,35 @@ class DataClient:
|
||||
latitude: Optional[float] = None,
|
||||
longitude: Optional[float] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch stored traffic data specifically for training/re-training
|
||||
This method accesses previously stored traffic data without making new API calls
|
||||
"""
|
||||
try:
|
||||
if self.supports_stored_traffic_data:
|
||||
# Use the dedicated stored traffic data method
|
||||
stored_traffic_data = await self.external_client.get_stored_traffic_data_for_training(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude
|
||||
)
|
||||
|
||||
if stored_traffic_data:
|
||||
logger.info(f"Retrieved {len(stored_traffic_data)} stored traffic records for training",
|
||||
tenant_id=tenant_id)
|
||||
return stored_traffic_data
|
||||
else:
|
||||
logger.warning("No stored traffic data available for training", tenant_id=tenant_id)
|
||||
return []
|
||||
else:
|
||||
# Fallback to regular traffic data method
|
||||
logger.info("Using fallback traffic data method for training")
|
||||
return await self.fetch_traffic_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching stored traffic data for training: {e}", tenant_id=tenant_id)
|
||||
return []
|
||||
"""Legacy method - delegates to unified fetcher with cache-first strategy"""
|
||||
logger.info("Legacy fetch_stored_traffic_data_for_training called - delegating to unified method", tenant_id=tenant_id)
|
||||
return await self.fetch_traffic_data_unified(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
force_refresh=False # Use cache-first for training calls
|
||||
)
|
||||
|
||||
async def refresh_traffic_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
latitude: Optional[float] = None,
|
||||
longitude: Optional[float] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convenience method to force refresh traffic data"""
|
||||
logger.info("Force refreshing traffic data (bypassing cache)", tenant_id=tenant_id)
|
||||
return await self.fetch_traffic_data_unified(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
force_refresh=True # Force fresh data
|
||||
)
|
||||
|
||||
async def validate_data_quality(
|
||||
self,
|
||||
|
||||
@@ -73,14 +73,47 @@ class TrainingDataOrchestrator:
|
||||
logger.info(f"Starting comprehensive training data preparation for tenant {tenant_id}, job {job_id}")
|
||||
|
||||
try:
|
||||
# Step 1: Fetch and validate sales data (unified approach)
|
||||
sales_data = await self.data_client.fetch_sales_data(tenant_id, fetch_all=True)
|
||||
|
||||
sales_data = await self.data_client.fetch_sales_data(tenant_id)
|
||||
# Pre-flight validation moved here to eliminate duplicate fetching
|
||||
if not sales_data or len(sales_data) == 0:
|
||||
error_msg = f"No sales data available for tenant {tenant_id}. Please import sales data before starting training."
|
||||
logger.error("Training aborted - no sales data", tenant_id=tenant_id, job_id=job_id)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Step 1: Extract and validate sales data date range
|
||||
# Debug: Analyze the sales data structure to understand product distribution
|
||||
sales_df_debug = pd.DataFrame(sales_data)
|
||||
if 'inventory_product_id' in sales_df_debug.columns:
|
||||
unique_products_found = sales_df_debug['inventory_product_id'].unique()
|
||||
product_counts = sales_df_debug['inventory_product_id'].value_counts().to_dict()
|
||||
|
||||
logger.info("Sales data analysis (moved from pre-flight)",
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
total_sales_records=len(sales_data),
|
||||
unique_products_count=len(unique_products_found),
|
||||
unique_products=unique_products_found.tolist(),
|
||||
records_per_product=product_counts)
|
||||
|
||||
if len(unique_products_found) == 1:
|
||||
logger.warning("POTENTIAL ISSUE: Only ONE unique product found in all sales data",
|
||||
tenant_id=tenant_id,
|
||||
single_product=unique_products_found[0],
|
||||
record_count=len(sales_data))
|
||||
else:
|
||||
logger.warning("No 'inventory_product_id' column found in sales data",
|
||||
tenant_id=tenant_id,
|
||||
columns=list(sales_df_debug.columns))
|
||||
|
||||
logger.info(f"Sales data validation passed: {len(sales_data)} sales records found",
|
||||
tenant_id=tenant_id, job_id=job_id)
|
||||
|
||||
# Step 2: Extract and validate sales data date range
|
||||
sales_date_range = self._extract_sales_date_range(sales_data)
|
||||
logger.info(f"Sales data range detected: {sales_date_range.start} to {sales_date_range.end}")
|
||||
|
||||
# Step 2: Apply date alignment across all data sources
|
||||
# Step 3: Apply date alignment across all data sources
|
||||
aligned_range = self.date_alignment_service.validate_and_align_dates(
|
||||
user_sales_range=sales_date_range,
|
||||
requested_start=requested_start,
|
||||
@@ -91,21 +124,21 @@ class TrainingDataOrchestrator:
|
||||
if aligned_range.constraints:
|
||||
logger.info(f"Applied constraints: {aligned_range.constraints}")
|
||||
|
||||
# Step 3: Filter sales data to aligned date range
|
||||
# Step 4: Filter sales data to aligned date range
|
||||
filtered_sales = self._filter_sales_data(sales_data, aligned_range)
|
||||
|
||||
# Step 4: Collect external data sources concurrently
|
||||
# Step 5: Collect external data sources concurrently
|
||||
logger.info("Collecting external data sources...")
|
||||
weather_data, traffic_data = await self._collect_external_data(
|
||||
aligned_range, bakery_location, tenant_id
|
||||
)
|
||||
|
||||
# Step 5: Validate data quality
|
||||
# Step 6: Validate data quality
|
||||
data_quality_results = self._validate_data_sources(
|
||||
filtered_sales, weather_data, traffic_data, aligned_range
|
||||
)
|
||||
|
||||
# Step 6: Create comprehensive training dataset
|
||||
# Step 7: Create comprehensive training dataset
|
||||
training_dataset = TrainingDataSet(
|
||||
sales_data=filtered_sales,
|
||||
weather_data=weather_data,
|
||||
@@ -126,7 +159,7 @@ class TrainingDataOrchestrator:
|
||||
}
|
||||
)
|
||||
|
||||
# Step 7: Final validation
|
||||
# Step 8: Final validation
|
||||
final_validation = self.validate_training_data_quality(training_dataset)
|
||||
training_dataset.metadata["final_validation"] = final_validation
|
||||
|
||||
@@ -375,14 +408,16 @@ class TrainingDataOrchestrator:
|
||||
start_date_str = aligned_range.start.isoformat()
|
||||
end_date_str = aligned_range.end.isoformat()
|
||||
|
||||
# Enhanced: Fetch traffic data using new abstracted service
|
||||
# Enhanced: Fetch traffic data using unified cache-first method
|
||||
# This automatically detects the appropriate city and uses the right client
|
||||
traffic_data = await self.data_client.fetch_traffic_data(
|
||||
traffic_data = await self.data_client.fetch_traffic_data_unified(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
latitude=lat,
|
||||
longitude=lon)
|
||||
longitude=lon,
|
||||
force_refresh=False # Use cache-first strategy
|
||||
)
|
||||
|
||||
# Enhanced validation including pedestrian inference data
|
||||
if self._validate_traffic_data_enhanced(traffic_data):
|
||||
@@ -461,54 +496,6 @@ class TrainingDataOrchestrator:
|
||||
minimal_traffic_data = [{"city": "madrid", "source": "legacy"}] * min(record_count, 1)
|
||||
self._log_enhanced_traffic_data_storage(lat, lon, aligned_range, record_count, minimal_traffic_data)
|
||||
|
||||
async def retrieve_stored_traffic_for_retraining(
|
||||
self,
|
||||
bakery_location: Tuple[float, float],
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
tenant_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve previously stored traffic data for model re-training
|
||||
This method specifically accesses the stored traffic data without making new API calls
|
||||
"""
|
||||
lat, lon = bakery_location
|
||||
|
||||
try:
|
||||
# Use the dedicated stored traffic data method for training
|
||||
stored_traffic_data = await self.data_client.fetch_stored_traffic_data_for_training(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat(),
|
||||
latitude=lat,
|
||||
longitude=lon
|
||||
)
|
||||
|
||||
if stored_traffic_data:
|
||||
logger.info(
|
||||
f"Retrieved {len(stored_traffic_data)} stored traffic records for re-training",
|
||||
location=f"{lat:.4f},{lon:.4f}",
|
||||
date_range=f"{start_date.isoformat()} to {end_date.isoformat()}",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return stored_traffic_data
|
||||
else:
|
||||
logger.warning(
|
||||
"No stored traffic data found for re-training",
|
||||
location=f"{lat:.4f},{lon:.4f}",
|
||||
date_range=f"{start_date.isoformat()} to {end_date.isoformat()}"
|
||||
)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve stored traffic data for re-training: {e}",
|
||||
location=f"{lat:.4f},{lon:.4f}",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return []
|
||||
|
||||
def _validate_weather_data(self, weather_data: List[Dict[str, Any]]) -> bool:
|
||||
"""Validate weather data quality"""
|
||||
if not weather_data:
|
||||
|
||||
@@ -137,42 +137,7 @@ class EnhancedTrainingService:
|
||||
await self._init_repositories(session)
|
||||
|
||||
try:
|
||||
# Pre-flight check: Verify sales data exists before starting training
|
||||
from app.services.data_client import DataClient
|
||||
data_client = DataClient()
|
||||
sales_data = await data_client.fetch_sales_data(tenant_id, fetch_all=True)
|
||||
|
||||
if not sales_data or len(sales_data) == 0:
|
||||
error_msg = f"No sales data available for tenant {tenant_id}. Please import sales data before starting training."
|
||||
logger.error("Training aborted - no sales data", tenant_id=tenant_id, job_id=job_id)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Debug: Analyze the sales data structure to understand product distribution
|
||||
sales_df_debug = pd.DataFrame(sales_data)
|
||||
if 'inventory_product_id' in sales_df_debug.columns:
|
||||
unique_products_found = sales_df_debug['inventory_product_id'].unique()
|
||||
product_counts = sales_df_debug['inventory_product_id'].value_counts().to_dict()
|
||||
|
||||
logger.info("Pre-flight sales data analysis",
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
total_sales_records=len(sales_data),
|
||||
unique_products_count=len(unique_products_found),
|
||||
unique_products=unique_products_found.tolist(),
|
||||
records_per_product=product_counts)
|
||||
|
||||
if len(unique_products_found) == 1:
|
||||
logger.warning("POTENTIAL ISSUE: Only ONE unique product found in all sales data",
|
||||
tenant_id=tenant_id,
|
||||
single_product=unique_products_found[0],
|
||||
record_count=len(sales_data))
|
||||
else:
|
||||
logger.warning("No 'inventory_product_id' column found in sales data",
|
||||
tenant_id=tenant_id,
|
||||
columns=list(sales_df_debug.columns))
|
||||
|
||||
logger.info(f"Pre-flight check passed: {len(sales_data)} sales records found",
|
||||
tenant_id=tenant_id, job_id=job_id)
|
||||
# Pre-flight check moved to orchestrator to eliminate duplicate sales data fetching
|
||||
|
||||
# Check if training log already exists, create if not
|
||||
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||
@@ -202,12 +167,13 @@ class EnhancedTrainingService:
|
||||
step_details="Data"
|
||||
)
|
||||
|
||||
# Step 1: Prepare training dataset
|
||||
logger.info("Step 1: Preparing and aligning training data")
|
||||
# Step 1: Prepare training dataset (includes sales data validation)
|
||||
logger.info("Step 1: Preparing and aligning training data (with validation)")
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 10, "data_validation", "running"
|
||||
)
|
||||
|
||||
# Orchestrator now handles sales data validation to eliminate duplicate fetching
|
||||
training_dataset = await self.orchestrator.prepare_training_data(
|
||||
tenant_id=tenant_id,
|
||||
bakery_location=bakery_location,
|
||||
@@ -216,6 +182,10 @@ class EnhancedTrainingService:
|
||||
job_id=job_id
|
||||
)
|
||||
|
||||
# Log the results from orchestrator's unified sales data fetch
|
||||
logger.info(f"Sales data validation completed: {len(training_dataset.sales_data)} records",
|
||||
tenant_id=tenant_id, job_id=job_id)
|
||||
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id, 30, "data_preparation_complete", "running"
|
||||
)
|
||||
@@ -285,6 +255,27 @@ class EnhancedTrainingService:
|
||||
# Make sure all data is JSON-serializable before saving to database
|
||||
json_safe_result = make_json_serializable(final_result)
|
||||
|
||||
# Ensure results is a proper dict for database storage
|
||||
if not isinstance(json_safe_result, dict):
|
||||
logger.warning("JSON safe result is not a dict, wrapping it", result_type=type(json_safe_result))
|
||||
json_safe_result = {"training_data": json_safe_result}
|
||||
|
||||
# Double-check JSON serialization by attempting to serialize
|
||||
import json
|
||||
try:
|
||||
json.dumps(json_safe_result)
|
||||
logger.debug("Results successfully JSON-serializable", job_id=job_id)
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.error("Results still not JSON-serializable after cleaning",
|
||||
job_id=job_id, error=str(e))
|
||||
# Create a minimal safe result
|
||||
json_safe_result = {
|
||||
"status": "completed",
|
||||
"job_id": job_id,
|
||||
"models_created": final_result.get("products_trained", 0),
|
||||
"error": "Result serialization failed"
|
||||
}
|
||||
|
||||
await self.training_log_repo.complete_training_log(
|
||||
job_id, results=json_safe_result
|
||||
)
|
||||
@@ -313,6 +304,9 @@ class EnhancedTrainingService:
|
||||
"completed_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Ensure error result is JSON serializable
|
||||
error_result = make_json_serializable(error_result)
|
||||
|
||||
return self._create_detailed_training_response(error_result)
|
||||
|
||||
async def _store_trained_models(
|
||||
|
||||
Reference in New Issue
Block a user