Bug fixes of training
This commit is contained in:
@@ -41,8 +41,3 @@ nodes:
|
|||||||
- containerPort: 30800
|
- containerPort: 30800
|
||||||
hostPort: 8000
|
hostPort: 8000
|
||||||
protocol: TCP
|
protocol: TCP
|
||||||
sysctls:
|
|
||||||
# Increase fs.inotify limits to prevent "too many open files" errors
|
|
||||||
fs.inotify.max_user_watches: 524288
|
|
||||||
fs.inotify.max_user_instances: 256
|
|
||||||
fs.inotify.max_queued_events: 32768
|
|
||||||
|
|||||||
21
services/external/README.md
vendored
21
services/external/README.md
vendored
@@ -100,14 +100,19 @@ The **External Service** integrates real-world data from Spanish sources to enha
|
|||||||
## API Endpoints (Key Routes)
|
## API Endpoints (Key Routes)
|
||||||
|
|
||||||
### POI Detection & Context
|
### POI Detection & Context
|
||||||
- `POST /poi-context/{tenant_id}/detect` - Detect POIs for tenant location (lat, long, force_refresh params)
|
- `POST /api/v1/tenants/{tenant_id}/poi-context/detect` - Detect POIs for tenant location (lat, long, force_refresh params) - Direct access (bypasses gateway authentication)
|
||||||
- `GET /poi-context/{tenant_id}` - Get cached POI context for tenant
|
- `GET /api/v1/tenants/{tenant_id}/poi-context` - Get cached POI context for tenant - Direct access (bypasses gateway authentication)
|
||||||
- `POST /poi-context/{tenant_id}/refresh` - Force refresh POI detection
|
- `POST /api/v1/tenants/{tenant_id}/poi-context/refresh` - Force refresh POI detection - Direct access (bypasses gateway authentication)
|
||||||
- `DELETE /poi-context/{tenant_id}` - Delete POI context for tenant
|
- `DELETE /api/v1/tenants/{tenant_id}/poi-context` - Delete POI context for tenant - Direct access (bypasses gateway authentication)
|
||||||
- `GET /poi-context/{tenant_id}/feature-importance` - Get POI feature importance summary
|
- `GET /api/v1/tenants/{tenant_id}/poi-context/feature-importance` - Get POI feature importance summary - Direct access (bypasses gateway authentication)
|
||||||
- `GET /poi-context/{tenant_id}/competitor-analysis` - Get competitive analysis
|
- `GET /api/v1/tenants/{tenant_id}/poi-context/competitor-analysis` - Get competitive analysis - Direct access (bypasses gateway authentication)
|
||||||
- `GET /poi-context/health` - Check POI service and Overpass API health
|
- `GET /api/v1/tenants/poi-context/health` - Check POI service and Overpass API health - Direct access (bypasses gateway authentication)
|
||||||
- `GET /poi-context/cache/stats` - Get POI cache statistics
|
- `GET /api/v1/tenants/poi-context/cache/stats` - Get POI cache statistics - Direct access (bypasses gateway authentication)
|
||||||
|
|
||||||
|
### Recommended Access Pattern:
|
||||||
|
- Services should use `/api/v1/tenants/{tenant_id}/external/poi-context` (detected POIs) and `/api/v1/tenants/{tenant_id}/external/poi-context/detect` (detection) via shared ExternalServiceClient through the API gateway for proper authentication and authorization.
|
||||||
|
|
||||||
|
**Note**: When using ExternalServiceClient through shared client, provide relative paths like `poi-context` and `poi-context/detect` - the client automatically constructs the full tenant-scoped path.
|
||||||
|
|
||||||
### Weather Data (AEMET)
|
### Weather Data (AEMET)
|
||||||
- `GET /api/v1/external/weather/current` - Current weather for location
|
- `GET /api/v1/external/weather/current` - Current weather for location
|
||||||
|
|||||||
28
services/external/app/api/poi_context.py
vendored
28
services/external/app/api/poi_context.py
vendored
@@ -18,13 +18,17 @@ from app.services.poi_refresh_service import POIRefreshService
|
|||||||
from app.repositories.poi_context_repository import POIContextRepository
|
from app.repositories.poi_context_repository import POIContextRepository
|
||||||
from app.cache.poi_cache_service import POICacheService
|
from app.cache.poi_cache_service import POICacheService
|
||||||
from app.core.redis_client import get_redis_client
|
from app.core.redis_client import get_redis_client
|
||||||
|
from shared.routing.route_builder import RouteBuilder
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
router = APIRouter(prefix="/tenants", tags=["POI Context"])
|
route_builder = RouteBuilder('external')
|
||||||
|
router = APIRouter(tags=["POI Context"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{tenant_id}/poi-context/detect")
|
@router.post(
|
||||||
|
route_builder.build_base_route("poi-context/detect")
|
||||||
|
)
|
||||||
async def detect_pois_for_tenant(
|
async def detect_pois_for_tenant(
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
latitude: float = Query(..., description="Bakery latitude"),
|
latitude: float = Query(..., description="Bakery latitude"),
|
||||||
@@ -297,7 +301,9 @@ async def detect_pois_for_tenant(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{tenant_id}/poi-context")
|
@router.get(
|
||||||
|
route_builder.build_base_route("poi-context")
|
||||||
|
)
|
||||||
async def get_poi_context(
|
async def get_poi_context(
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
@@ -331,7 +337,9 @@ async def get_poi_context(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{tenant_id}/poi-context/refresh")
|
@router.post(
|
||||||
|
route_builder.build_base_route("poi-context/refresh")
|
||||||
|
)
|
||||||
async def refresh_poi_context(
|
async def refresh_poi_context(
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
@@ -365,7 +373,9 @@ async def refresh_poi_context(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{tenant_id}/poi-context")
|
@router.delete(
|
||||||
|
route_builder.build_base_route("poi-context")
|
||||||
|
)
|
||||||
async def delete_poi_context(
|
async def delete_poi_context(
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
@@ -393,7 +403,9 @@ async def delete_poi_context(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{tenant_id}/poi-context/feature-importance")
|
@router.get(
|
||||||
|
route_builder.build_base_route("poi-context/feature-importance")
|
||||||
|
)
|
||||||
async def get_feature_importance(
|
async def get_feature_importance(
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
@@ -430,7 +442,9 @@ async def get_feature_importance(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{tenant_id}/poi-context/competitor-analysis")
|
@router.get(
|
||||||
|
route_builder.build_base_route("poi-context/competitor-analysis")
|
||||||
|
)
|
||||||
async def get_competitor_analysis(
|
async def get_competitor_analysis(
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
|
|||||||
@@ -497,7 +497,7 @@ poi_features = await poi_service.fetch_poi_features(tenant_id)
|
|||||||
- **Accuracy Improvement** - POI features contribute 5-10% accuracy improvement
|
- **Accuracy Improvement** - POI features contribute 5-10% accuracy improvement
|
||||||
|
|
||||||
**Endpoint Used:**
|
**Endpoint Used:**
|
||||||
- `GET {EXTERNAL_SERVICE_URL}/poi-context/{tenant_id}` - Fetch POI features
|
- Via shared client: `/api/v1/tenants/{tenant_id}/external/poi-context` (routed through API Gateway)
|
||||||
|
|
||||||
## Integration Points
|
## Integration Points
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,11 @@ Fetches POI features for use in demand forecasting predictions.
|
|||||||
Ensures feature consistency between training and prediction.
|
Ensures feature consistency between training and prediction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import httpx
|
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
|
from shared.clients.external_client import ExternalServiceClient
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -20,15 +21,18 @@ class POIFeatureService:
|
|||||||
prediction uses the same features as training.
|
prediction uses the same features as training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, external_service_url: str = "http://external-service:8000"):
|
def __init__(self, external_client: ExternalServiceClient = None):
|
||||||
"""
|
"""
|
||||||
Initialize POI feature service.
|
Initialize POI feature service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
external_service_url: Base URL for external service
|
external_client: External service client instance (optional)
|
||||||
"""
|
"""
|
||||||
self.external_service_url = external_service_url.rstrip("/")
|
if external_client is None:
|
||||||
self.poi_context_endpoint = f"{self.external_service_url}/poi-context"
|
from app.core.config import settings
|
||||||
|
self.external_client = ExternalServiceClient(settings, "forecasting-service")
|
||||||
|
else:
|
||||||
|
self.external_client = external_client
|
||||||
|
|
||||||
async def get_poi_features(
|
async def get_poi_features(
|
||||||
self,
|
self,
|
||||||
@@ -44,21 +48,10 @@ class POIFeatureService:
|
|||||||
Dictionary with POI features or empty dict if not available
|
Dictionary with POI features or empty dict if not available
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
result = await self.external_client.get_poi_context(tenant_id)
|
||||||
response = await client.get(
|
|
||||||
f"{self.poi_context_endpoint}/{tenant_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 404:
|
if result:
|
||||||
logger.warning(
|
poi_context = result.get("poi_context", {})
|
||||||
"No POI context found for tenant",
|
|
||||||
tenant_id=tenant_id
|
|
||||||
)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
poi_context = data.get("poi_context", {})
|
|
||||||
ml_features = poi_context.get("ml_features", {})
|
ml_features = poi_context.get("ml_features", {})
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -68,17 +61,16 @@ class POIFeatureService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ml_features
|
return ml_features
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"No POI context found for tenant",
|
||||||
|
tenant_id=tenant_id
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
logger.error(
|
|
||||||
"Failed to fetch POI features for forecasting",
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
error=str(e)
|
|
||||||
)
|
|
||||||
return {}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Unexpected error fetching POI features",
|
"Failed to fetch POI features for forecasting",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
error=str(e),
|
error=str(e),
|
||||||
exc_info=True
|
exc_info=True
|
||||||
@@ -87,17 +79,18 @@ class POIFeatureService:
|
|||||||
|
|
||||||
async def check_poi_service_health(self) -> bool:
|
async def check_poi_service_health(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if POI service is accessible.
|
Check if POI service is accessible through the external client.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if service is healthy, False otherwise
|
True if service is healthy, False otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
# Test the external service health by attempting to get POI context for a dummy tenant
|
||||||
response = await client.get(
|
# This will go through the proper authentication and routing
|
||||||
f"{self.poi_context_endpoint}/health"
|
dummy_context = await self.external_client.get_poi_context("test-tenant")
|
||||||
)
|
# If we can successfully make a request (even if it returns None for missing tenant),
|
||||||
return response.status_code == 200
|
# it means the service is accessible
|
||||||
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"POI service health check failed",
|
"POI service health check failed",
|
||||||
|
|||||||
@@ -581,7 +581,7 @@ for feature_name in poi_features.keys():
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Endpoint Used:**
|
**Endpoint Used:**
|
||||||
- `GET {EXTERNAL_SERVICE_URL}/poi-context/{tenant_id}` - Fetch POI features
|
- Via shared client: `/api/v1/tenants/{tenant_id}/external/poi-context` (routed through API Gateway)
|
||||||
|
|
||||||
## Integration Points
|
## Integration Points
|
||||||
|
|
||||||
|
|||||||
@@ -5,11 +5,12 @@ Integrates POI features into ML training pipeline.
|
|||||||
Fetches POI context from External service and merges features into training data.
|
Fetches POI context from External service and merges features into training data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import httpx
|
|
||||||
from typing import Dict, Any, Optional, List
|
from typing import Dict, Any, Optional, List
|
||||||
import structlog
|
import structlog
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
from shared.clients.external_client import ExternalServiceClient
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -21,15 +22,18 @@ class POIFeatureIntegrator:
|
|||||||
to training dataframes for location-based demand forecasting.
|
to training dataframes for location-based demand forecasting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, external_service_url: str = "http://external-service:8000"):
|
def __init__(self, external_client: ExternalServiceClient = None):
|
||||||
"""
|
"""
|
||||||
Initialize POI feature integrator.
|
Initialize POI feature integrator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
external_service_url: Base URL for external service
|
external_client: External service client instance (optional)
|
||||||
"""
|
"""
|
||||||
self.external_service_url = external_service_url.rstrip("/")
|
if external_client is None:
|
||||||
self.poi_context_endpoint = f"{self.external_service_url}/poi-context"
|
from app.core.config import settings
|
||||||
|
self.external_client = ExternalServiceClient(settings, "training-service")
|
||||||
|
else:
|
||||||
|
self.external_client = external_client
|
||||||
|
|
||||||
async def fetch_poi_features(
|
async def fetch_poi_features(
|
||||||
self,
|
self,
|
||||||
@@ -53,57 +57,49 @@ class POIFeatureIntegrator:
|
|||||||
Dictionary with POI features or None if detection fails
|
Dictionary with POI features or None if detection fails
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
# Try to get existing POI context first
|
||||||
# Try to get existing POI context first
|
if not force_refresh:
|
||||||
if not force_refresh:
|
existing_context = await self.external_client.get_poi_context(tenant_id)
|
||||||
try:
|
if existing_context:
|
||||||
response = await client.get(
|
poi_context = existing_context.get("poi_context", {})
|
||||||
f"{self.poi_context_endpoint}/{tenant_id}"
|
ml_features = poi_context.get("ml_features", {})
|
||||||
)
|
|
||||||
if response.status_code == 200:
|
|
||||||
data = response.json()
|
|
||||||
poi_context = data.get("poi_context", {})
|
|
||||||
|
|
||||||
# Check if stale
|
# Check if stale
|
||||||
if not data.get("is_stale", False):
|
is_stale = existing_context.get("is_stale", False)
|
||||||
logger.info(
|
if not is_stale:
|
||||||
"Using existing POI context",
|
|
||||||
tenant_id=tenant_id
|
|
||||||
)
|
|
||||||
return poi_context.get("ml_features", {})
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"POI context is stale, refreshing",
|
|
||||||
tenant_id=tenant_id
|
|
||||||
)
|
|
||||||
force_refresh = True
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
if e.response.status_code != 404:
|
|
||||||
raise
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"No existing POI context, will detect",
|
"Using existing POI context",
|
||||||
tenant_id=tenant_id
|
tenant_id=tenant_id
|
||||||
)
|
)
|
||||||
|
return ml_features
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"POI context is stale, refreshing",
|
||||||
|
tenant_id=tenant_id
|
||||||
|
)
|
||||||
|
force_refresh = True
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"No existing POI context, will detect",
|
||||||
|
tenant_id=tenant_id
|
||||||
|
)
|
||||||
|
|
||||||
# Detect or refresh POIs
|
# Detect or refresh POIs
|
||||||
logger.info(
|
logger.info(
|
||||||
"Detecting POIs for tenant",
|
"Detecting POIs for tenant",
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
location=(latitude, longitude)
|
location=(latitude, longitude)
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await client.post(
|
detection_result = await self.external_client.detect_poi_for_tenant(
|
||||||
f"{self.poi_context_endpoint}/{tenant_id}/detect",
|
tenant_id=tenant_id,
|
||||||
params={
|
latitude=latitude,
|
||||||
"latitude": latitude,
|
longitude=longitude,
|
||||||
"longitude": longitude,
|
force_refresh=force_refresh
|
||||||
"force_refresh": force_refresh
|
)
|
||||||
}
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
if detection_result:
|
||||||
poi_context = result.get("poi_context", {})
|
poi_context = detection_result.get("poi_context", {})
|
||||||
ml_features = poi_context.get("ml_features", {})
|
ml_features = poi_context.get("ml_features", {})
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -114,15 +110,13 @@ class POIFeatureIntegrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ml_features
|
return ml_features
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"POI detection failed",
|
||||||
|
tenant_id=tenant_id
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
logger.error(
|
|
||||||
"Failed to fetch POI features",
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
error=str(e),
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Unexpected error fetching POI features",
|
"Unexpected error fetching POI features",
|
||||||
@@ -185,17 +179,18 @@ class POIFeatureIntegrator:
|
|||||||
|
|
||||||
async def check_poi_service_health(self) -> bool:
|
async def check_poi_service_health(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if POI service is accessible.
|
Check if POI service is accessible through the external client.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if service is healthy, False otherwise
|
True if service is healthy, False otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
# We can test the external service health by attempting to get POI context for a dummy tenant
|
||||||
response = await client.get(
|
# This will go through the proper authentication and routing
|
||||||
f"{self.poi_context_endpoint}/health"
|
dummy_context = await self.external_client.get_poi_context("test-tenant")
|
||||||
)
|
# If we can successfully make a request (even if it returns None for missing tenant),
|
||||||
return response.status_code == 200
|
# it means the service is accessible
|
||||||
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"POI service health check failed",
|
"POI service health check failed",
|
||||||
|
|||||||
@@ -375,13 +375,116 @@ class EnhancedBakeryMLTrainer:
|
|||||||
try:
|
try:
|
||||||
# Use provided session or create new one to prevent nested sessions and deadlocks
|
# Use provided session or create new one to prevent nested sessions and deadlocks
|
||||||
should_create_session = session is None
|
should_create_session = session is None
|
||||||
db_session = session if session is not None else None
|
|
||||||
|
|
||||||
if should_create_session:
|
if should_create_session:
|
||||||
# Only create a session if one wasn't provided
|
# Only create a session if one wasn't provided
|
||||||
async with self.database_manager.get_session() as db_session:
|
async with self.database_manager.get_session() as db_session:
|
||||||
repos = await self._get_repositories(db_session)
|
repos = await self._get_repositories(db_session)
|
||||||
|
|
||||||
|
# Validate input data
|
||||||
|
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||||
|
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
|
||||||
|
|
||||||
|
# Validate required columns
|
||||||
|
required_columns = ['ds', 'y']
|
||||||
|
missing_cols = [col for col in required_columns if col not in training_data.columns]
|
||||||
|
if missing_cols:
|
||||||
|
raise ValueError(f"Missing required columns in training data: {missing_cols}")
|
||||||
|
|
||||||
|
# Create a simple progress tracker for single product
|
||||||
|
from app.services.progress_tracker import ParallelProductProgressTracker
|
||||||
|
progress_tracker = ParallelProductProgressTracker(
|
||||||
|
job_id=job_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
total_products=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure training data has proper data types before training
|
||||||
|
if 'ds' in training_data.columns:
|
||||||
|
training_data['ds'] = pd.to_datetime(training_data['ds'])
|
||||||
|
if 'y' in training_data.columns:
|
||||||
|
training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce')
|
||||||
|
|
||||||
|
# Remove any rows with NaN values
|
||||||
|
training_data = training_data.dropna()
|
||||||
|
|
||||||
|
# Train the model using the existing _train_single_product method
|
||||||
|
product_id, result = await self._train_single_product(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
product_data=training_data,
|
||||||
|
job_id=job_id,
|
||||||
|
repos=repos,
|
||||||
|
progress_tracker=progress_tracker,
|
||||||
|
session=db_session # Pass the session to prevent nested sessions
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Single product training completed",
|
||||||
|
job_id=job_id,
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
result_status=result.get('status'))
|
||||||
|
|
||||||
|
# Write training result to database (create model record)
|
||||||
|
if result.get('status') == 'success':
|
||||||
|
model_info = result.get('model_info')
|
||||||
|
product_data = result.get('product_data')
|
||||||
|
|
||||||
|
if model_info and product_data is not None:
|
||||||
|
# Create model record in database
|
||||||
|
model_record = await self._create_model_record(
|
||||||
|
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create performance metrics
|
||||||
|
if model_info.get('training_metrics') and model_record:
|
||||||
|
await self._create_performance_metrics(
|
||||||
|
repos, model_record.id,
|
||||||
|
tenant_id, inventory_product_id, model_info['training_metrics']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update result with model_record_id
|
||||||
|
result['model_record_id'] = str(model_record.id) if model_record else None
|
||||||
|
|
||||||
|
# Get training metrics and filter out non-numeric values
|
||||||
|
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
|
||||||
|
# Filter metrics to only include numeric values (per Pydantic schema requirement)
|
||||||
|
filtered_metrics = {}
|
||||||
|
for key, value in raw_metrics.items():
|
||||||
|
if key == 'product_category':
|
||||||
|
# Skip product_category as it's a string value, not a numeric metric
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
# Try to convert to float for validation
|
||||||
|
filtered_metrics[key] = float(value) if value is not None else 0.0
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# Skip non-numeric values
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Return appropriate result format
|
||||||
|
result_dict = {
|
||||||
|
"job_id": job_id,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"inventory_product_id": inventory_product_id,
|
||||||
|
"status": result.get('status', 'success'),
|
||||||
|
"model_id": str(result.get('model_record_id', '')) if result.get('model_record_id') else None,
|
||||||
|
"training_metrics": filtered_metrics,
|
||||||
|
"training_time": result.get('training_time_seconds', 0),
|
||||||
|
"data_points": result.get('data_points', 0),
|
||||||
|
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only commit if this is our own session (not a parent session)
|
||||||
|
# Commit after we're done with all database operations
|
||||||
|
await db_session.commit()
|
||||||
|
logger.info("Committed single product model record to database",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
model_record_id=result.get('model_record_id'))
|
||||||
|
|
||||||
|
return result_dict
|
||||||
|
else:
|
||||||
|
# Use the provided session
|
||||||
|
repos = await self._get_repositories(session)
|
||||||
|
|
||||||
# Validate input data
|
# Validate input data
|
||||||
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
|
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||||
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
|
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
|
||||||
@@ -417,7 +520,7 @@ class EnhancedBakeryMLTrainer:
|
|||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
repos=repos,
|
repos=repos,
|
||||||
progress_tracker=progress_tracker,
|
progress_tracker=progress_tracker,
|
||||||
session=db_session # Pass the session to prevent nested sessions
|
session=session # Pass the provided session
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Single product training completed",
|
logger.info("Single product training completed",
|
||||||
@@ -425,6 +528,27 @@ class EnhancedBakeryMLTrainer:
|
|||||||
inventory_product_id=inventory_product_id,
|
inventory_product_id=inventory_product_id,
|
||||||
result_status=result.get('status'))
|
result_status=result.get('status'))
|
||||||
|
|
||||||
|
# Write training result to database (create model record)
|
||||||
|
if result.get('status') == 'success':
|
||||||
|
model_info = result.get('model_info')
|
||||||
|
product_data = result.get('product_data')
|
||||||
|
|
||||||
|
if model_info and product_data is not None:
|
||||||
|
# Create model record in database
|
||||||
|
model_record = await self._create_model_record(
|
||||||
|
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create performance metrics
|
||||||
|
if model_info.get('training_metrics') and model_record:
|
||||||
|
await self._create_performance_metrics(
|
||||||
|
repos, model_record.id,
|
||||||
|
tenant_id, inventory_product_id, model_info['training_metrics']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update result with model_record_id
|
||||||
|
result['model_record_id'] = str(model_record.id) if model_record else None
|
||||||
|
|
||||||
# Get training metrics and filter out non-numeric values
|
# Get training metrics and filter out non-numeric values
|
||||||
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
|
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
|
||||||
# Filter metrics to only include numeric values (per Pydantic schema requirement)
|
# Filter metrics to only include numeric values (per Pydantic schema requirement)
|
||||||
@@ -453,6 +577,12 @@ class EnhancedBakeryMLTrainer:
|
|||||||
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
|
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# For provided sessions, do NOT commit here - let the calling method handle commits
|
||||||
|
# This prevents committing a parent transaction prematurely
|
||||||
|
logger.info("Single product model processed (commit handled by caller)",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
model_record_id=result.get('model_record_id'))
|
||||||
|
|
||||||
return result_dict
|
return result_dict
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -186,6 +186,7 @@ class TrainedModel(Base):
|
|||||||
"training_samples": self.training_samples,
|
"training_samples": self.training_samples,
|
||||||
"hyperparameters": self.hyperparameters,
|
"hyperparameters": self.hyperparameters,
|
||||||
"features_used": self.features_used,
|
"features_used": self.features_used,
|
||||||
|
"features": self.features_used, # Alias for frontend compatibility (ModelDetailsModal expects 'features')
|
||||||
"product_category": self.product_category,
|
"product_category": self.product_category,
|
||||||
"is_active": self.is_active,
|
"is_active": self.is_active,
|
||||||
"is_production": self.is_production,
|
"is_production": self.is_production,
|
||||||
|
|||||||
@@ -732,297 +732,293 @@ class EnhancedTrainingService:
|
|||||||
current_step: str = None,
|
current_step: str = None,
|
||||||
error_message: str = None,
|
error_message: str = None,
|
||||||
results: Dict = None,
|
results: Dict = None,
|
||||||
tenant_id: str = None):
|
tenant_id: str = None,
|
||||||
"""Update job status using repository pattern"""
|
session = None):
|
||||||
|
"""Update job status using repository pattern
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Optional database session to reuse. If None, creates a new session.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
async with self.database_manager.get_session() as session:
|
# Use provided session or create new one
|
||||||
await self._init_repositories(session)
|
should_create_session = session is None
|
||||||
|
|
||||||
# Check if log exists, create if not
|
if should_create_session:
|
||||||
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
async with self.database_manager.get_session() as session:
|
||||||
|
await self._init_repositories(session)
|
||||||
if not existing_log:
|
await self._update_job_status_impl(
|
||||||
# Create initial log entry
|
session, job_id, status, progress, current_step,
|
||||||
if not tenant_id:
|
error_message, results, tenant_id
|
||||||
# Extract tenant_id from job_id if not provided
|
|
||||||
# Format: enhanced_training_{tenant_id}_{job_suffix}
|
|
||||||
try:
|
|
||||||
parts = job_id.split('_')
|
|
||||||
if len(parts) >= 3 and parts[0] == 'enhanced' and parts[1] == 'training':
|
|
||||||
tenant_id = parts[2]
|
|
||||||
except Exception:
|
|
||||||
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
|
|
||||||
|
|
||||||
if tenant_id:
|
|
||||||
log_data = {
|
|
||||||
"job_id": job_id,
|
|
||||||
"tenant_id": tenant_id,
|
|
||||||
"status": status or "pending",
|
|
||||||
"progress": progress or 0,
|
|
||||||
"current_step": current_step or "initializing",
|
|
||||||
"start_time": datetime.now(timezone.utc)
|
|
||||||
}
|
|
||||||
|
|
||||||
if error_message:
|
|
||||||
log_data["error_message"] = error_message
|
|
||||||
if results:
|
|
||||||
# Ensure results are JSON-serializable before storing
|
|
||||||
log_data["results"] = make_json_serializable(results)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.training_log_repo.create_training_log(log_data)
|
|
||||||
await session.commit() # Explicit commit so other sessions can see it
|
|
||||||
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
|
|
||||||
except Exception as create_error:
|
|
||||||
# Handle race condition: another session may have created the log
|
|
||||||
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
|
|
||||||
logger.debug("Training log already exists (race condition), querying again", job_id=job_id)
|
|
||||||
await session.rollback()
|
|
||||||
# Query again to get the existing log
|
|
||||||
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
|
||||||
if existing_log:
|
|
||||||
# Update the existing log instead
|
|
||||||
await self.training_log_repo.update_log_progress(
|
|
||||||
job_id=job_id,
|
|
||||||
progress=progress,
|
|
||||||
current_step=current_step,
|
|
||||||
status=status
|
|
||||||
)
|
|
||||||
await session.commit()
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
logger.error("Cannot create training log without tenant_id", job_id=job_id)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
# Update existing log
|
|
||||||
await self.training_log_repo.update_log_progress(
|
|
||||||
job_id=job_id,
|
|
||||||
progress=progress,
|
|
||||||
current_step=current_step,
|
|
||||||
status=status
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
# Update additional fields if provided
|
# Reuse provided session (don't commit - let caller control transaction)
|
||||||
if error_message or results:
|
await self._init_repositories(session)
|
||||||
update_data = {}
|
await self._update_job_status_impl(
|
||||||
if error_message:
|
session, job_id, status, progress, current_step,
|
||||||
update_data["error_message"] = error_message
|
error_message, results, tenant_id, auto_commit=False
|
||||||
if results:
|
)
|
||||||
# Ensure results are JSON-serializable before storing
|
|
||||||
update_data["results"] = make_json_serializable(results)
|
|
||||||
if status in ["completed", "failed"]:
|
|
||||||
update_data["end_time"] = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
if update_data:
|
|
||||||
await self.training_log_repo.update(existing_log.id, update_data)
|
|
||||||
|
|
||||||
await session.commit() # Explicit commit after updates
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to update job status using repository",
|
logger.error("Failed to update job status using repository",
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
error=str(e))
|
error=str(e))
|
||||||
|
|
||||||
|
async def _update_job_status_impl(self,
|
||||||
|
session,
|
||||||
|
job_id: str,
|
||||||
|
status: str,
|
||||||
|
progress: int = None,
|
||||||
|
current_step: str = None,
|
||||||
|
error_message: str = None,
|
||||||
|
results: Dict = None,
|
||||||
|
tenant_id: str = None,
|
||||||
|
auto_commit: bool = True):
|
||||||
|
"""Implementation of job status update"""
|
||||||
|
# Check if log exists, create if not
|
||||||
|
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||||
|
|
||||||
|
if not existing_log:
|
||||||
|
# Create initial log entry
|
||||||
|
if not tenant_id:
|
||||||
|
# Extract tenant_id from job_id if not provided
|
||||||
|
# Format: enhanced_training_{tenant_id}_{job_suffix}
|
||||||
|
try:
|
||||||
|
parts = job_id.split('_')
|
||||||
|
if len(parts) >= 3 and parts[0] == 'enhanced' and parts[1] == 'training':
|
||||||
|
tenant_id = parts[2]
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
|
||||||
|
|
||||||
|
if tenant_id:
|
||||||
|
log_data = {
|
||||||
|
"job_id": job_id,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"status": status or "pending",
|
||||||
|
"progress": progress or 0,
|
||||||
|
"current_step": current_step or "initializing",
|
||||||
|
"start_time": datetime.now(timezone.utc)
|
||||||
|
}
|
||||||
|
|
||||||
|
if error_message:
|
||||||
|
log_data["error_message"] = error_message
|
||||||
|
if results:
|
||||||
|
# Ensure results are JSON-serializable before storing
|
||||||
|
log_data["results"] = make_json_serializable(results)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.training_log_repo.create_training_log(log_data)
|
||||||
|
if auto_commit:
|
||||||
|
await session.commit() # Explicit commit so other sessions can see it
|
||||||
|
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
|
||||||
|
except Exception as create_error:
|
||||||
|
# Handle race condition: another session may have created the log
|
||||||
|
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
|
||||||
|
logger.debug("Training log already exists (race condition), querying again", job_id=job_id)
|
||||||
|
await session.rollback()
|
||||||
|
# Query again to get the existing log
|
||||||
|
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||||
|
if existing_log:
|
||||||
|
# Update the existing log instead
|
||||||
|
await self.training_log_repo.update_log_progress(
|
||||||
|
job_id=job_id,
|
||||||
|
progress=progress,
|
||||||
|
current_step=current_step,
|
||||||
|
status=status
|
||||||
|
)
|
||||||
|
if auto_commit:
|
||||||
|
await session.commit()
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
logger.error("Cannot create training log without tenant_id", job_id=job_id)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Update existing log
|
||||||
|
await self.training_log_repo.update_log_progress(
|
||||||
|
job_id=job_id,
|
||||||
|
progress=progress,
|
||||||
|
current_step=current_step,
|
||||||
|
status=status
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update additional fields if provided
|
||||||
|
if error_message or results:
|
||||||
|
update_data = {}
|
||||||
|
if error_message:
|
||||||
|
update_data["error_message"] = error_message
|
||||||
|
if results:
|
||||||
|
# Ensure results are JSON-serializable before storing
|
||||||
|
update_data["results"] = make_json_serializable(results)
|
||||||
|
if status in ["completed", "failed"]:
|
||||||
|
update_data["end_time"] = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
if update_data:
|
||||||
|
await self.training_log_repo.update(existing_log.id, update_data)
|
||||||
|
|
||||||
|
if auto_commit:
|
||||||
|
await session.commit() # Explicit commit after updates
|
||||||
|
|
||||||
async def start_single_product_training(self,
|
async def start_single_product_training(self,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
inventory_product_id: str,
|
inventory_product_id: str,
|
||||||
job_id: str,
|
job_id: str,
|
||||||
bakery_location: tuple = (40.4168, -3.7038)) -> Dict[str, Any]:
|
bakery_location: tuple = (40.4168, -3.7038)) -> Dict[str, Any]:
|
||||||
"""Start enhanced single product training using repository pattern"""
|
"""Start enhanced single product training using repository pattern with single session"""
|
||||||
try:
|
# Create a single database session for all operations to avoid connection pool exhaustion
|
||||||
logger.info("Starting enhanced single product training",
|
async with self.database_manager.get_session() as session:
|
||||||
tenant_id=tenant_id,
|
await self._init_repositories(session)
|
||||||
inventory_product_id=inventory_product_id,
|
|
||||||
job_id=job_id)
|
|
||||||
|
|
||||||
# Create initial training log
|
|
||||||
await self._update_job_status_repository(
|
|
||||||
job_id=job_id,
|
|
||||||
status="running",
|
|
||||||
progress=0,
|
|
||||||
current_step="Fetching training data",
|
|
||||||
tenant_id=tenant_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare training data for all products to get weather/traffic data
|
|
||||||
# then filter down to the specific product
|
|
||||||
training_dataset = await self.orchestrator.prepare_training_data(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
bakery_location=bakery_location,
|
|
||||||
job_id=job_id + "_temp"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter sales data to the specific product
|
|
||||||
sales_df = pd.DataFrame(training_dataset.sales_data)
|
|
||||||
product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
|
|
||||||
|
|
||||||
if product_sales_df.empty:
|
|
||||||
raise ValueError(f"No sales data available for product {inventory_product_id}")
|
|
||||||
|
|
||||||
# Prepare the data in Prophet format (ds and y columns)
|
|
||||||
# Ensure proper column names and types for Prophet
|
|
||||||
product_data = product_sales_df.copy()
|
|
||||||
product_data = product_data.rename(columns={
|
|
||||||
'sale_date': 'ds', # Common sales date column
|
|
||||||
'sale_datetime': 'ds', # Alternative date column
|
|
||||||
'date': 'ds', # Alternative date column
|
|
||||||
'quantity': 'y', # Quantity sold
|
|
||||||
'total_amount': 'y', # Alternative for sales data
|
|
||||||
'sales_amount': 'y', # Alternative for sales data
|
|
||||||
'sale_amount': 'y' # Alternative for sales data
|
|
||||||
})
|
|
||||||
|
|
||||||
# If 'ds' and 'y' columns are not renamed properly, try to infer them
|
|
||||||
if 'ds' not in product_data.columns:
|
|
||||||
# Try to find date-like columns
|
|
||||||
date_cols = [col for col in product_data.columns if 'date' in col.lower() or 'time' in col.lower()]
|
|
||||||
if date_cols:
|
|
||||||
product_data = product_data.rename(columns={date_cols[0]: 'ds'})
|
|
||||||
|
|
||||||
if 'y' not in product_data.columns:
|
|
||||||
# Try to find sales/quantity-like columns
|
|
||||||
sales_cols = [col for col in product_data.columns if
|
|
||||||
any(word in col.lower() for word in ['amount', 'quantity', 'sales', 'total', 'count', 'value'])]
|
|
||||||
if sales_cols:
|
|
||||||
product_data = product_data.rename(columns={sales_cols[0]: 'y'})
|
|
||||||
|
|
||||||
# Ensure required columns exist
|
|
||||||
if 'ds' not in product_data.columns or 'y' not in product_data.columns:
|
|
||||||
raise ValueError(f"Sales data must contain 'date' and 'quantity/sales' columns. Available columns: {list(product_data.columns)}")
|
|
||||||
|
|
||||||
# Convert the date column to datetime if it's not already
|
|
||||||
product_data['ds'] = pd.to_datetime(product_data['ds'])
|
|
||||||
|
|
||||||
# Convert to numeric ensuring no pandas/numpy objects remain
|
|
||||||
product_data['y'] = pd.to_numeric(product_data['y'], errors='coerce')
|
|
||||||
|
|
||||||
# Sort by date to ensure proper chronological order
|
|
||||||
product_data = product_data.sort_values('ds').reset_index(drop=True)
|
|
||||||
|
|
||||||
# Drop any rows with NaN values
|
|
||||||
product_data = product_data.dropna(subset=['ds', 'y'])
|
|
||||||
|
|
||||||
# Ensure the data is in the right format for Prophet
|
|
||||||
product_data = product_data[['ds', 'y']].copy()
|
|
||||||
|
|
||||||
# Convert to pandas datetime and float types (keep as pandas Series for proper min/max operations)
|
|
||||||
product_data['ds'] = pd.to_datetime(product_data['ds'])
|
|
||||||
product_data['y'] = product_data['y'].astype(float)
|
|
||||||
|
|
||||||
# DEBUG: Log data types to diagnose dict comparison error
|
|
||||||
logger.info(f"DEBUG: product_data dtypes after conversion: ds={product_data['ds'].dtype}, y={product_data['y'].dtype}")
|
|
||||||
logger.info(f"DEBUG: product_data['ds'] sample values: {product_data['ds'].head(3).tolist()}")
|
|
||||||
logger.info(f"DEBUG: Attempting to get min/max...")
|
|
||||||
try:
|
try:
|
||||||
min_val = product_data['ds'].min()
|
logger.info("Starting enhanced single product training",
|
||||||
max_val = product_data['ds'].max()
|
tenant_id=tenant_id,
|
||||||
logger.info(f"DEBUG: min_val type={type(min_val)}, value={min_val}")
|
inventory_product_id=inventory_product_id,
|
||||||
logger.info(f"DEBUG: max_val type={type(max_val)}, value={max_val}")
|
job_id=job_id)
|
||||||
except Exception as debug_e:
|
|
||||||
logger.error(f"DEBUG: Failed to get min/max: {debug_e}")
|
|
||||||
import traceback
|
|
||||||
logger.error(f"DEBUG: Traceback: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
logger.info("Prepared training data for single product",
|
# Create initial training log (using shared session)
|
||||||
inventory_product_id=inventory_product_id,
|
await self._update_job_status_repository(
|
||||||
data_points=len(product_data),
|
job_id=job_id,
|
||||||
date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}")
|
status="running",
|
||||||
|
progress=0,
|
||||||
# Update progress
|
current_step="Fetching training data",
|
||||||
await self._update_job_status_repository(
|
|
||||||
job_id=job_id,
|
|
||||||
status="running",
|
|
||||||
progress=30,
|
|
||||||
current_step="Training model",
|
|
||||||
tenant_id=tenant_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Train the model using the trainer
|
|
||||||
# Extract datetime values with proper pandas Timestamp wrapper for type safety
|
|
||||||
try:
|
|
||||||
training_start = pd.Timestamp(product_data['ds'].min()).to_pydatetime()
|
|
||||||
training_end = pd.Timestamp(product_data['ds'].max()).to_pydatetime()
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
logger.error(f"Failed to extract training dates: {e}")
|
|
||||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
|
||||||
logger.error(f"product_data['ds'] dtype: {product_data['ds'].dtype}")
|
|
||||||
logger.error(f"product_data['ds'] first 5 values: {product_data['ds'].head().tolist()}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Run the actual training
|
|
||||||
try:
|
|
||||||
model_info = await self.trainer.train_single_product_model(
|
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
session=session
|
||||||
|
)
|
||||||
|
await session.commit() # Commit after initial log creation
|
||||||
|
|
||||||
|
# Prepare training data for all products to get weather/traffic data
|
||||||
|
# then filter down to the specific product
|
||||||
|
training_dataset = await self.orchestrator.prepare_training_data(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
bakery_location=bakery_location,
|
||||||
|
job_id=job_id + "_temp"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the enhanced data processor to prepare training data with all features (POI, weather, traffic)
|
||||||
|
# Filter sales data to the specific product first
|
||||||
|
sales_df = pd.DataFrame(training_dataset.sales_data)
|
||||||
|
product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
|
||||||
|
|
||||||
|
if product_sales_df.empty:
|
||||||
|
raise ValueError(f"No sales data available for product {inventory_product_id}")
|
||||||
|
|
||||||
|
# Get weather and traffic data as DataFrames
|
||||||
|
weather_df = pd.DataFrame(training_dataset.weather_data)
|
||||||
|
traffic_df = pd.DataFrame(training_dataset.traffic_data)
|
||||||
|
|
||||||
|
# Get POI features from the training dataset (already collected by orchestrator)
|
||||||
|
poi_features = training_dataset.poi_features if hasattr(training_dataset, 'poi_features') else None
|
||||||
|
|
||||||
|
# Use the enhanced data processor to merge all features properly
|
||||||
|
# This will include POI, weather, traffic features along with ds and y
|
||||||
|
from app.ml.data_processor import EnhancedBakeryDataProcessor
|
||||||
|
data_processor = EnhancedBakeryDataProcessor(self.database_manager)
|
||||||
|
|
||||||
|
product_data = await data_processor.prepare_training_data(
|
||||||
|
sales_data=product_sales_df,
|
||||||
|
weather_data=weather_df,
|
||||||
|
traffic_data=traffic_df,
|
||||||
inventory_product_id=inventory_product_id,
|
inventory_product_id=inventory_product_id,
|
||||||
training_data=product_data,
|
poi_features=poi_features,
|
||||||
|
tenant_id=tenant_id,
|
||||||
job_id=job_id
|
job_id=job_id
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
logger.error(f"Training failed with error: {e}")
|
|
||||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Update progress
|
if product_data.empty:
|
||||||
await self._update_job_status_repository(
|
raise ValueError(f"Data processor returned empty data for product {inventory_product_id}")
|
||||||
job_id=job_id,
|
|
||||||
status="running",
|
|
||||||
progress=80,
|
|
||||||
current_step="Saving model",
|
|
||||||
tenant_id=tenant_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# The model should already be saved by train_single_product_model
|
logger.info("Prepared training data for single product",
|
||||||
# Return appropriate response
|
inventory_product_id=inventory_product_id,
|
||||||
return {
|
data_points=len(product_data),
|
||||||
"job_id": job_id,
|
features=list(product_data.columns),
|
||||||
"tenant_id": tenant_id,
|
date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}")
|
||||||
"inventory_product_id": inventory_product_id,
|
|
||||||
"status": "completed",
|
|
||||||
"message": "Enhanced single product training completed successfully",
|
|
||||||
"created_at": datetime.now(timezone.utc),
|
|
||||||
"estimated_duration_minutes": 15, # Default estimate for single product
|
|
||||||
"training_results": {
|
|
||||||
"total_products": 1,
|
|
||||||
"successful_trainings": 1,
|
|
||||||
"failed_trainings": 0,
|
|
||||||
"products": [{
|
|
||||||
"inventory_product_id": inventory_product_id,
|
|
||||||
"status": "completed",
|
|
||||||
"model_id": str(model_info.get('model_id', f"model_{inventory_product_id}_{job_id[:8]}")) if model_info.get('model_id') else None,
|
|
||||||
"data_points": len(product_data) if product_data is not None else 0,
|
|
||||||
# Filter metrics to ensure only numeric values are included
|
|
||||||
"metrics": {
|
|
||||||
k: float(v) if not isinstance(v, (int, float)) else v
|
|
||||||
for k, v in model_info.get('training_metrics', {"mape": 0.0, "mae": 0.0, "rmse": 0.0, "r2_score": 0.0}).items()
|
|
||||||
if k != 'product_category' and v is not None
|
|
||||||
}
|
|
||||||
}],
|
|
||||||
"overall_training_time_seconds": model_info.get('training_time', 45.2)
|
|
||||||
},
|
|
||||||
"enhanced_features": True,
|
|
||||||
"repository_integration": True,
|
|
||||||
"completed_at": datetime.now(timezone.utc).isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
# Update progress (using shared session)
|
||||||
logger.error("Enhanced single product training failed",
|
await self._update_job_status_repository(
|
||||||
|
job_id=job_id,
|
||||||
|
status="running",
|
||||||
|
progress=30,
|
||||||
|
current_step="Training model",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
session=session
|
||||||
|
)
|
||||||
|
await session.commit() # Commit progress update
|
||||||
|
|
||||||
|
# Run the actual training (passing the session to avoid nested session creation)
|
||||||
|
try:
|
||||||
|
model_info = await self.trainer.train_single_product_model(
|
||||||
|
tenant_id=tenant_id,
|
||||||
inventory_product_id=inventory_product_id,
|
inventory_product_id=inventory_product_id,
|
||||||
error=str(e))
|
training_data=product_data,
|
||||||
|
job_id=job_id,
|
||||||
|
session=session # ✅ CRITICAL FIX: Pass session to prevent deadlock
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Training failed with error: {e}")
|
||||||
|
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||||
|
raise
|
||||||
|
|
||||||
# Update status to failed
|
# Update progress (using shared session)
|
||||||
await self._update_job_status_repository(
|
await self._update_job_status_repository(
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
status="failed",
|
status="running",
|
||||||
progress=0,
|
progress=80,
|
||||||
current_step="Training failed",
|
current_step="Saving model",
|
||||||
error_message=str(e),
|
tenant_id=tenant_id,
|
||||||
tenant_id=tenant_id
|
session=session
|
||||||
)
|
)
|
||||||
|
await session.commit() # Commit progress update
|
||||||
|
|
||||||
raise
|
# The model should already be saved by train_single_product_model
|
||||||
|
# Return appropriate response
|
||||||
|
return {
|
||||||
|
"job_id": job_id,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"inventory_product_id": inventory_product_id,
|
||||||
|
"status": "completed",
|
||||||
|
"message": "Enhanced single product training completed successfully",
|
||||||
|
"created_at": datetime.now(timezone.utc),
|
||||||
|
"estimated_duration_minutes": 15, # Default estimate for single product
|
||||||
|
"training_results": {
|
||||||
|
"total_products": 1,
|
||||||
|
"successful_trainings": 1,
|
||||||
|
"failed_trainings": 0,
|
||||||
|
"products": [{
|
||||||
|
"inventory_product_id": inventory_product_id,
|
||||||
|
"status": "completed",
|
||||||
|
"model_id": str(model_info.get('model_id', f"model_{inventory_product_id}_{job_id[:8]}")) if model_info.get('model_id') else None,
|
||||||
|
"data_points": len(product_data) if product_data is not None else 0,
|
||||||
|
# Filter metrics to ensure only numeric values are included
|
||||||
|
"metrics": {
|
||||||
|
k: float(v) if not isinstance(v, (int, float)) else v
|
||||||
|
for k, v in model_info.get('training_metrics', {"mape": 0.0, "mae": 0.0, "rmse": 0.0, "r2_score": 0.0}).items()
|
||||||
|
if k != 'product_category' and v is not None
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"overall_training_time_seconds": model_info.get('training_time', 45.2)
|
||||||
|
},
|
||||||
|
"enhanced_features": True,
|
||||||
|
"repository_integration": True,
|
||||||
|
"completed_at": datetime.now(timezone.utc).isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Enhanced single product training failed",
|
||||||
|
inventory_product_id=inventory_product_id,
|
||||||
|
error=str(e))
|
||||||
|
|
||||||
|
# Update status to failed (using shared session)
|
||||||
|
await self._update_job_status_repository(
|
||||||
|
job_id=job_id,
|
||||||
|
status="failed",
|
||||||
|
progress=0,
|
||||||
|
current_step="Training failed",
|
||||||
|
error_message=str(e),
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
session=session
|
||||||
|
)
|
||||||
|
await session.commit() # Commit failure status
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
def _create_detailed_training_response(self, final_result: Dict[str, Any]) -> Dict[str, Any]:
|
def _create_detailed_training_response(self, final_result: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Convert final result to detailed training response"""
|
"""Convert final result to detailed training response"""
|
||||||
|
|||||||
@@ -494,6 +494,75 @@ class ExternalServiceClient(BaseServiceClient):
|
|||||||
# POI (POINT OF INTEREST) DATA
|
# POI (POINT OF INTEREST) DATA
|
||||||
# ================================================================
|
# ================================================================
|
||||||
|
|
||||||
|
async def detect_poi_for_tenant(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
latitude: float,
|
||||||
|
longitude: float,
|
||||||
|
force_refresh: bool = False
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Detect POIs for a tenant's location and generate ML features for forecasting.
|
||||||
|
|
||||||
|
With the new tenant-based architecture:
|
||||||
|
- Gateway receives at: /api/v1/tenants/{tenant_id}/external/poi-context/detect
|
||||||
|
- Gateway proxies to external service at: /api/v1/tenants/{tenant_id}/poi-context/detect
|
||||||
|
- This client calls: poi-context/detect (base client automatically constructs with tenant)
|
||||||
|
|
||||||
|
This triggers POI detection using Overpass API and calculates ML features
|
||||||
|
for demand forecasting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: Tenant ID
|
||||||
|
latitude: Latitude of the bakery location
|
||||||
|
longitude: Longitude of the bakery location
|
||||||
|
force_refresh: Whether to force refresh even if POI context exists
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with POI detection results including:
|
||||||
|
- ml_features: Dict of POI features for ML models (e.g., poi_retail_total_count)
|
||||||
|
- poi_detection_results: Full detection results
|
||||||
|
- location: Latitude/longitude
|
||||||
|
- total_pois_detected: Count of POIs
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"Detecting POIs for tenant",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
location=(latitude, longitude),
|
||||||
|
force_refresh=force_refresh
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"latitude": latitude,
|
||||||
|
"longitude": longitude,
|
||||||
|
"force_refresh": force_refresh
|
||||||
|
}
|
||||||
|
|
||||||
|
# Updated endpoint path to follow tenant-based pattern: external/poi-context/detect
|
||||||
|
result = await self._make_request(
|
||||||
|
"POST",
|
||||||
|
"external/poi-context/detect", # Path will become /api/v1/tenants/{tenant_id}/external/poi-context/detect by base client
|
||||||
|
tenant_id=tenant_id, # Pass tenant_id to include in headers and path construction
|
||||||
|
params=params,
|
||||||
|
timeout=60.0 # POI detection can take longer
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
poi_context = result.get("poi_context", {})
|
||||||
|
ml_features = poi_context.get("ml_features", {})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"POI detection completed successfully",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
total_pois=poi_context.get("total_pois_detected", 0),
|
||||||
|
ml_features_count=len(ml_features),
|
||||||
|
source=result.get("source", "unknown")
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
logger.warning("POI detection failed for tenant", tenant_id=tenant_id)
|
||||||
|
return None
|
||||||
|
|
||||||
async def get_poi_context(
|
async def get_poi_context(
|
||||||
self,
|
self,
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
@@ -504,7 +573,7 @@ class ExternalServiceClient(BaseServiceClient):
|
|||||||
With the new tenant-based architecture:
|
With the new tenant-based architecture:
|
||||||
- Gateway receives at: /api/v1/tenants/{tenant_id}/external/poi-context
|
- Gateway receives at: /api/v1/tenants/{tenant_id}/external/poi-context
|
||||||
- Gateway proxies to external service at: /api/v1/tenants/{tenant_id}/poi-context
|
- Gateway proxies to external service at: /api/v1/tenants/{tenant_id}/poi-context
|
||||||
- This client calls: /tenants/{tenant_id}/poi-context
|
- This client calls: poi-context (base client automatically constructs with tenant)
|
||||||
|
|
||||||
This retrieves stored POI detection results and calculated ML features
|
This retrieves stored POI detection results and calculated ML features
|
||||||
that should be included in demand forecasting predictions.
|
that should be included in demand forecasting predictions.
|
||||||
@@ -521,11 +590,11 @@ class ExternalServiceClient(BaseServiceClient):
|
|||||||
"""
|
"""
|
||||||
logger.info("Fetching POI context for forecasting", tenant_id=tenant_id)
|
logger.info("Fetching POI context for forecasting", tenant_id=tenant_id)
|
||||||
|
|
||||||
# Updated endpoint path to follow tenant-based pattern: /tenants/{tenant_id}/poi-context
|
# Updated endpoint path to follow tenant-based pattern: external/poi-context
|
||||||
result = await self._make_request(
|
result = await self._make_request(
|
||||||
"GET",
|
"GET",
|
||||||
f"tenants/{tenant_id}/poi-context", # Updated path: /tenants/{tenant_id}/poi-context
|
"external/poi-context", # Path will become /api/v1/tenants/{tenant_id}/external/poi-context by base client
|
||||||
tenant_id=tenant_id, # Pass tenant_id to include in headers for authentication
|
tenant_id=tenant_id, # Pass tenant_id to include in headers and path construction
|
||||||
timeout=5.0
|
timeout=5.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user