Improve training code 2

This commit is contained in:
Urtzi Alfaro
2025-07-28 20:20:54 +02:00
parent 98f546af12
commit 7cd595df81
6 changed files with 229 additions and 153 deletions

View File

@@ -59,10 +59,10 @@ async def start_training_job(
# Delegate to training service (Step 1 of the flow)
result = await training_service.start_training_job(
tenant_id=tenant_id,
bakery_location=request.bakery_location or (40.4168, -3.7038), # Default Madrid
bakery_location=(40.4168, -3.7038), # Default Madrid coordinates
requested_start=request.start_date if request.start_date else None,
requested_end=request.end_date if request.end_date else None,
job_id=request.job_id
job_id=None # Let the service generate it
)
return TrainingJobResponse(**result)

View File

@@ -7,7 +7,7 @@ Handles data preparation, date alignment, cleaning, and feature engineering for
import pandas as pd
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
import logging
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
@@ -280,14 +280,21 @@ class BakeryDataProcessor:
def _merge_weather_features(self,
daily_sales: pd.DataFrame,
weather_data: pd.DataFrame) -> pd.DataFrame:
"""Merge weather features with enhanced handling"""
"""Merge weather features with enhanced Madrid-specific handling"""
# ✅ FIX: Define weather_defaults OUTSIDE try block to fix scope error
weather_defaults = {
'temperature': 15.0,
'precipitation': 0.0,
'humidity': 60.0,
'wind_speed': 5.0,
'pressure': 1013.0
}
if weather_data.empty:
# Add default weather columns with Madrid-appropriate values
daily_sales['temperature'] = 15.0 # Average Madrid temperature
daily_sales['precipitation'] = 0.0 # Default no rain
daily_sales['humidity'] = 60.0 # Moderate humidity
daily_sales['wind_speed'] = 5.0 # Light wind
# Add default weather columns
for feature, default_value in weather_defaults.items():
daily_sales[feature] = default_value
return daily_sales
try:
@@ -297,14 +304,22 @@ class BakeryDataProcessor:
if 'date' not in weather_clean.columns and 'ds' in weather_clean.columns:
weather_clean = weather_clean.rename(columns={'ds': 'date'})
# ✅ FIX: Ensure timezone consistency
weather_clean['date'] = pd.to_datetime(weather_clean['date'])
daily_sales['date'] = pd.to_datetime(daily_sales['date'])
# Remove timezone info from both to make them compatible
if weather_clean['date'].dt.tz is not None:
weather_clean['date'] = weather_clean['date'].dt.tz_localize(None)
if daily_sales['date'].dt.tz is not None:
daily_sales['date'] = daily_sales['date'].dt.tz_localize(None)
# Map weather columns to standard names
weather_mapping = {
'temperature': ['temperature', 'temp', 'temperatura', 'temp_avg', 'temperature_avg'],
'precipitation': ['precipitation', 'rain', 'lluvia', 'precipitacion', 'rainfall'],
'temperature': ['temperature', 'temp', 'temperatura'],
'precipitation': ['precipitation', 'precip', 'rain', 'lluvia'],
'humidity': ['humidity', 'humedad', 'relative_humidity'],
'wind_speed': ['wind_speed', 'viento', 'wind', 'wind_avg'],
'wind_speed': ['wind_speed', 'viento', 'wind'],
'pressure': ['pressure', 'presion', 'atmospheric_pressure']
}
@@ -324,14 +339,6 @@ class BakeryDataProcessor:
merged = daily_sales.merge(weather_clean, on='date', how='left')
# Fill missing weather values with Madrid-appropriate defaults
weather_defaults = {
'temperature': 15.0,
'precipitation': 0.0,
'humidity': 60.0,
'wind_speed': 5.0,
'pressure': 1013.0
}
for feature, default_value in weather_defaults.items():
if feature in merged.columns:
merged[feature] = merged[feature].fillna(default_value)
@@ -340,11 +347,12 @@ class BakeryDataProcessor:
except Exception as e:
logger.warning(f"Error merging weather data: {e}")
# Add default weather columns if merge fails
# Add default weather columns if merge fails (weather_defaults now in scope)
for feature, default_value in weather_defaults.items():
daily_sales[feature] = default_value
return daily_sales
def _merge_traffic_features(self,
daily_sales: pd.DataFrame,
traffic_data: pd.DataFrame) -> pd.DataFrame:
@@ -430,7 +438,7 @@ class BakeryDataProcessor:
bins=[-0.1, 0, 2, 10, np.inf],
labels=[0, 1, 2, 3]).astype(int)
# Traffic-based features
# ✅ FIX: Traffic-based features with NaN protection
if 'traffic_volume' in df.columns:
# Calculate traffic quantiles for relative measures
q75 = df['traffic_volume'].quantile(0.75)
@@ -438,7 +446,21 @@ class BakeryDataProcessor:
df['high_traffic'] = (df['traffic_volume'] > q75).astype(int)
df['low_traffic'] = (df['traffic_volume'] < q25).astype(int)
df['traffic_normalized'] = (df['traffic_volume'] - df['traffic_volume'].mean()) / df['traffic_volume'].std()
# ✅ FIX: Safe normalization with NaN protection
traffic_std = df['traffic_volume'].std()
traffic_mean = df['traffic_volume'].mean()
if traffic_std > 0 and not pd.isna(traffic_std) and not pd.isna(traffic_mean):
# Normal case: valid standard deviation
df['traffic_normalized'] = (df['traffic_volume'] - traffic_mean) / traffic_std
else:
# Edge case: all values are the same or contain NaN
logger.warning("Traffic volume has zero standard deviation or contains NaN, using zeros for normalized values")
df['traffic_normalized'] = 0.0
# ✅ ADDITIONAL SAFETY: Fill any remaining NaN values
df['traffic_normalized'] = df['traffic_normalized'].fillna(0.0)
# Interaction features - bakery specific
if 'is_weekend' in df.columns and 'temperature' in df.columns:
@@ -465,30 +487,20 @@ class BakeryDataProcessor:
# Month-specific features for bakery seasonality
if 'month' in df.columns:
# Tourist season in Madrid (spring/summer)
df['is_tourist_season'] = df['month'].isin([4, 5, 6, 7, 8, 9]).astype(int)
# High-demand months (holidays, summer)
df['is_high_demand_month'] = df['month'].isin([6, 7, 8, 12]).astype(int)
# Christmas season (affects bakery sales significantly)
df['is_christmas_season'] = df['month'].isin([11, 12]).astype(int)
# Spring/summer months
df['is_warm_season'] = df['month'].isin([4, 5, 6, 7, 8, 9]).astype(int)
# Back-to-school/work season
df['is_back_to_work_season'] = df['month'].isin([1, 9]).astype(int)
# Lagged features (if we have enough data)
if len(df) > 7 and 'quantity' in df.columns:
# Rolling averages for trend detection
df['sales_7day_avg'] = df['quantity'].rolling(window=7, min_periods=3).mean()
df['sales_14day_avg'] = df['quantity'].rolling(window=14, min_periods=7).mean()
# Day-over-day changes
df['sales_change_1day'] = df['quantity'].diff()
df['sales_change_7day'] = df['quantity'].diff(7) # Week-over-week
# Fill NaN values for lagged features
df['sales_7day_avg'] = df['sales_7day_avg'].fillna(df['quantity'])
df['sales_14day_avg'] = df['sales_14day_avg'].fillna(df['quantity'])
df['sales_change_1day'] = df['sales_change_1day'].fillna(0)
df['sales_change_7day'] = df['sales_change_7day'].fillna(0)
# ✅ FINAL SAFETY CHECK: Remove any remaining NaN values
# Check for NaN values in all numeric columns and fill them
numeric_columns = df.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
if df[col].isna().any():
nan_count = df[col].isna().sum()
logger.warning(f"Found {nan_count} NaN values in column '{col}', filling with 0")
df[col] = df[col].fillna(0.0)
return df

View File

@@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
import logging
from datetime import datetime, timedelta, timezone
logger = logging.getLogger(__name__)
@@ -84,15 +85,24 @@ class DateAlignmentService:
) -> DateRange:
"""Determine the base date range for training."""
# ✅ FIX: Ensure all datetimes are timezone-aware for comparison
def ensure_timezone_aware(dt: datetime) -> datetime:
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt
# Use explicit dates if provided
if requested_start and requested_end:
requested_start = ensure_timezone_aware(requested_start)
requested_end = ensure_timezone_aware(requested_end)
if requested_end <= requested_start:
raise ValueError("End date must be after start date")
return DateRange(requested_start, requested_end, DataSourceType.BAKERY_SALES)
# Otherwise, use the user's sales data range as the foundation
start_date = requested_start or user_sales_range.start
end_date = requested_end or user_sales_range.end
start_date = ensure_timezone_aware(requested_start or user_sales_range.start)
end_date = ensure_timezone_aware(requested_end or user_sales_range.end)
# Ensure we don't exceed maximum training range
if (end_date - start_date).days > self.MAX_TRAINING_RANGE_DAYS:
@@ -104,7 +114,7 @@ class DateAlignmentService:
def _apply_data_source_constraints(self, base_range: DateRange) -> AlignedDateRange:
"""Apply constraints from each data source and determine final aligned range."""
current_month = datetime.now().replace(day=1, hour=0, minute=0, second=0, microsecond=0)
current_month = datetime.now(timezone.utc).replace(day=1, hour=0, minute=0, second=0, microsecond=0)
available_sources = [DataSourceType.BAKERY_SALES] # Always have sales data
constraints = {}
@@ -121,7 +131,7 @@ class DateAlignmentService:
# Weather Forecast Constraint
# Weather data available from yesterday backward
weather_end_date = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=1)
weather_end_date = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=1)
if base_range.end > weather_end_date:
if new_end > weather_end_date:
new_end = weather_end_date
@@ -150,7 +160,7 @@ class DateAlignmentService:
Get the latest available date for Madrid traffic data.
Data for current month is not available until the following month.
"""
now = datetime.now()
now = datetime.now(timezone.utc)
if now.day == 1:
# If it's the first day of the month, data up to previous month should be available
last_available_month = now.replace(day=1) - timedelta(days=1)
@@ -234,7 +244,7 @@ class DateAlignmentService:
Returns:
True if the constraint is violated (end date is in current month)
"""
now = datetime.now()
now = datetime.now(timezone.utc)
current_month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return end_date >= current_month_start

View File

@@ -10,6 +10,8 @@ from dataclasses import dataclass
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import timezone
import pandas as pd
from app.services.data_client import DataServiceClient
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
@@ -35,8 +37,6 @@ class TrainingDataOrchestrator:
madrid_client=None,
weather_client=None,
date_alignment_service: DateAlignmentService = None):
self.madrid_client = madrid_client
self.weather_client = weather_client
self.data_client = DataServiceClient()
self.date_alignment_service = date_alignment_service or DateAlignmentService()
self.max_concurrent_requests = 3
@@ -67,7 +67,7 @@ class TrainingDataOrchestrator:
try:
sales_data = self.data_client.fetch_sales_data(tenant_id)
sales_data = await self.data_client.fetch_sales_data(tenant_id)
# Step 1: Extract and validate sales data date range
sales_date_range = self._extract_sales_date_range(sales_data)
@@ -90,7 +90,7 @@ class TrainingDataOrchestrator:
# Step 4: Collect external data sources concurrently
logger.info("Collecting external data sources...")
weather_data, traffic_data = await self._collect_external_data(
aligned_range, bakery_location
aligned_range, bakery_location, tenant_id
)
# Step 5: Validate data quality
@@ -136,43 +136,32 @@ class TrainingDataOrchestrator:
raise ValueError(f"Failed to prepare training data: {str(e)}")
def _extract_sales_date_range(self, sales_data: List[Dict[str, Any]]) -> DateRange:
"""Extract and validate the date range from sales data"""
"""Extract date range from sales data with timezone handling"""
if not sales_data:
raise ValueError("No sales data provided")
dates = []
valid_records = 0
for record in sales_data:
try:
if 'date' in record:
date_val = record['date']
if isinstance(date_val, str):
# Handle various date formats
if 'T' in date_val:
date_val = date_val.replace('Z', '+00:00')
parsed_date = datetime.fromisoformat(date_val.split('T')[0])
elif isinstance(date_val, datetime):
parsed_date = date_val
else:
continue
dates.append(parsed_date)
valid_records += 1
except (ValueError, TypeError) as e:
logger.warning(f"Invalid date in sales record: {record.get('date', 'N/A')} - {str(e)}")
continue
date_value = record.get('date')
if date_value:
# ✅ FIX: Ensure timezone-aware datetime
if isinstance(date_value, str):
dt = pd.to_datetime(date_value)
if dt.tz is None:
dt = dt.replace(tzinfo=timezone.utc)
dates.append(dt.to_pydatetime())
elif isinstance(date_value, datetime):
if date_value.tzinfo is None:
date_value = date_value.replace(tzinfo=timezone.utc)
dates.append(date_value)
if not dates:
raise ValueError("No valid dates found in sales data")
logger.info(f"Processed {valid_records} valid date records from {len(sales_data)} total records")
start_date = min(dates)
end_date = max(dates)
return DateRange(
start=min(dates),
end=max(dates),
source=DataSourceType.BAKERY_SALES
)
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
def _filter_sales_data(
self,
@@ -187,20 +176,43 @@ class TrainingDataOrchestrator:
try:
if 'date' in record:
record_date = record['date']
# ✅ FIX: Proper timezone handling for date parsing
if isinstance(record_date, str):
if 'T' in record_date:
record_date = record_date.replace('Z', '+00:00')
record_date = datetime.fromisoformat(record_date.split('T')[0])
# Parse with timezone info intact
parsed_date = datetime.fromisoformat(record_date.split('T')[0])
# Ensure timezone-aware
if parsed_date.tzinfo is None:
parsed_date = parsed_date.replace(tzinfo=timezone.utc)
record_date = parsed_date
elif isinstance(record_date, datetime):
# Ensure timezone-aware
if record_date.tzinfo is None:
record_date = record_date.replace(tzinfo=timezone.utc)
# Normalize to start of day
record_date = record_date.replace(hour=0, minute=0, second=0, microsecond=0)
# Check if date falls within aligned range
if aligned_range.start <= record_date <= aligned_range.end:
# ✅ FIX: Ensure aligned_range dates are also timezone-aware for comparison
aligned_start = aligned_range.start
aligned_end = aligned_range.end
if aligned_start.tzinfo is None:
aligned_start = aligned_start.replace(tzinfo=timezone.utc)
if aligned_end.tzinfo is None:
aligned_end = aligned_end.replace(tzinfo=timezone.utc)
# Check if date falls within aligned range (now both are timezone-aware)
if aligned_start <= record_date <= aligned_end:
# Validate that record has required fields
if self._validate_sales_record(record):
filtered_data.append(record)
else:
filtered_count += 1
else:
# Record outside date range
filtered_count += 1
except Exception as e:
logger.warning(f"Error processing sales record: {str(e)}")
filtered_count += 1
@@ -243,7 +255,8 @@ class TrainingDataOrchestrator:
async def _collect_external_data(
self,
aligned_range: AlignedDateRange,
bakery_location: Tuple[float, float]
bakery_location: Tuple[float, float],
tenant_id: str
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Collect weather and traffic data concurrently with enhanced error handling"""
@@ -255,14 +268,14 @@ class TrainingDataOrchestrator:
# Weather data collection
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
weather_task = asyncio.create_task(
self._collect_weather_data_with_timeout(lat, lon, aligned_range)
self._collect_weather_data_with_timeout(lat, lon, aligned_range, tenant_id)
)
tasks.append(("weather", weather_task))
# Traffic data collection
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
traffic_task = asyncio.create_task(
self._collect_traffic_data_with_timeout(lat, lon, aligned_range)
self._collect_traffic_data_with_timeout(lat, lon, aligned_range, tenant_id)
)
tasks.append(("traffic", traffic_task))
@@ -297,18 +310,21 @@ class TrainingDataOrchestrator:
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Collect weather data with timeout and fallback"""
try:
if not self.weather_client:
logger.info("Weather client not configured, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
start_date_str = aligned_range.start.isoformat()
end_date_str = aligned_range.end.isoformat()
weather_data = await asyncio.wait_for(
self.data_client.fetch_weather_data(aligned_range.start, aligned_range.end, lat, lon),
)
weather_data = await self.data_client.fetch_weather_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=lat,
longitude=lon)
# Validate weather data
if self._validate_weather_data(weather_data):
@@ -319,7 +335,7 @@ class TrainingDataOrchestrator:
return self._generate_synthetic_weather_data(aligned_range)
except asyncio.TimeoutError:
logger.warning(f"Weather data collection timed out after {timeout_seconds}s, using synthetic data")
logger.warning(f"Weather data collection timed out, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
except Exception as e:
logger.warning(f"Weather data collection failed: {e}, using synthetic data")
@@ -329,23 +345,26 @@ class TrainingDataOrchestrator:
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Collect traffic data with timeout and Madrid constraint validation"""
try:
if not self.madrid_client:
logger.info("Madrid client not configured, no traffic data available")
return []
# Double-check Madrid constraint before making request
if self.date_alignment_service.check_madrid_current_month_constraint(aligned_range.end):
logger.warning("Madrid current month constraint violation, no traffic data available")
return []
traffic_data = await asyncio.wait_for(
self.data_client.fetch_traffic_data(aligned_range.start, aligned_range.end, lat, lon),
)
start_date_str = aligned_range.start.isoformat()
end_date_str = aligned_range.end.isoformat()
traffic_data = await self.data_client.fetch_traffic_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=lat,
longitude=lon)
# Validate traffic data
if self._validate_traffic_data(traffic_data):
@@ -356,7 +375,7 @@ class TrainingDataOrchestrator:
return []
except asyncio.TimeoutError:
logger.warning(f"Traffic data collection timed out after {timeout_seconds}s")
logger.warning(f"Traffic data collection timed out")
return []
except Exception as e:
logger.warning(f"Traffic data collection failed: {e}")

View File

@@ -81,36 +81,26 @@ class TrainingService:
)
# Step 3: Compile final results
final_result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "completed",
"training_results": training_results,
"data_summary": {
"sales_records": len(training_dataset.sales_data),
"weather_records": len(training_dataset.weather_data),
"traffic_records": len(training_dataset.traffic_data),
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat()
},
"data_sources_used": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
},
"completed_at": datetime.now().isoformat()
}
logger.info(f"Training job {job_id} completed successfully")
return final_result
return {
"job_id": job_id,
"status": "completed", # or "running" if async
"message": "Training job completed successfully",
"tenant_id": tenant_id,
"created_at": datetime.now(),
"estimated_duration_minutes": 5 # reasonable estimate
}
except Exception as e:
logger.error(f"Training job {job_id} failed: {str(e)}")
# Return error response that still matches schema
return {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "failed",
"error_message": str(e),
"failed_at": datetime.now().isoformat()
"message": f"Training job failed: {str(e)}",
"tenant_id": tenant_id,
"created_at": datetime.now(),
"estimated_duration_minutes": 0
}
async def start_single_product_training(

View File

@@ -622,7 +622,7 @@ fi
echo ""
# =================================================================
# STEP 4: MODEL TRAINING (ONBOARDING PAGE STEP 4)
# STEP 4: MODEL TRAINING (ONBOARDING PAGE STEP 4) - FIXED
# =================================================================
echo -e "${STEP_ICONS[3]} ${PURPLE}STEP 4: AI MODEL TRAINING${NC}"
@@ -633,22 +633,28 @@ log_step "4.1. Starting model training process with real data products"
# Get unique products from the imported data for training
# Extract some real product names from the CSV for training
REAL_PRODUCTS=$(tail -n +2 "$PREPARED_CSV" | cut -d',' -f2 | sort | uniq | head -3 | tr '\n' ',' | sed 's/,$//')
REAL_PRODUCTS_RAW=$(tail -n +2 "$PREPARED_CSV" | cut -d',' -f2 | sort | uniq | head -3 | tr '\n' ',' | sed 's/,$//')
if [ -z "$REAL_PRODUCTS" ]; then
if [ -z "$REAL_PRODUCTS_RAW" ]; then
# Fallback to default products if extraction fails
REAL_PRODUCTS='"Pan de molde","Croissants","Magdalenas"'
REAL_PRODUCTS_ARRAY='["Pan de molde","Croissants","Magdalenas"]'
log_warning "Could not extract real product names, using defaults"
else
# Format for JSON array
REAL_PRODUCTS=$(echo "$REAL_PRODUCTS" | sed 's/,/","/g' | sed 's/^/"/' | sed 's/$/"/')
log_success "Extracted real products for training: $REAL_PRODUCTS"
# Format for JSON array properly
REAL_PRODUCTS_ARRAY='['$(echo "$REAL_PRODUCTS_RAW" | sed 's/,/","/g' | sed 's/^/"/' | sed 's/$/"/')']'
log_success "Extracted real products for training: $REAL_PRODUCTS_ARRAY"
fi
# Training request with real products
# ✅ FIXED: Training request with correct data types matching TrainingJobRequest schema
TRAINING_DATA="{
\"tenant_id\": \"$TENANT_ID\"
}
\"products\": $REAL_PRODUCTS_ARRAY,
\"max_workers\": 4,
\"seasonality_mode\": \"additive\",
\"daily_seasonality\": true,
\"weekly_seasonality\": true,
\"yearly_seasonality\": true,
\"force_retrain\": false,
\"parallel_training\": true
}"
echo "Training Request:"
@@ -668,15 +674,54 @@ echo "Training HTTP Status Code: $HTTP_CODE"
echo "Training Response:"
echo "$TRAINING_RESPONSE" | python3 -m json.tool 2>/dev/null || echo "$TRAINING_RESPONSE"
TRAINING_TASK_ID=$(extract_json_field "$TRAINING_RESPONSE" "job_id")
if [ -z "$TRAINING_TASK_ID" ]; then
TRAINING_TASK_ID=$(extract_json_field "$TRAINING_RESPONSE" "id")
fi
# ✅ FIXED: Better error handling for 422 responses
if [ "$HTTP_CODE" = "422" ]; then
log_error "Training request failed with validation error (HTTP 422)"
echo "This usually means the request doesn't match the expected schema."
echo "Common causes:"
echo " - Wrong data types (string instead of integer)"
echo " - Invalid field values (seasonality_mode must be 'additive' or 'multiplicative')"
echo " - Missing required headers"
echo ""
echo "Response details:"
echo "$TRAINING_RESPONSE"
if [ -n "$TRAINING_TASK_ID" ]; then
log_success "Training started successfully - Task ID: $TRAINING_TASK_ID"
# Try a minimal request that should work
log_step "4.2. Attempting minimal training request as fallback"
MINIMAL_TRAINING_DATA='{"seasonality_mode": "additive"}'
FALLBACK_RESPONSE=$(curl -s -w "\nHTTP_CODE:%{http_code}" -X POST "$API_BASE/api/v1/tenants/$TENANT_ID/training/jobs" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $ACCESS_TOKEN" \
-H "X-Tenant-ID: $TENANT_ID" \
-d "$MINIMAL_TRAINING_DATA")
FALLBACK_HTTP_CODE=$(echo "$FALLBACK_RESPONSE" | grep "HTTP_CODE:" | cut -d: -f2)
FALLBACK_RESPONSE=$(echo "$FALLBACK_RESPONSE" | sed '/HTTP_CODE:/d')
echo "Fallback HTTP Status Code: $FALLBACK_HTTP_CODE"
echo "Fallback Response:"
echo "$FALLBACK_RESPONSE" | python3 -m json.tool 2>/dev/null || echo "$FALLBACK_RESPONSE"
if [ "$FALLBACK_HTTP_CODE" = "200" ] || [ "$FALLBACK_HTTP_CODE" = "201" ]; then
log_success "Minimal training request succeeded"
TRAINING_TASK_ID=$(extract_json_field "$FALLBACK_RESPONSE" "job_id")
else
log_error "Both training requests failed"
fi
else
# Original success handling
TRAINING_TASK_ID=$(extract_json_field "$TRAINING_RESPONSE" "job_id")
if [ -z "$TRAINING_TASK_ID" ]; then
TRAINING_TASK_ID=$(extract_json_field "$TRAINING_RESPONSE" "id")
fi
if [ -n "$TRAINING_TASK_ID" ]; then
log_success "Training started successfully - Task ID: $TRAINING_TASK_ID"
else
log_warning "Could not start training - task ID not found"
fi
fi
echo ""