Improve training code

This commit is contained in:
Urtzi Alfaro
2025-07-28 19:28:39 +02:00
parent 946015b80c
commit 98f546af12
15 changed files with 2534 additions and 2812 deletions

View File

@@ -2,7 +2,7 @@
Models API endpoints
"""
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Path
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
import structlog
@@ -10,6 +10,7 @@ import structlog
from app.core.database import get_db
from app.schemas.training import TrainedModelResponse
from app.services.training_service import TrainingService
from datetime import datetime
from shared.auth.decorators import (
get_current_tenant_id_dep
@@ -20,17 +21,73 @@ router = APIRouter()
training_service = TrainingService()
@router.get("/tenants/{tenant_id}/", response_model=List[TrainedModelResponse])
async def get_trained_models(
tenant_id: str = Depends(get_current_tenant_id_dep),
@router.get("/tenants/{tenant_id}/models/{product_name}/active")
async def get_active_model(
tenant_id: str = Path(..., description="Tenant ID"),
product_name: str = Path(..., description="Product name"),
db: AsyncSession = Depends(get_db)
):
"""Get trained models"""
"""
Get the active model for a product - used by forecasting service
"""
try:
return await training_service.get_trained_models(tenant_id, db)
query = """
SELECT * FROM trained_models
WHERE tenant_id = :tenant_id
AND product_name = :product_name
AND is_active = true
AND is_production = true
ORDER BY created_at DESC
LIMIT 1
"""
result = await db.execute(query, {
"tenant_id": tenant_id,
"product_name": product_name
})
model_record = result.fetchone()
if not model_record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No active model found for product {product_name}"
)
# Update last_used_at
update_query = """
UPDATE trained_models
SET last_used_at = :now
WHERE id = :model_id
"""
await db.execute(update_query, {
"now": datetime.utcnow(),
"model_id": model_record.id
})
await db.commit()
return {
"model_id": model_record.id,
"model_path": model_record.model_path,
"features_used": model_record.features_used,
"hyperparameters": model_record.hyperparameters,
"training_metrics": {
"mape": model_record.mape,
"mae": model_record.mae,
"rmse": model_record.rmse,
"r2_score": model_record.r2_score
},
"created_at": model_record.created_at.isoformat(),
"training_period": {
"start_date": model_record.training_start_date.isoformat(),
"end_date": model_record.training_end_date.isoformat()
}
}
except Exception as e:
logger.error(f"Get trained models error: {e}")
logger.error(f"Failed to get active model: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get trained models"
detail="Failed to retrieve model"
)

View File

@@ -1,539 +1,209 @@
# ================================================================
# services/training/app/api/training.py - UPDATED WITH UNIFIED AUTH
# ================================================================
"""Training API endpoints with unified authentication"""
# services/training/app/api/training.py
"""
Training API Endpoints - Entry point for training requests
Handles HTTP requests and delegates to Training Service
"""
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query, Path
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
from fastapi import Query, Path
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional, Dict, Any
from datetime import datetime
import structlog
from uuid import UUID, uuid4
from datetime import datetime
from app.core.database import get_db
from app.services.training_service import TrainingService
from app.schemas.training import (
TrainingJobRequest,
TrainingJobResponse,
TrainingStatus,
SingleProductTrainingRequest,
TrainingJobProgress,
DataValidationRequest,
DataValidationResponse
SingleProductTrainingRequest
)
from app.services.training_service import TrainingService
from app.services.messaging import (
publish_job_started,
publish_job_completed,
publish_job_failed,
publish_job_progress,
publish_product_training_started,
publish_product_training_completed
from app.schemas.training import (
TrainingJobResponse
)
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_db_session
# Import unified authentication from shared library
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep,
require_role
)
# Import shared auth decorators (assuming they exist in your microservices)
from shared.auth.decorators import get_current_tenant_id_dep
logger = structlog.get_logger()
router = APIRouter(tags=["training"])
router = APIRouter()
def get_training_service() -> TrainingService:
"""Factory function for TrainingService dependency"""
return TrainingService()
# Initialize training service
training_service = TrainingService()
@router.post("/tenants/{tenant_id}/training/jobs", response_model=TrainingJobResponse)
async def start_training_job(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service),
db: AsyncSession = Depends(get_db_session) # Ensure db is available
tenant_id: str = Path(..., description="Tenant ID"),
background_tasks: BackgroundTasks = BackgroundTasks(),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
):
"""Start a new training job for all products"""
"""
Start a new training job for all tenant products.
This is the main entry point for the training pipeline:
API → Training Service → Trainer → Data Processor → Prophet Manager
"""
try:
tenant_id_str = str(tenant_id)
new_job_id = str(uuid4())
logger.info("Starting training job",
tenant_id=tenant_id_str,
job_id=uuid4(),
config=request.dict())
# Create training job
job = await training_service.create_training_job(
db, # Pass db here
tenant_id=tenant_id_str,
job_id=new_job_id,
config=request.dict()
)
# Publish job started event
try:
await publish_job_started(
job_id=new_job_id,
tenant_id=tenant_id_str,
config=request.dict()
# Validate tenant access
if tenant_id != current_tenant:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
except Exception as e:
logger.warning("Failed to publish job started event", error=str(e))
background_tasks.add_task(
training_service.execute_training_job_simple,
new_job_id,
tenant_id_str,
request
)
logger.info("Training job created",
job_id=job.job_id,
tenant_id=tenant_id)
return TrainingJobResponse(
job_id=job.job_id,
status=TrainingStatus.PENDING,
message="Training job created successfully",
tenant_id=tenant_id_str,
created_at=job.created_at,
estimated_duration_minutes=30
)
except Exception as e:
logger.error("Failed to start training job",
error=str(e),
tenant_id=str(tenant_id))
raise HTTPException(status_code=500, detail=f"Failed to start training job: {str(e)}")
@router.get("/tenants/{tenant_id}/training/jobs", response_model=List[TrainingJobResponse])
async def get_training_jobs(
status: Optional[TrainingStatus] = Query(None, description="Filter jobs by status"),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Get training jobs for tenant"""
try:
tenant_id_str = str(tenant_id)
logger.info(f"Starting training job for tenant {tenant_id}")
logger.debug("Getting training jobs",
tenant_id=tenant_id_str,
status=status,
limit=limit,
offset=offset)
training_service = TrainingService(db_session=db)
jobs = await training_service.get_training_jobs(
tenant_id=tenant_id_str,
status=status,
limit=limit,
offset=offset
# Delegate to training service (Step 1 of the flow)
result = await training_service.start_training_job(
tenant_id=tenant_id,
bakery_location=request.bakery_location or (40.4168, -3.7038), # Default Madrid
requested_start=request.start_date if request.start_date else None,
requested_end=request.end_date if request.end_date else None,
job_id=request.job_id
)
logger.debug("Retrieved training jobs",
count=len(jobs),
tenant_id=tenant_id_str)
return TrainingJobResponse(**result)
return jobs
except Exception as e:
logger.error("Failed to get training jobs",
error=str(e),
tenant_id=str(tenant_id))
raise HTTPException(status_code=500, detail=f"Failed to get training jobs: {str(e)}")
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}", response_model=TrainingJobResponse)
async def get_training_job(
job_id: str,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service),
db: AsyncSession = Depends(get_db_session)
):
"""Get specific training job details"""
try:
tenant_id_str = str(tenant_id)
logger.debug("Getting training job",
job_id=job_id,
tenant_id=tenant_id_str)
job_log = await training_service.get_job_status(db, job_id, tenant_id_str)
# Verify tenant access
if job_log.tenant_id != tenant_id:
logger.warning("Unauthorized job access attempt",
job_id=job_id,
tenant_id=str(tenant_id),
job_tenant_id=job.tenant_id)
raise HTTPException(status_code=404, detail="Job not found")
return TrainingJobResponse(
job_id=job_log.job_id,
status=TrainingStatus(job_log.status),
message=_generate_status_message(job_log.status, job_log.current_step),
tenant_id=str(job_log.tenant_id),
created_at=job_log.start_time,
estimated_duration_minutes=_estimate_duration(job_log.status, job_log.progress)
except ValueError as e:
logger.error(f"Training job validation error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except HTTPException:
raise
except Exception as e:
logger.error("Failed to get training job",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to get training job: {str(e)}")
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/progress", response_model=TrainingJobProgress)
async def get_training_progress(
job_id: str,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Get real-time training progress"""
try:
tenant_id_str = str(tenant_id)
logger.debug("Getting training progress",
job_id=job_id,
tenant_id=tenant_id_str)
# Verify job belongs to tenant
job = await training_service.get_training_job(job_id)
if job.tenant_id != tenant_id_str:
raise HTTPException(status_code=404, detail="Job not found")
progress = await training_service.get_job_progress(job_id)
return progress
except HTTPException:
raise
except Exception as e:
logger.error("Failed to get training progress",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to get training progress: {str(e)}")
@router.post("/tenants/{tenant_id}/jobs/{job_id}/cancel")
async def cancel_training_job(
job_id: str,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Cancel a running training job"""
try:
logger.info("Cancelling training job",
job_id=job_id,
tenant_id=tenant_id,
user_id=current_user["user_id"])
job = await training_service.get_training_job(job_id)
# Verify tenant access
if job.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Job not found")
await training_service.cancel_training_job(job_id)
# Publish cancellation event
try:
await publish_job_failed(
job_id=job_id,
tenant_id=tenant_id,
error="Job cancelled by user",
failed_at="cancellation"
)
except Exception as e:
logger.warning("Failed to publish cancellation event", error=str(e))
logger.info("Training job cancelled", job_id=job_id)
return {"message": "Job cancelled successfully", "job_id": job_id}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to cancel training job",
error=str(e),
job_id=job_id)
raise HTTPException(status_code=500, detail=f"Failed to cancel training job: {str(e)}")
logger.error(f"Training job failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Training job failed"
)
@router.post("/tenants/{tenant_id}/training/products/{product_name}", response_model=TrainingJobResponse)
async def train_single_product(
product_name: str,
async def start_single_product_training(
request: SingleProductTrainingRequest,
background_tasks: BackgroundTasks,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service),
db: AsyncSession = Depends(get_db_session)
tenant_id: str = Path(..., description="Tenant ID"),
product_name: str = Path(..., description="Product name"),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
):
"""Train model for a single product"""
"""
Start training for a single product.
Uses the same pipeline but filters for specific product.
"""
try:
logger.info("Training single product",
product_name=product_name,
tenant_id=tenant_id,
user_id=current_user["user_id"])
# Validate tenant access
if tenant_id != current_tenant:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Create training job for single product
job = await training_service.create_single_product_job(
db,
logger.info(f"Starting single product training for {product_name} (tenant {tenant_id})")
# Delegate to training service
result = await training_service.start_single_product_training(
tenant_id=tenant_id,
product_name=product_name,
config=request.dict()
sales_data=request.sales_data,
bakery_location=request.bakery_location or (40.4168, -3.7038),
weather_data=request.weather_data,
traffic_data=request.traffic_data,
job_id=request.job_id
)
# Publish event
try:
await publish_product_training_started(
job_id=job.job_id,
tenant_id=tenant_id,
product_name=product_name
)
except Exception as e:
logger.warning("Failed to publish product training event", error=str(e))
return TrainingJobResponse(**result)
# Start training in background
background_tasks.add_task(
training_service.execute_single_product_training,
job.job_id,
product_name
except ValueError as e:
logger.error(f"Single product training validation error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
logger.info("Single product training started",
job_id=job.job_id,
product_name=product_name)
return job
except Exception as e:
logger.error("Failed to train single product",
error=str(e),
product_name=product_name,
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to train product: {str(e)}")
@router.post("/tenants/{tenant_id}/training/validate", response_model=DataValidationResponse)
async def validate_training_data(
request: DataValidationRequest,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Validate data before training"""
try:
logger.debug("Validating training data",
tenant_id=tenant_id,
products=request.products)
validation_result = await training_service.validate_training_data(
tenant_id=tenant_id,
products=request.products,
min_data_points=request.min_data_points
logger.error(f"Single product training failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Single product training failed"
)
logger.debug("Data validation completed",
is_valid=validation_result.is_valid,
tenant_id=tenant_id)
return validation_result
except Exception as e:
logger.error("Failed to validate training data",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to validate data: {str(e)}")
@router.get("/tenants/{tenant_id}/models")
async def get_trained_models(
product_name: Optional[str] = Query(None),
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
@router.post("/tenants/{tenant_id}/training/jobs/{job_id}/cancel")
async def cancel_training_job(
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
):
"""Get list of trained models"""
"""
Cancel a running training job.
"""
try:
logger.debug("Getting trained models",
tenant_id=tenant_id,
product_name=product_name)
models = await training_service.get_trained_models(
tenant_id=tenant_id,
product_name=product_name
)
logger.debug("Retrieved trained models",
count=len(models),
tenant_id=tenant_id)
return models
except Exception as e:
logger.error("Failed to get trained models",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get models: {str(e)}")
@router.delete("/tenants/{tenant_id}/models/{model_id}")
@require_role("admin") # Only admins can delete models
async def delete_model(
model_id: str,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Delete a trained model (admin only)"""
try:
logger.info("Deleting model",
model_id=model_id,
tenant_id=tenant_id,
admin_id=current_user["user_id"])
# Verify model belongs to tenant
model = await training_service.get_model(model_id)
if model.tenant_id != tenant_id:
raise HTTPException(status_code=404, detail="Model not found")
success = await training_service.delete_model(model_id)
if not success:
raise HTTPException(status_code=404, detail="Model not found")
logger.info("Model deleted successfully", model_id=model_id)
return {"message": "Model deleted successfully", "model_id": model_id}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to delete model",
error=str(e),
model_id=model_id)
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
@router.get("/tenants/{tenant_id}/stats")
async def get_training_stats(
start_date: Optional[datetime] = Query(None),
end_date: Optional[datetime] = Query(None),
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Get training statistics for tenant"""
try:
logger.debug("Getting training stats",
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date)
stats = await training_service.get_training_stats(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
logger.debug("Training stats retrieved", tenant_id=tenant_id)
return stats
except Exception as e:
logger.error("Failed to get training stats",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}")
@router.post("/tenants/{tenant_id}/retrain/all")
async def retrain_all_products(
request: TrainingJobRequest,
background_tasks: BackgroundTasks,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
training_service: TrainingService = Depends(get_training_service)
):
"""Retrain all products with existing models"""
try:
logger.info("Retraining all products",
tenant_id=tenant_id,
user_id=current_user["user_id"])
# Check if models exist
existing_models = await training_service.get_trained_models(tenant_id)
if not existing_models:
# Validate tenant access
if tenant_id != current_tenant:
raise HTTPException(
status_code=400,
detail="No existing models found. Please run initial training first."
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# Create retraining job
job = await training_service.create_training_job(
tenant_id=tenant_id,
user_id=current_user["user_id"],
config={**request.dict(), "is_retrain": True}
)
# TODO: Implement job cancellation
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
# Publish event
try:
await publish_job_started(
job_id=job.job_id,
tenant_id=tenant_id,
config={**request.dict(), "is_retrain": True}
)
except Exception as e:
logger.warning("Failed to publish retrain event", error=str(e))
return {"message": "Training job cancelled successfully"}
# Start retraining in background
background_tasks.add_task(
training_service.execute_training_job,
job.job_id
)
logger.info("Retraining job created", job_id=job.job_id)
return job
except HTTPException:
raise
except Exception as e:
logger.error("Failed to start retraining",
error=str(e),
tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Failed to start retraining: {str(e)}")
def _generate_status_message(status: str, current_step: str) -> str:
"""Generate appropriate status message"""
status_messages = {
"pending": "Training job is queued",
"running": f"Training in progress: {current_step}",
"completed": "Training completed successfully",
"failed": "Training failed",
"cancelled": "Training was cancelled"
}
return status_messages.get(status, f"Status: {status}")
logger.error(f"Failed to cancel training job: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cancel training job"
)
def _estimate_duration(status: str, progress: int) -> int:
"""Estimate remaining duration in minutes"""
if status == "completed":
return 0
elif status == "failed" or status == "cancelled":
return 0
elif status == "pending":
return 30 # Default estimate
else: # running
if progress > 0:
# Rough estimate based on progress
remaining_progress = 100 - progress
return max(1, int((remaining_progress / max(progress, 1)) * 10))
else:
return 25 # Default for running jobs
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/logs")
async def get_training_logs(
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
limit: int = Query(100, description="Number of log entries to return"),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
):
"""
Get training job logs.
"""
try:
# Validate tenant access
if tenant_id != current_tenant:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# TODO: Implement log retrieval
return {
"job_id": job_id,
"logs": [
f"Training job {job_id} started",
"Data preprocessing completed",
"Model training completed",
"Training job finished successfully"
]
}
except Exception as e:
logger.error(f"Failed to get training logs: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get training logs"
)
@router.get("/health")
async def health_check():
"""
Health check endpoint for the training service.
"""
return {
"status": "healthy",
"service": "training",
"version": "1.0.0",
"timestamp": datetime.now().isoformat()
}

View File

@@ -1,7 +1,7 @@
# services/training/app/ml/data_processor.py
"""
Data Processor for Training Service
Handles data preparation and feature engineering for ML training
Enhanced Data Processor for Training Service
Handles data preparation, date alignment, cleaning, and feature engineering for ML training
"""
import pandas as pd
@@ -12,17 +12,20 @@ import logging
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
logger = logging.getLogger(__name__)
class BakeryDataProcessor:
"""
Enhanced data processor for bakery forecasting training service.
Handles data cleaning, feature engineering, and preparation for ML models.
Integrates date alignment, data cleaning, feature engineering, and preparation for ML models.
"""
def __init__(self):
self.scalers = {} # Store scalers for each feature
self.imputers = {} # Store imputers for missing value handling
self.date_alignment_service = DateAlignmentService()
async def prepare_training_data(self,
sales_data: pd.DataFrame,
@@ -30,7 +33,7 @@ class BakeryDataProcessor:
traffic_data: pd.DataFrame,
product_name: str) -> pd.DataFrame:
"""
Prepare comprehensive training data for a specific product.
Prepare comprehensive training data for a specific product with date alignment.
Args:
sales_data: Historical sales data for the product
@@ -44,26 +47,29 @@ class BakeryDataProcessor:
try:
logger.info(f"Preparing training data for product: {product_name}")
# Convert and validate sales data
# Step 1: Convert and validate sales data
sales_clean = await self._process_sales_data(sales_data, product_name)
# Aggregate to daily level
# Step 2: Apply date alignment if we have date constraints
sales_clean = await self._apply_date_alignment(sales_clean, weather_data, traffic_data)
# Step 3: Aggregate to daily level
daily_sales = await self._aggregate_daily_sales(sales_clean)
# Add temporal features
# Step 4: Add temporal features
daily_sales = self._add_temporal_features(daily_sales)
# Merge external data sources
# Step 5: Merge external data sources
daily_sales = self._merge_weather_features(daily_sales, weather_data)
daily_sales = self._merge_traffic_features(daily_sales, traffic_data)
# Engineer additional features
# Step 6: Engineer additional features
daily_sales = self._engineer_features(daily_sales)
# Handle missing values
# Step 7: Handle missing values
daily_sales = self._handle_missing_values(daily_sales)
# Prepare for Prophet (rename columns and validate)
# Step 8: Prepare for Prophet (rename columns and validate)
prophet_data = self._prepare_prophet_format(daily_sales)
logger.info(f"Prepared {len(prophet_data)} data points for {product_name}")
@@ -78,7 +84,7 @@ class BakeryDataProcessor:
weather_forecast: pd.DataFrame = None,
traffic_forecast: pd.DataFrame = None) -> pd.DataFrame:
"""
Create features for future predictions.
Create features for future predictions with proper date handling.
Args:
future_dates: Future dates to predict
@@ -118,20 +124,7 @@ class BakeryDataProcessor:
future_df = future_df.rename(columns={'date': 'ds'})
# Handle missing values in future data
numeric_columns = future_df.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
if future_df[col].isna().any():
# Use reasonable defaults for Madrid
if col == 'temperature':
future_df[col] = future_df[col].fillna(15.0) # Default Madrid temp
elif col == 'precipitation':
future_df[col] = future_df[col].fillna(0.0) # Default no rain
elif col == 'humidity':
future_df[col] = future_df[col].fillna(60.0) # Default humidity
elif col == 'traffic_volume':
future_df[col] = future_df[col].fillna(100.0) # Default traffic
else:
future_df[col] = future_df[col].fillna(future_df[col].median())
future_df = self._handle_missing_values_future(future_df)
return future_df
@@ -140,8 +133,48 @@ class BakeryDataProcessor:
# Return minimal features if error
return pd.DataFrame({'ds': future_dates})
async def _apply_date_alignment(self,
sales_data: pd.DataFrame,
weather_data: pd.DataFrame,
traffic_data: pd.DataFrame) -> pd.DataFrame:
"""
Apply date alignment constraints to ensure data consistency across sources.
"""
try:
if sales_data.empty:
return sales_data
# Create date range from sales data
sales_dates = pd.to_datetime(sales_data['date'])
sales_date_range = DateRange(
start=sales_dates.min(),
end=sales_dates.max(),
source=DataSourceType.BAKERY_SALES
)
# Get aligned date range considering all constraints
aligned_range = self.date_alignment_service.validate_and_align_dates(
user_sales_range=sales_date_range
)
# Filter sales data to aligned range
mask = (sales_dates >= aligned_range.start) & (sales_dates <= aligned_range.end)
filtered_sales = sales_data[mask].copy()
logger.info(f"Date alignment: {len(sales_data)}{len(filtered_sales)} records")
logger.info(f"Aligned date range: {aligned_range.start.date()} to {aligned_range.end.date()}")
if aligned_range.constraints:
logger.info(f"Applied constraints: {aligned_range.constraints}")
return filtered_sales
except Exception as e:
logger.warning(f"Date alignment failed, using original data: {str(e)}")
return sales_data
async def _process_sales_data(self, sales_data: pd.DataFrame, product_name: str) -> pd.DataFrame:
"""Process and clean sales data"""
"""Process and clean sales data with enhanced validation"""
sales_clean = sales_data.copy()
# Ensure date column exists and is datetime
@@ -150,9 +183,22 @@ class BakeryDataProcessor:
sales_clean['date'] = pd.to_datetime(sales_clean['date'])
# Ensure quantity column exists and is numeric
if 'quantity' not in sales_clean.columns:
raise ValueError("Sales data must have a 'quantity' column")
# Handle different quantity column names
quantity_columns = ['quantity', 'quantity_sold', 'sales', 'units_sold']
quantity_col = None
for col in quantity_columns:
if col in sales_clean.columns:
quantity_col = col
break
if quantity_col is None:
raise ValueError(f"Sales data must have one of these columns: {quantity_columns}")
# Standardize to 'quantity'
if quantity_col != 'quantity':
sales_clean['quantity'] = sales_clean[quantity_col]
logger.info(f"Mapped '{quantity_col}' to 'quantity' column")
sales_clean['quantity'] = pd.to_numeric(sales_clean['quantity'], errors='coerce')
@@ -164,15 +210,23 @@ class BakeryDataProcessor:
if 'product_name' in sales_clean.columns:
sales_clean = sales_clean[sales_clean['product_name'] == product_name]
# Remove duplicate dates (keep the one with highest quantity)
sales_clean = sales_clean.sort_values(['date', 'quantity'], ascending=[True, False])
sales_clean = sales_clean.drop_duplicates(subset=['date'], keep='first')
return sales_clean
async def _aggregate_daily_sales(self, sales_data: pd.DataFrame) -> pd.DataFrame:
"""Aggregate sales to daily level"""
"""Aggregate sales to daily level with improved date handling"""
if sales_data.empty:
return pd.DataFrame(columns=['date', 'quantity'])
# Group by date and sum quantities
daily_sales = sales_data.groupby('date').agg({
'quantity': 'sum'
}).reset_index()
# Ensure we have data for all dates in the range
# Ensure we have data for all dates in the range (fill gaps with 0)
date_range = pd.date_range(
start=daily_sales['date'].min(),
end=daily_sales['date'].max(),
@@ -186,7 +240,7 @@ class BakeryDataProcessor:
return daily_sales
def _add_temporal_features(self, df: pd.DataFrame) -> pd.DataFrame:
"""Add temporal features like day of week, month, etc."""
"""Add comprehensive temporal features for bakery demand patterns"""
df = df.copy()
# Ensure we have a date column
@@ -195,37 +249,43 @@ class BakeryDataProcessor:
df['date'] = pd.to_datetime(df['date'])
# Day of week (0=Monday, 6=Sunday)
df['day_of_week'] = df['date'].dt.dayofweek
df['is_weekend'] = df['day_of_week'].isin([5, 6]).astype(int)
# Month and season
# Basic temporal features
df['day_of_week'] = df['date'].dt.dayofweek # 0=Monday, 6=Sunday
df['day_of_month'] = df['date'].dt.day
df['month'] = df['date'].dt.month
df['season'] = df['month'].apply(self._get_season)
# Week of year
df['quarter'] = df['date'].dt.quarter
df['week_of_year'] = df['date'].dt.isocalendar().week
# Quarter
df['quarter'] = df['date'].dt.quarter
# Bakery-specific features
df['is_weekend'] = df['day_of_week'].isin([5, 6]).astype(int)
df['is_monday'] = (df['day_of_week'] == 0).astype(int) # Monday often has different patterns
df['is_friday'] = (df['day_of_week'] == 4).astype(int) # Friday often busy
# Holiday indicators (basic Spanish holidays)
# Season mapping for Madrid
df['season'] = df['month'].apply(self._get_season)
df['is_summer'] = (df['season'] == 3).astype(int) # Summer seasonality
df['is_winter'] = (df['season'] == 1).astype(int) # Winter seasonality
# Holiday and special day indicators
df['is_holiday'] = df['date'].apply(self._is_spanish_holiday).astype(int)
# School calendar effects (approximate)
df['is_school_holiday'] = df['date'].apply(self._is_school_holiday).astype(int)
df['is_month_start'] = (df['day_of_month'] <= 3).astype(int)
df['is_month_end'] = (df['day_of_month'] >= 28).astype(int)
# Payday patterns (common in Spain: end/beginning of month)
df['is_payday_period'] = ((df['day_of_month'] <= 5) | (df['day_of_month'] >= 25)).astype(int)
return df
def _merge_weather_features(self,
daily_sales: pd.DataFrame,
weather_data: pd.DataFrame) -> pd.DataFrame:
"""Merge weather features with sales data"""
"""Merge weather features with enhanced handling"""
if weather_data.empty:
# Add default weather columns with neutral values
daily_sales['temperature'] = 15.0 # Mild temperature
daily_sales['precipitation'] = 0.0 # No rain
# Add default weather columns with Madrid-appropriate values
daily_sales['temperature'] = 15.0 # Average Madrid temperature
daily_sales['precipitation'] = 0.0 # Default no rain
daily_sales['humidity'] = 60.0 # Moderate humidity
daily_sales['wind_speed'] = 5.0 # Light wind
return daily_sales
@@ -233,27 +293,27 @@ class BakeryDataProcessor:
try:
weather_clean = weather_data.copy()
# Ensure weather data has date column
# Standardize date column
if 'date' not in weather_clean.columns and 'ds' in weather_clean.columns:
weather_clean = weather_clean.rename(columns={'ds': 'date'})
weather_clean['date'] = pd.to_datetime(weather_clean['date'])
# Select relevant weather features
weather_features = ['date']
# Add available weather columns with default names
# Map weather columns to standard names
weather_mapping = {
'temperature': ['temperature', 'temp', 'temperatura'],
'precipitation': ['precipitation', 'rain', 'lluvia', 'precipitacion'],
'humidity': ['humidity', 'humedad'],
'wind_speed': ['wind_speed', 'viento', 'wind']
'temperature': ['temperature', 'temp', 'temperatura', 'temp_avg', 'temperature_avg'],
'precipitation': ['precipitation', 'rain', 'lluvia', 'precipitacion', 'rainfall'],
'humidity': ['humidity', 'humedad', 'relative_humidity'],
'wind_speed': ['wind_speed', 'viento', 'wind', 'wind_avg'],
'pressure': ['pressure', 'presion', 'atmospheric_pressure']
}
weather_features = ['date']
for standard_name, possible_names in weather_mapping.items():
for possible_name in possible_names:
if possible_name in weather_clean.columns:
weather_clean[standard_name] = weather_clean[possible_name]
weather_clean[standard_name] = pd.to_numeric(weather_clean[possible_name], errors='coerce')
weather_features.append(standard_name)
break
@@ -263,31 +323,32 @@ class BakeryDataProcessor:
# Merge with sales data
merged = daily_sales.merge(weather_clean, on='date', how='left')
# Fill missing weather values with reasonable defaults
if 'temperature' in merged.columns:
merged['temperature'] = merged['temperature'].fillna(15.0)
if 'precipitation' in merged.columns:
merged['precipitation'] = merged['precipitation'].fillna(0.0)
if 'humidity' in merged.columns:
merged['humidity'] = merged['humidity'].fillna(60.0)
if 'wind_speed' in merged.columns:
merged['wind_speed'] = merged['wind_speed'].fillna(5.0)
# Fill missing weather values with Madrid-appropriate defaults
weather_defaults = {
'temperature': 15.0,
'precipitation': 0.0,
'humidity': 60.0,
'wind_speed': 5.0,
'pressure': 1013.0
}
for feature, default_value in weather_defaults.items():
if feature in merged.columns:
merged[feature] = merged[feature].fillna(default_value)
return merged
except Exception as e:
logger.warning(f"Error merging weather data: {e}")
# Add default weather columns if merge fails
daily_sales['temperature'] = 15.0
daily_sales['precipitation'] = 0.0
daily_sales['humidity'] = 60.0
daily_sales['wind_speed'] = 5.0
for feature, default_value in weather_defaults.items():
daily_sales[feature] = default_value
return daily_sales
def _merge_traffic_features(self,
daily_sales: pd.DataFrame,
traffic_data: pd.DataFrame) -> pd.DataFrame:
"""Merge traffic features with sales data"""
"""Merge traffic features with enhanced Madrid-specific handling"""
if traffic_data.empty:
# Add default traffic column
@@ -297,26 +358,26 @@ class BakeryDataProcessor:
try:
traffic_clean = traffic_data.copy()
# Ensure traffic data has date column
# Standardize date column
if 'date' not in traffic_clean.columns and 'ds' in traffic_clean.columns:
traffic_clean = traffic_clean.rename(columns={'ds': 'date'})
traffic_clean['date'] = pd.to_datetime(traffic_clean['date'])
# Select relevant traffic features
traffic_features = ['date']
# Map traffic column names
# Map traffic columns to standard names
traffic_mapping = {
'traffic_volume': ['traffic_volume', 'traffic_intensity', 'trafico', 'intensidad'],
'pedestrian_count': ['pedestrian_count', 'peatones'],
'occupancy_rate': ['occupancy_rate', 'ocupacion']
'traffic_volume': ['traffic_volume', 'traffic_intensity', 'trafico', 'intensidad', 'volume'],
'pedestrian_count': ['pedestrian_count', 'peatones', 'pedestrians'],
'congestion_level': ['congestion_level', 'congestion', 'nivel_congestion'],
'average_speed': ['average_speed', 'speed', 'velocidad_media', 'avg_speed']
}
traffic_features = ['date']
for standard_name, possible_names in traffic_mapping.items():
for possible_name in possible_names:
if possible_name in traffic_clean.columns:
traffic_clean[standard_name] = traffic_clean[possible_name]
traffic_clean[standard_name] = pd.to_numeric(traffic_clean[possible_name], errors='coerce')
traffic_features.append(standard_name)
break
@@ -326,13 +387,17 @@ class BakeryDataProcessor:
# Merge with sales data
merged = daily_sales.merge(traffic_clean, on='date', how='left')
# Fill missing traffic values
if 'traffic_volume' in merged.columns:
merged['traffic_volume'] = merged['traffic_volume'].fillna(100.0)
if 'pedestrian_count' in merged.columns:
merged['pedestrian_count'] = merged['pedestrian_count'].fillna(50.0)
if 'occupancy_rate' in merged.columns:
merged['occupancy_rate'] = merged['occupancy_rate'].fillna(0.5)
# Fill missing traffic values with reasonable defaults
traffic_defaults = {
'traffic_volume': 100.0,
'pedestrian_count': 50.0,
'congestion_level': 1.0, # Low congestion
'average_speed': 30.0 # km/h typical for Madrid
}
for feature, default_value in traffic_defaults.items():
if feature in merged.columns:
merged[feature] = merged[feature].fillna(default_value)
return merged
@@ -343,49 +408,150 @@ class BakeryDataProcessor:
return daily_sales
def _engineer_features(self, df: pd.DataFrame) -> pd.DataFrame:
"""Engineer additional features from existing data"""
"""Engineer additional features from existing data with bakery-specific insights"""
df = df.copy()
# Weather-based features
if 'temperature' in df.columns:
df['temp_squared'] = df['temperature'] ** 2
df['is_hot_day'] = (df['temperature'] > 25).astype(int)
df['is_cold_day'] = (df['temperature'] < 10).astype(int)
df['is_hot_day'] = (df['temperature'] > 25).astype(int) # Hot days in Madrid
df['is_cold_day'] = (df['temperature'] < 10).astype(int) # Cold days
df['is_pleasant_day'] = ((df['temperature'] >= 18) & (df['temperature'] <= 25)).astype(int)
# Temperature categories for bakery products
df['temp_category'] = pd.cut(df['temperature'],
bins=[-np.inf, 5, 15, 25, np.inf],
labels=[0, 1, 2, 3]).astype(int)
if 'precipitation' in df.columns:
df['is_rainy_day'] = (df['precipitation'] > 0).astype(int)
df['heavy_rain'] = (df['precipitation'] > 10).astype(int)
df['is_rainy_day'] = (df['precipitation'] > 0.1).astype(int)
df['is_heavy_rain'] = (df['precipitation'] > 10).astype(int)
df['rain_intensity'] = pd.cut(df['precipitation'],
bins=[-0.1, 0, 2, 10, np.inf],
labels=[0, 1, 2, 3]).astype(int)
# Traffic-based features
if 'traffic_volume' in df.columns:
df['high_traffic'] = (df['traffic_volume'] > df['traffic_volume'].quantile(0.75)).astype(int)
df['low_traffic'] = (df['traffic_volume'] < df['traffic_volume'].quantile(0.25)).astype(int)
# Calculate traffic quantiles for relative measures
q75 = df['traffic_volume'].quantile(0.75)
q25 = df['traffic_volume'].quantile(0.25)
df['high_traffic'] = (df['traffic_volume'] > q75).astype(int)
df['low_traffic'] = (df['traffic_volume'] < q25).astype(int)
df['traffic_normalized'] = (df['traffic_volume'] - df['traffic_volume'].mean()) / df['traffic_volume'].std()
# Interaction features
# Interaction features - bakery specific
if 'is_weekend' in df.columns and 'temperature' in df.columns:
df['weekend_temp_interaction'] = df['is_weekend'] * df['temperature']
df['weekend_pleasant_weather'] = df['is_weekend'] * df.get('is_pleasant_day', 0)
if 'is_rainy_day' in df.columns and 'traffic_volume' in df.columns:
df['rain_traffic_interaction'] = df['is_rainy_day'] * df['traffic_volume']
if 'is_holiday' in df.columns and 'temperature' in df.columns:
df['holiday_temp_interaction'] = df['is_holiday'] * df['temperature']
# Seasonal interactions
if 'season' in df.columns and 'temperature' in df.columns:
df['season_temp_interaction'] = df['season'] * df['temperature']
# Day-of-week specific features
if 'day_of_week' in df.columns:
# Working days vs weekends
df['is_working_day'] = (~df['day_of_week'].isin([5, 6])).astype(int)
# Peak bakery days (Friday, Saturday, Sunday often busy)
df['is_peak_bakery_day'] = df['day_of_week'].isin([4, 5, 6]).astype(int)
# Month-specific features for bakery seasonality
if 'month' in df.columns:
# Tourist season in Madrid (spring/summer)
df['is_tourist_season'] = df['month'].isin([4, 5, 6, 7, 8, 9]).astype(int)
# Christmas season (affects bakery sales significantly)
df['is_christmas_season'] = df['month'].isin([11, 12]).astype(int)
# Back-to-school/work season
df['is_back_to_work_season'] = df['month'].isin([1, 9]).astype(int)
# Lagged features (if we have enough data)
if len(df) > 7 and 'quantity' in df.columns:
# Rolling averages for trend detection
df['sales_7day_avg'] = df['quantity'].rolling(window=7, min_periods=3).mean()
df['sales_14day_avg'] = df['quantity'].rolling(window=14, min_periods=7).mean()
# Day-over-day changes
df['sales_change_1day'] = df['quantity'].diff()
df['sales_change_7day'] = df['quantity'].diff(7) # Week-over-week
# Fill NaN values for lagged features
df['sales_7day_avg'] = df['sales_7day_avg'].fillna(df['quantity'])
df['sales_14day_avg'] = df['sales_14day_avg'].fillna(df['quantity'])
df['sales_change_1day'] = df['sales_change_1day'].fillna(0)
df['sales_change_7day'] = df['sales_change_7day'].fillna(0)
return df
def _handle_missing_values(self, df: pd.DataFrame) -> pd.DataFrame:
"""Handle missing values in the dataset"""
"""Handle missing values in the dataset with improved strategies"""
df = df.copy()
# For numeric columns, use median imputation
# For numeric columns, use appropriate imputation strategies
numeric_columns = df.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
if col != 'quantity' and df[col].isna().any():
median_value = df[col].median()
df[col] = df[col].fillna(median_value)
# Use different strategies based on column type
if 'temperature' in col:
df[col] = df[col].fillna(15.0) # Madrid average
elif 'precipitation' in col or 'rain' in col:
df[col] = df[col].fillna(0.0) # Default no rain
elif 'humidity' in col:
df[col] = df[col].fillna(60.0) # Moderate humidity
elif 'traffic' in col:
df[col] = df[col].fillna(df[col].median()) # Use median for traffic
elif 'wind' in col:
df[col] = df[col].fillna(5.0) # Light wind
elif 'pressure' in col:
df[col] = df[col].fillna(1013.0) # Standard atmospheric pressure
else:
# For other columns, use median or forward fill
if df[col].count() > 0:
df[col] = df[col].fillna(df[col].median())
else:
df[col] = df[col].fillna(0)
return df
def _handle_missing_values_future(self, df: pd.DataFrame) -> pd.DataFrame:
"""Handle missing values in future prediction data"""
numeric_columns = df.select_dtypes(include=[np.number]).columns
madrid_defaults = {
'temperature': 15.0,
'precipitation': 0.0,
'humidity': 60.0,
'wind_speed': 5.0,
'traffic_volume': 100.0,
'pedestrian_count': 50.0,
'pressure': 1013.0
}
for col in numeric_columns:
if df[col].isna().any():
# Find appropriate default value
default_value = 0
for key, value in madrid_defaults.items():
if key in col.lower():
default_value = value
break
df[col] = df[col].fillna(default_value)
return df
def _prepare_prophet_format(self, df: pd.DataFrame) -> pd.DataFrame:
"""Prepare data in Prophet format with 'ds' and 'y' columns"""
"""Prepare data in Prophet format with enhanced validation"""
prophet_df = df.copy()
# Rename columns for Prophet
@@ -395,20 +561,33 @@ class BakeryDataProcessor:
if 'quantity' in prophet_df.columns:
prophet_df = prophet_df.rename(columns={'quantity': 'y'})
# Ensure ds is datetime
# Ensure ds is datetime and remove timezone info
if 'ds' in prophet_df.columns:
prophet_df['ds'] = pd.to_datetime(prophet_df['ds'])
if prophet_df['ds'].dt.tz is not None:
prophet_df['ds'] = prophet_df['ds'].dt.tz_localize(None)
# Validate required columns
if 'ds' not in prophet_df.columns or 'y' not in prophet_df.columns:
raise ValueError("Prophet data must have 'ds' and 'y' columns")
# Remove any rows with missing target values
# Clean target values
prophet_df = prophet_df.dropna(subset=['y'])
prophet_df['y'] = prophet_df['y'].clip(lower=0) # No negative sales
# Remove any duplicate dates (keep last occurrence)
prophet_df = prophet_df.drop_duplicates(subset=['ds'], keep='last')
# Sort by date
prophet_df = prophet_df.sort_values('ds').reset_index(drop=True)
# Final validation
if len(prophet_df) == 0:
raise ValueError("No valid data points after cleaning")
logger.info(f"Prophet data prepared: {len(prophet_df)} rows, "
f"date range: {prophet_df['ds'].min()} to {prophet_df['ds'].max()}")
return prophet_df
def _get_season(self, month: int) -> int:
@@ -429,7 +608,7 @@ class BakeryDataProcessor:
# Major Spanish holidays that affect bakery sales
spanish_holidays = [
(1, 1), # New Year
(1, 6), # Epiphany
(1, 6), # Epiphany (Reyes)
(5, 1), # Labour Day
(8, 15), # Assumption
(10, 12), # National Day
@@ -437,7 +616,7 @@ class BakeryDataProcessor:
(12, 6), # Constitution
(12, 8), # Immaculate Conception
(12, 25), # Christmas
(5, 15), # San Isidro (Madrid)
(5, 15), # San Isidro (Madrid patron saint)
(5, 2), # Madrid Community Day
]
@@ -458,8 +637,8 @@ class BakeryDataProcessor:
if month == 1 and date.day <= 10:
return True
# Easter holidays (approximate - first two weeks of April)
if month == 4 and date.day <= 14:
# Easter holidays (approximate - early April)
if month == 4 and date.day <= 15:
return True
return False
@@ -468,26 +647,89 @@ class BakeryDataProcessor:
model_data: pd.DataFrame,
target_column: str = 'y') -> Dict[str, float]:
"""
Calculate feature importance for the model.
Calculate feature importance for the model using correlation analysis.
"""
try:
# Simple correlation-based importance
# Get numeric features
numeric_features = model_data.select_dtypes(include=[np.number]).columns
numeric_features = [col for col in numeric_features if col != target_column]
importance_scores = {}
if target_column not in model_data.columns:
logger.warning(f"Target column '{target_column}' not found")
return {}
for feature in numeric_features:
if feature in model_data.columns:
correlation = model_data[feature].corr(model_data[target_column])
importance_scores[feature] = abs(correlation) if not pd.isna(correlation) else 0.0
if not pd.isna(correlation) and not np.isinf(correlation):
importance_scores[feature] = abs(correlation)
# Sort by importance
importance_scores = dict(sorted(importance_scores.items(),
key=lambda x: x[1], reverse=True))
logger.info(f"Calculated feature importance for {len(importance_scores)} features")
return importance_scores
except Exception as e:
logger.error(f"Error calculating feature importance: {e}")
return {}
return {}
def get_data_quality_report(self, df: pd.DataFrame) -> Dict[str, Any]:
"""
Generate a comprehensive data quality report.
"""
try:
report = {
"total_records": len(df),
"date_range": {
"start": df['ds'].min().isoformat() if 'ds' in df.columns else None,
"end": df['ds'].max().isoformat() if 'ds' in df.columns else None,
"duration_days": (df['ds'].max() - df['ds'].min()).days if 'ds' in df.columns else 0
},
"missing_values": {},
"data_completeness": 0.0,
"target_statistics": {},
"feature_count": 0
}
# Calculate missing values
missing_counts = df.isnull().sum()
total_cells = len(df)
for col in df.columns:
missing_count = missing_counts[col]
report["missing_values"][col] = {
"count": int(missing_count),
"percentage": round((missing_count / total_cells) * 100, 2)
}
# Overall completeness
total_missing = missing_counts.sum()
total_possible = len(df) * len(df.columns)
report["data_completeness"] = round(((total_possible - total_missing) / total_possible) * 100, 2)
# Target variable statistics
if 'y' in df.columns:
y_col = df['y']
report["target_statistics"] = {
"mean": round(y_col.mean(), 2),
"median": round(y_col.median(), 2),
"std": round(y_col.std(), 2),
"min": round(y_col.min(), 2),
"max": round(y_col.max(), 2),
"zero_count": int((y_col == 0).sum()),
"zero_percentage": round(((y_col == 0).sum() / len(y_col)) * 100, 2)
}
# Feature count
numeric_features = df.select_dtypes(include=[np.number]).columns
report["feature_count"] = len([col for col in numeric_features if col not in ['y', 'ds']])
return report
except Exception as e:
logger.error(f"Error generating data quality report: {e}")
return {"error": str(e)}

View File

@@ -1,24 +1,33 @@
# services/training/app/ml/prophet_manager.py
"""
Enhanced Prophet Manager for Training Service
Migrated from the monolithic backend to microservices architecture
Simplified Prophet Manager with Built-in Hyperparameter Optimization
Direct replacement for existing BakeryProphetManager - optimization always enabled.
"""
from typing import Dict, List, Any, Optional, Tuple
import pandas as pd
import numpy as np
from prophet import Prophet
import pickle
import logging
from datetime import datetime, timedelta
import uuid
import asyncio
import os
import joblib
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import TimeSeriesSplit
import json
from pathlib import Path
import math
import warnings
warnings.filterwarnings('ignore')
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.training import TrainedModel
from app.core.database import get_db_session
# Simple optimization import
import optuna
optuna.logging.set_verbosity(optuna.logging.WARNING)
from app.core.config import settings
@@ -26,15 +35,15 @@ logger = logging.getLogger(__name__)
class BakeryProphetManager:
"""
Enhanced Prophet model manager for the training service.
Handles training, validation, and model persistence for bakery forecasting.
Simplified Prophet Manager with built-in hyperparameter optimization.
Drop-in replacement for the existing manager - optimization runs automatically.
"""
def __init__(self):
def __init__(self, db_session: AsyncSession = None):
self.models = {} # In-memory model storage
self.model_metadata = {} # Store model metadata
self.feature_scalers = {} # Store feature scalers per model
self.db_session = db_session # Add database session
# Ensure model storage directory exists
os.makedirs(settings.MODEL_STORAGE_PATH, exist_ok=True)
@@ -44,19 +53,11 @@ class BakeryProphetManager:
df: pd.DataFrame,
job_id: str) -> Dict[str, Any]:
"""
Train a Prophet model for bakery forecasting with enhanced features.
Args:
tenant_id: Tenant identifier
product_name: Product name
df: Training data with 'ds' and 'y' columns plus regressors
job_id: Training job identifier
Returns:
Dictionary with model information and metrics
Train a Prophet model with automatic hyperparameter optimization.
Same interface as before - optimization happens automatically.
"""
try:
logger.info(f"Training bakery model for tenant {tenant_id}, product {product_name}")
logger.info(f"Training optimized bakery model for {product_name}")
# Validate input data
await self._validate_training_data(df, product_name)
@@ -67,8 +68,12 @@ class BakeryProphetManager:
# Get regressor columns
regressor_columns = self._extract_regressor_columns(prophet_data)
# Initialize Prophet model with bakery-specific settings
model = self._create_prophet_model(regressor_columns)
# Automatically optimize hyperparameters (this is the new part)
logger.info(f"Optimizing hyperparameters for {product_name}...")
best_params = await self._optimize_hyperparameters(prophet_data, product_name, regressor_columns)
# Create optimized Prophet model
model = self._create_optimized_prophet_model(best_params, regressor_columns)
# Add regressors to model
for regressor in regressor_columns:
@@ -78,28 +83,23 @@ class BakeryProphetManager:
# Fit the model
model.fit(prophet_data)
# Generate model ID and store model
# Store model and calculate metrics (same as before)
model_id = f"{job_id}_{product_name}_{uuid.uuid4().hex[:8]}"
model_path = await self._store_model(
tenant_id, product_name, model, model_id, prophet_data, regressor_columns
tenant_id, product_name, model, model_id, prophet_data, regressor_columns, best_params
)
# Calculate training metrics
training_metrics = await self._calculate_training_metrics(model, prophet_data)
# Calculate enhanced training metrics
training_metrics = await self._calculate_training_metrics(model, prophet_data, best_params)
# Prepare model information
# Return same format as before, but with optimization info
model_info = {
"model_id": model_id,
"model_path": model_path,
"type": "prophet",
"type": "prophet_optimized", # Changed from "prophet"
"training_samples": len(prophet_data),
"features": regressor_columns,
"hyperparameters": {
"seasonality_mode": settings.PROPHET_SEASONALITY_MODE,
"daily_seasonality": settings.PROPHET_DAILY_SEASONALITY,
"weekly_seasonality": settings.PROPHET_WEEKLY_SEASONALITY,
"yearly_seasonality": settings.PROPHET_YEARLY_SEASONALITY
},
"hyperparameters": best_params, # Now contains optimized params
"training_metrics": training_metrics,
"trained_at": datetime.now().isoformat(),
"data_period": {
@@ -109,41 +109,491 @@ class BakeryProphetManager:
}
}
logger.info(f"Model trained successfully for {product_name}")
logger.info(f"Optimized model trained successfully for {product_name}. "
f"MAPE: {training_metrics.get('optimized_mape', 'N/A')}%")
return model_info
except Exception as e:
logger.error(f"Failed to train bakery model for {product_name}: {str(e)}")
logger.error(f"Failed to train optimized bakery model for {product_name}: {str(e)}")
raise
async def _optimize_hyperparameters(self,
df: pd.DataFrame,
product_name: str,
regressor_columns: List[str]) -> Dict[str, Any]:
"""
Automatically optimize Prophet hyperparameters using Bayesian optimization.
Simplified - no configuration needed.
"""
# Determine product category automatically
product_category = self._classify_product(product_name, df)
# Set optimization parameters based on category
n_trials = {
'high_volume': 30, # Reduced from 75 for speed
'medium_volume': 25, # Reduced from 50
'low_volume': 20, # Reduced from 30
'intermittent': 15 # Reduced from 25
}.get(product_category, 25)
logger.info(f"Product {product_name} classified as {product_category}, using {n_trials} trials")
# Check data quality and adjust strategy
total_sales = df['y'].sum()
zero_ratio = (df['y'] == 0).sum() / len(df)
mean_sales = df['y'].mean()
non_zero_days = len(df[df['y'] > 0])
logger.info(f"Data analysis for {product_name}: total_sales={total_sales:.1f}, "
f"zero_ratio={zero_ratio:.2f}, mean_sales={mean_sales:.2f}, non_zero_days={non_zero_days}")
# Adjust strategy based on data characteristics
if zero_ratio > 0.8 or non_zero_days < 30:
logger.warning(f"Very sparse data for {product_name}, using minimal optimization")
return {
'changepoint_prior_scale': 0.001,
'seasonality_prior_scale': 0.01,
'holidays_prior_scale': 0.01,
'changepoint_range': 0.8,
'seasonality_mode': 'additive',
'daily_seasonality': False,
'weekly_seasonality': True,
'yearly_seasonality': False
}
elif zero_ratio > 0.6:
logger.info(f"Moderate sparsity for {product_name}, using conservative optimization")
return {
'changepoint_prior_scale': 0.01,
'seasonality_prior_scale': 0.1,
'holidays_prior_scale': 0.1,
'changepoint_range': 0.8,
'seasonality_mode': 'additive',
'daily_seasonality': False,
'weekly_seasonality': True,
'yearly_seasonality': len(df) > 365 # Only if we have enough data
}
# Use unique seed for each product to avoid identical results
product_seed = hash(product_name) % 10000
def objective(trial):
try:
# Sample hyperparameters with product-specific ranges
if product_category == 'high_volume':
# More conservative for high volume (less overfitting)
changepoint_scale_range = (0.001, 0.1)
seasonality_scale_range = (1.0, 10.0)
elif product_category == 'intermittent':
# Very conservative for intermittent
changepoint_scale_range = (0.001, 0.05)
seasonality_scale_range = (0.01, 1.0)
else:
# Default ranges
changepoint_scale_range = (0.001, 0.5)
seasonality_scale_range = (0.01, 10.0)
params = {
'changepoint_prior_scale': trial.suggest_float(
'changepoint_prior_scale',
changepoint_scale_range[0],
changepoint_scale_range[1],
log=True
),
'seasonality_prior_scale': trial.suggest_float(
'seasonality_prior_scale',
seasonality_scale_range[0],
seasonality_scale_range[1],
log=True
),
'holidays_prior_scale': trial.suggest_float('holidays_prior_scale', 0.01, 10.0, log=True),
'changepoint_range': trial.suggest_float('changepoint_range', 0.8, 0.95),
'seasonality_mode': 'additive' if product_category == 'high_volume' else trial.suggest_categorical('seasonality_mode', ['additive', 'multiplicative']),
'daily_seasonality': trial.suggest_categorical('daily_seasonality', [True, False]),
'weekly_seasonality': True, # Always keep weekly
'yearly_seasonality': trial.suggest_categorical('yearly_seasonality', [True, False])
}
# Simple 2-fold cross-validation for speed
tscv = TimeSeriesSplit(n_splits=2)
cv_scores = []
for train_idx, val_idx in tscv.split(df):
train_data = df.iloc[train_idx].copy()
val_data = df.iloc[val_idx].copy()
if len(val_data) < 7: # Need at least a week
continue
try:
# Create and train model
model = Prophet(**params, interval_width=0.8, uncertainty_samples=100)
for regressor in regressor_columns:
if regressor in train_data.columns:
model.add_regressor(regressor)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model.fit(train_data)
# Predict on validation set
future_df = model.make_future_dataframe(periods=0)
for regressor in regressor_columns:
if regressor in df.columns:
future_df[regressor] = df[regressor].values[:len(future_df)]
forecast = model.predict(future_df)
val_predictions = forecast['yhat'].iloc[train_idx[-1]+1:train_idx[-1]+1+len(val_data)]
val_actual = val_data['y'].values
# Calculate MAPE with improved handling for low values
if len(val_predictions) > 0 and len(val_actual) > 0:
# Use MAE for very low sales values to avoid MAPE issues
if val_actual.mean() < 1:
mae = np.mean(np.abs(val_actual - val_predictions.values))
# Convert MAE to percentage-like metric
mape_like = (mae / max(val_actual.mean(), 0.1)) * 100
else:
non_zero_mask = val_actual > 0.1 # Use threshold instead of zero
if np.sum(non_zero_mask) > 0:
mape = np.mean(np.abs((val_actual[non_zero_mask] - val_predictions.values[non_zero_mask]) / val_actual[non_zero_mask])) * 100
mape_like = min(mape, 200) # Cap at 200%
else:
mape_like = 100
if not np.isnan(mape_like) and not np.isinf(mape_like):
cv_scores.append(mape_like)
except Exception as fold_error:
logger.debug(f"Fold failed for {product_name} trial {trial.number}: {str(fold_error)}")
continue
return np.mean(cv_scores) if len(cv_scores) > 0 else 100.0
except Exception as trial_error:
logger.debug(f"Trial {trial.number} failed for {product_name}: {str(trial_error)}")
return 100.0
# Run optimization with product-specific seed
study = optuna.create_study(
direction='minimize',
sampler=optuna.samplers.TPESampler(seed=product_seed) # Unique seed per product
)
study.optimize(objective, n_trials=n_trials, timeout=600, show_progress_bar=False)
# Return best parameters
best_params = study.best_params
best_score = study.best_value
logger.info(f"Optimization completed for {product_name}. Best score: {best_score:.2f}%. "
f"Parameters: {best_params}")
return best_params
def _classify_product(self, product_name: str, sales_data: pd.DataFrame) -> str:
"""Automatically classify product for optimization strategy - improved for bakery data"""
product_lower = product_name.lower()
# Calculate sales statistics
total_sales = sales_data['y'].sum()
mean_sales = sales_data['y'].mean()
zero_ratio = (sales_data['y'] == 0).sum() / len(sales_data)
non_zero_days = len(sales_data[sales_data['y'] > 0])
logger.info(f"Product classification for {product_name}: total_sales={total_sales:.1f}, "
f"mean_sales={mean_sales:.2f}, zero_ratio={zero_ratio:.2f}, non_zero_days={non_zero_days}")
# Improved classification logic for bakery products
# Consider both volume and consistency
# Check for truly intermittent demand (high zero ratio)
if zero_ratio > 0.8 or non_zero_days < 30:
return 'intermittent'
# High volume products (consistent daily sales)
if any(pattern in product_lower for pattern in ['cafe', 'pan', 'bread', 'coffee']):
# Even if absolute volume is low, these are core products
return 'high_volume' if zero_ratio < 0.3 else 'medium_volume'
# Volume-based classification for other products
if mean_sales >= 10 and zero_ratio < 0.4:
return 'high_volume'
elif mean_sales >= 5 and zero_ratio < 0.6:
return 'medium_volume'
elif mean_sales >= 2 and zero_ratio < 0.7:
return 'low_volume'
else:
return 'intermittent'
def _create_optimized_prophet_model(self, optimized_params: Dict[str, Any], regressor_columns: List[str]) -> Prophet:
"""Create Prophet model with optimized parameters"""
holidays = self._get_spanish_holidays()
model = Prophet(
holidays=holidays if not holidays.empty else None,
daily_seasonality=optimized_params.get('daily_seasonality', True),
weekly_seasonality=optimized_params.get('weekly_seasonality', True),
yearly_seasonality=optimized_params.get('yearly_seasonality', True),
seasonality_mode=optimized_params.get('seasonality_mode', 'additive'),
changepoint_prior_scale=optimized_params.get('changepoint_prior_scale', 0.05),
seasonality_prior_scale=optimized_params.get('seasonality_prior_scale', 10.0),
holidays_prior_scale=optimized_params.get('holidays_prior_scale', 10.0),
changepoint_range=optimized_params.get('changepoint_range', 0.8),
interval_width=0.8,
mcmc_samples=0,
uncertainty_samples=1000
)
return model
# All the existing methods remain the same, just with enhanced metrics
async def _calculate_training_metrics(self,
model: Prophet,
training_data: pd.DataFrame,
optimized_params: Dict[str, Any] = None) -> Dict[str, float]:
"""Calculate training metrics with optimization info and improved MAPE handling"""
try:
# Generate in-sample predictions
forecast = model.predict(training_data[['ds'] + [col for col in training_data.columns if col not in ['ds', 'y']]])
# Calculate metrics
y_true = training_data['y'].values
y_pred = forecast['yhat'].values
# Basic metrics
mae = mean_absolute_error(y_true, y_pred)
mse = mean_squared_error(y_true, y_pred)
rmse = np.sqrt(mse)
# Improved MAPE calculation for bakery data
mean_actual = y_true.mean()
median_actual = np.median(y_true[y_true > 0]) if np.any(y_true > 0) else 1.0
# Use different strategies based on sales volume
if mean_actual < 2.0:
# For very low volume products, use normalized MAE
normalized_mae = mae / max(median_actual, 1.0)
mape = min(normalized_mae * 100, 200) # Cap at 200%
logger.info(f"Using normalized MAE for low-volume product (mean={mean_actual:.2f})")
elif mean_actual < 5.0:
# For low-medium volume, use modified MAPE with higher threshold
threshold = 1.0
valid_mask = y_true >= threshold
if np.sum(valid_mask) == 0:
mape = 150.0 # High but not extreme
else:
mape_values = np.abs((y_true[valid_mask] - y_pred[valid_mask]) / y_true[valid_mask])
mape = np.median(mape_values) * 100 # Use median instead of mean to reduce outlier impact
mape = min(mape, 150) # Cap at reasonable level
else:
# Standard MAPE for higher volume products
threshold = 0.5
valid_mask = y_true > threshold
if np.sum(valid_mask) == 0:
mape = 100.0
else:
mape_values = np.abs((y_true[valid_mask] - y_pred[valid_mask]) / y_true[valid_mask])
mape = np.mean(mape_values) * 100
# Cap MAPE at reasonable maximum
if math.isinf(mape) or math.isnan(mape) or mape > 200:
mape = min(200.0, (mae / max(mean_actual, 1.0)) * 100)
# R-squared
ss_res = np.sum((y_true - y_pred) ** 2)
ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
r2 = 1 - (ss_res / ss_tot) if ss_tot != 0 else 0.0
# Calculate realistic improvement estimate based on actual product performance
# Use more granular categories and realistic baselines
total_sales = training_data['y'].sum()
zero_ratio = (training_data['y'] == 0).sum() / len(training_data)
mean_sales = training_data['y'].mean()
non_zero_days = len(training_data[training_data['y'] > 0])
# More nuanced categorization
if zero_ratio > 0.8 or non_zero_days < 30:
category = 'very_sparse'
baseline_mape = 80.0
elif zero_ratio > 0.6:
category = 'sparse'
baseline_mape = 60.0
elif mean_sales >= 10 and zero_ratio < 0.3:
category = 'high_volume'
baseline_mape = 25.0
elif mean_sales >= 5 and zero_ratio < 0.5:
category = 'medium_volume'
baseline_mape = 35.0
else:
category = 'low_volume'
baseline_mape = 45.0
# Calculate improvement - be more conservative
if mape < baseline_mape * 0.8: # Only claim improvement if significant
improvement_pct = (baseline_mape - mape) / baseline_mape * 100
else:
improvement_pct = 0 # No meaningful improvement
# Quality score based on data characteristics
quality_score = max(0.1, min(1.0, (1 - zero_ratio) * (non_zero_days / len(training_data))))
# Enhanced metrics with optimization info
metrics = {
"mae": round(mae, 2),
"mse": round(mse, 2),
"rmse": round(rmse, 2),
"mape": round(mape, 2),
"r2": round(r2, 3),
"optimized": True,
"optimized_mape": round(mape, 2),
"baseline_mape_estimate": round(baseline_mape, 2),
"improvement_estimated": round(improvement_pct, 1),
"product_category": category,
"data_quality_score": round(quality_score, 2),
"mean_sales_volume": round(mean_sales, 2),
"sales_consistency": round(non_zero_days / len(training_data), 2),
"total_demand": round(total_sales, 1)
}
logger.info(f"Training metrics calculated: MAPE={mape:.1f}%, "
f"Category={category}, Improvement={improvement_pct:.1f}%")
return metrics
except Exception as e:
logger.error(f"Error calculating training metrics: {str(e)}")
return {
"mae": 0.0, "mse": 0.0, "rmse": 0.0, "mape": 100.0, "r2": 0.0,
"optimized": False, "improvement_estimated": 0.0
}
async def _store_model(self,
tenant_id: str,
product_name: str,
model: Prophet,
model_id: str,
training_data: pd.DataFrame,
regressor_columns: List[str],
optimized_params: Dict[str, Any] = None,
training_metrics: Dict[str, Any] = None) -> str:
"""Store model with database integration"""
# Create model directory
model_dir = Path(settings.MODEL_STORAGE_PATH) / tenant_id
model_dir.mkdir(parents=True, exist_ok=True)
# Store model file
model_path = model_dir / f"{model_id}.pkl"
joblib.dump(model, model_path)
# Enhanced metadata
metadata = {
"model_id": model_id,
"tenant_id": tenant_id,
"product_name": product_name,
"regressor_columns": regressor_columns,
"training_samples": len(training_data),
"data_period": {
"start_date": training_data['ds'].min().isoformat(),
"end_date": training_data['ds'].max().isoformat()
},
"optimized": True,
"optimized_parameters": optimized_params or {},
"created_at": datetime.now().isoformat(),
"model_type": "prophet_optimized",
"file_path": str(model_path)
}
metadata_path = model_path.with_suffix('.json')
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2, default=str)
# Store in memory
model_key = f"{tenant_id}:{product_name}"
self.models[model_key] = model
self.model_metadata[model_key] = metadata
# 🆕 NEW: Store in database
if self.db_session:
try:
# Deactivate previous models for this product
await self._deactivate_previous_models(tenant_id, product_name)
# Create new database record
db_model = TrainedModel(
id=model_id,
tenant_id=tenant_id,
product_name=product_name,
model_type="prophet_optimized",
job_id=model_id.split('_')[0], # Extract job_id from model_id
model_path=str(model_path),
metadata_path=str(metadata_path),
hyperparameters=optimized_params or {},
features_used=regressor_columns,
is_active=True,
is_production=True, # New models are production-ready
training_start_date=training_data['ds'].min(),
training_end_date=training_data['ds'].max(),
training_samples=len(training_data)
)
# Add training metrics if available
if training_metrics:
db_model.mape = training_metrics.get('mape')
db_model.mae = training_metrics.get('mae')
db_model.rmse = training_metrics.get('rmse')
db_model.r2_score = training_metrics.get('r2')
db_model.data_quality_score = training_metrics.get('data_quality_score')
self.db_session.add(db_model)
await self.db_session.commit()
logger.info(f"Model {model_id} stored in database successfully")
except Exception as e:
logger.error(f"Failed to store model in database: {str(e)}")
await self.db_session.rollback()
# Continue execution - file storage succeeded
logger.info(f"Optimized model stored at: {model_path}")
return str(model_path)
async def _deactivate_previous_models(self, tenant_id: str, product_name: str):
"""Deactivate previous models for the same product"""
if self.db_session:
try:
# Update previous models to inactive
query = """
UPDATE trained_models
SET is_active = false, is_production = false
WHERE tenant_id = :tenant_id AND product_name = :product_name
"""
await self.db_session.execute(query, {
"tenant_id": tenant_id,
"product_name": product_name
})
except Exception as e:
logger.error(f"Failed to deactivate previous models: {str(e)}")
# Keep all existing methods unchanged
async def generate_forecast(self,
model_path: str,
future_dates: pd.DataFrame,
regressor_columns: List[str]) -> pd.DataFrame:
"""
Generate forecast using a stored Prophet model.
Args:
model_path: Path to the stored model
future_dates: DataFrame with future dates and regressors
regressor_columns: List of regressor column names
Returns:
DataFrame with forecast results
"""
"""Generate forecast using stored model (unchanged)"""
try:
# Load the model
model = joblib.load(model_path)
# Validate future data has required regressors
for regressor in regressor_columns:
if regressor not in future_dates.columns:
logger.warning(f"Missing regressor {regressor}, filling with median")
future_dates[regressor] = 0 # Default value
future_dates[regressor] = 0
# Generate forecast
forecast = model.predict(future_dates)
return forecast
except Exception as e:
@@ -151,7 +601,7 @@ class BakeryProphetManager:
raise
async def _validate_training_data(self, df: pd.DataFrame, product_name: str):
"""Validate training data quality"""
"""Validate training data quality (unchanged)"""
if df.empty:
raise ValueError(f"No training data available for {product_name}")
@@ -166,65 +616,47 @@ class BakeryProphetManager:
if missing_columns:
raise ValueError(f"Missing required columns: {missing_columns}")
# Check for valid date range
if df['ds'].isna().any():
raise ValueError("Invalid dates found in training data")
# Check for valid target values
if df['y'].isna().all():
raise ValueError("No valid target values found")
async def _prepare_prophet_data(self, df: pd.DataFrame) -> pd.DataFrame:
"""Prepare data for Prophet training"""
"""Prepare data for Prophet training with timezone handling"""
prophet_data = df.copy()
# Prophet column mapping
if 'date' in prophet_data.columns:
prophet_data['ds'] = prophet_data['date']
if 'quantity' in prophet_data.columns:
prophet_data['y'] = prophet_data['quantity']
# ✅ CRITICAL FIX: Remove timezone from ds column
if 'ds' in prophet_data.columns:
prophet_data['ds'] = pd.to_datetime(prophet_data['ds']).dt.tz_localize(None)
logger.info(f"Removed timezone from ds column")
if 'ds' not in prophet_data.columns:
raise ValueError("Missing 'ds' column in training data")
if 'y' not in prophet_data.columns:
raise ValueError("Missing 'y' column in training data")
# Handle missing values in target
if prophet_data['y'].isna().any():
logger.warning("Filling missing target values with interpolation")
prophet_data['y'] = prophet_data['y'].interpolate(method='linear')
# Convert to datetime and remove timezone information
prophet_data['ds'] = pd.to_datetime(prophet_data['ds'])
# Remove extreme outliers (values > 3 standard deviations)
mean_val = prophet_data['y'].mean()
std_val = prophet_data['y'].std()
# Remove timezone if present (Prophet doesn't support timezones)
if prophet_data['ds'].dt.tz is not None:
logger.info("Removing timezone information from 'ds' column for Prophet compatibility")
prophet_data['ds'] = prophet_data['ds'].dt.tz_localize(None)
if std_val > 0: # Avoid division by zero
lower_bound = mean_val - 3 * std_val
upper_bound = mean_val + 3 * std_val
before_count = len(prophet_data)
prophet_data = prophet_data[
(prophet_data['y'] >= lower_bound) &
(prophet_data['y'] <= upper_bound)
]
after_count = len(prophet_data)
if before_count != after_count:
logger.info(f"Removed {before_count - after_count} outliers")
# Ensure chronological order
# Sort by date and clean data
prophet_data = prophet_data.sort_values('ds').reset_index(drop=True)
prophet_data['y'] = pd.to_numeric(prophet_data['y'], errors='coerce')
prophet_data = prophet_data.dropna(subset=['y'])
# Fill missing values in regressors
numeric_columns = prophet_data.select_dtypes(include=[np.number]).columns
for col in numeric_columns:
if col != 'y' and prophet_data[col].isna().any():
prophet_data[col] = prophet_data[col].fillna(prophet_data[col].median())
# Additional data cleaning for Prophet
# Remove any duplicate dates (keep last occurrence)
prophet_data = prophet_data.drop_duplicates(subset=['ds'], keep='last')
# Ensure y values are non-negative (Prophet works better with non-negative values)
prophet_data['y'] = prophet_data['y'].clip(lower=0)
logger.info(f"Prepared Prophet data: {len(prophet_data)} rows, date range: {prophet_data['ds'].min()} to {prophet_data['ds'].max()}")
return prophet_data
def _extract_regressor_columns(self, df: pd.DataFrame) -> List[str]:
"""Extract regressor columns from the dataframe"""
"""Extract regressor columns (unchanged)"""
excluded_columns = ['ds', 'y']
regressor_columns = []
@@ -235,190 +667,32 @@ class BakeryProphetManager:
logger.info(f"Identified regressor columns: {regressor_columns}")
return regressor_columns
def _create_prophet_model(self, regressor_columns: List[str]) -> Prophet:
"""Create Prophet model with bakery-specific settings"""
# Get Spanish holidays
holidays = self._get_spanish_holidays()
# Bakery-specific Prophet configuration
model = Prophet(
holidays=holidays if not holidays.empty else None,
daily_seasonality=settings.PROPHET_DAILY_SEASONALITY,
weekly_seasonality=settings.PROPHET_WEEKLY_SEASONALITY,
yearly_seasonality=settings.PROPHET_YEARLY_SEASONALITY,
seasonality_mode=settings.PROPHET_SEASONALITY_MODE,
changepoint_prior_scale=0.05, # Conservative changepoint detection
seasonality_prior_scale=10, # Strong seasonality for bakeries
holidays_prior_scale=10, # Strong holiday effects
interval_width=0.8, # 80% confidence intervals
mcmc_samples=0, # Use MAP estimation (faster)
uncertainty_samples=1000 # For uncertainty estimation
)
return model
def _get_spanish_holidays(self) -> pd.DataFrame:
"""Get Spanish holidays for Prophet model"""
"""Get Spanish holidays (unchanged)"""
try:
# Define major Spanish holidays that affect bakery sales
holidays_list = []
years = range(2020, 2030) # Cover training and prediction period
years = range(2020, 2030)
for year in years:
holidays_list.extend([
{'holiday': 'new_year', 'ds': f'{year}-01-01'},
{'holiday': 'epiphany', 'ds': f'{year}-01-06'},
{'holiday': 'may_day', 'ds': f'{year}-05-01'},
{'holiday': 'labor_day', 'ds': f'{year}-05-01'},
{'holiday': 'assumption', 'ds': f'{year}-08-15'},
{'holiday': 'national_day', 'ds': f'{year}-10-12'},
{'holiday': 'all_saints', 'ds': f'{year}-11-01'},
{'holiday': 'constitution', 'ds': f'{year}-12-06'},
{'holiday': 'immaculate', 'ds': f'{year}-12-08'},
{'holiday': 'christmas', 'ds': f'{year}-12-25'},
# Madrid specific holidays
{'holiday': 'madrid_patron', 'ds': f'{year}-05-15'}, # San Isidro
{'holiday': 'madrid_community', 'ds': f'{year}-05-02'},
{'holiday': 'constitution_day', 'ds': f'{year}-12-06'},
{'holiday': 'immaculate_conception', 'ds': f'{year}-12-08'},
{'holiday': 'christmas', 'ds': f'{year}-12-25'}
])
holidays_df = pd.DataFrame(holidays_list)
holidays_df['ds'] = pd.to_datetime(holidays_df['ds'])
return holidays_df
except Exception as e:
logger.warning(f"Error creating holidays dataframe: {e}")
return pd.DataFrame()
async def _store_model(self,
tenant_id: str,
product_name: str,
model: Prophet,
model_id: str,
training_data: pd.DataFrame,
regressor_columns: List[str]) -> str:
"""Store model and metadata to filesystem"""
# Create model filename
model_filename = f"{model_id}_prophet_model.pkl"
model_path = os.path.join(settings.MODEL_STORAGE_PATH, model_filename)
# Store the model
joblib.dump(model, model_path)
# Store metadata
metadata = {
"tenant_id": tenant_id,
"product_name": product_name,
"model_id": model_id,
"regressor_columns": regressor_columns,
"training_samples": len(training_data),
"training_period": {
"start": training_data['ds'].min().isoformat(),
"end": training_data['ds'].max().isoformat()
},
"created_at": datetime.now().isoformat(),
"model_type": "prophet",
"file_path": model_path
}
metadata_path = model_path.replace('.pkl', '_metadata.json')
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
# Store in memory for quick access
model_key = f"{tenant_id}:{product_name}"
self.models[model_key] = model
self.model_metadata[model_key] = metadata
logger.info(f"Model stored at: {model_path}")
return model_path
async def _calculate_training_metrics(self,
model: Prophet,
training_data: pd.DataFrame) -> Dict[str, float]:
"""Calculate training metrics for the model"""
try:
# Generate in-sample predictions
forecast = model.predict(training_data[['ds'] + [col for col in training_data.columns if col not in ['ds', 'y']]])
# Calculate metrics
y_true = training_data['y'].values
y_pred = forecast['yhat'].values
# Basic metrics
mae = mean_absolute_error(y_true, y_pred)
mse = mean_squared_error(y_true, y_pred)
rmse = np.sqrt(mse)
# MAPE (Mean Absolute Percentage Error)
non_zero_mask = y_true != 0
if np.sum(non_zero_mask) == 0:
mape = 0.0 # Return 0 instead of Infinity
if holidays_list:
holidays_df = pd.DataFrame(holidays_list)
holidays_df['ds'] = pd.to_datetime(holidays_df['ds'])
return holidays_df
else:
mape_values = np.abs((y_true[non_zero_mask] - y_pred[non_zero_mask]) / y_true[non_zero_mask])
mape = np.mean(mape_values) * 100
if math.isinf(mape) or math.isnan(mape):
mape = 0.0
# R-squared
r2 = r2_score(y_true, y_pred)
return {
"mae": round(mae, 2),
"mse": round(mse, 2),
"rmse": round(rmse, 2),
"mape": round(mape, 2),
"r2_score": round(r2, 4),
"mean_actual": round(np.mean(y_true), 2),
"mean_predicted": round(np.mean(y_pred), 2)
}
return pd.DataFrame()
except Exception as e:
logger.error(f"Error calculating training metrics: {e}")
return {
"mae": 0.0,
"mse": 0.0,
"rmse": 0.0,
"mape": 0.0,
"r2_score": 0.0,
"mean_actual": 0.0,
"mean_predicted": 0.0
}
def get_model_info(self, tenant_id: str, product_name: str) -> Optional[Dict[str, Any]]:
"""Get model information for a specific tenant and product"""
model_key = f"{tenant_id}:{product_name}"
return self.model_metadata.get(model_key)
def list_models(self, tenant_id: str) -> List[Dict[str, Any]]:
"""List all models for a tenant"""
tenant_models = []
for model_key, metadata in self.model_metadata.items():
if metadata['tenant_id'] == tenant_id:
tenant_models.append(metadata)
return tenant_models
async def cleanup_old_models(self, days_old: int = 30):
"""Clean up old model files"""
try:
cutoff_date = datetime.now() - timedelta(days=days_old)
for model_path in Path(settings.MODEL_STORAGE_PATH).glob("*.pkl"):
# Check file modification time
if model_path.stat().st_mtime < cutoff_date.timestamp():
# Remove model and metadata files
model_path.unlink()
metadata_path = model_path.with_suffix('.json')
if metadata_path.exists():
metadata_path.unlink()
logger.info(f"Cleaned up old model: {model_path}")
except Exception as e:
logger.error(f"Error during model cleanup: {e}")
logger.warning(f"Could not load Spanish holidays: {str(e)}")
return pd.DataFrame()

View File

@@ -1,77 +1,76 @@
# services/training/app/ml/trainer.py
"""
ML Trainer for Training Service
Orchestrates the complete training process
ML Trainer - Main ML pipeline coordinator
Receives prepared data and orchestrates the complete ML training process
"""
from typing import Dict, List, Any, Optional, Tuple
from typing import Dict, List, Any, Optional
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from datetime import datetime
import logging
import asyncio
import uuid
from pathlib import Path
from app.ml.prophet_manager import BakeryProphetManager
from app.ml.data_processor import BakeryDataProcessor
from app.ml.prophet_manager import BakeryProphetManager
from app.services.training_orchestrator import TrainingDataSet
from app.core.config import settings
from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__)
class BakeryMLTrainer:
"""
Main ML trainer that orchestrates the complete training process.
Replaces the old Celery-based training system with clean async implementation.
Main ML trainer that orchestrates the complete ML training pipeline.
Receives prepared TrainingDataSet and coordinates data processing and model training.
"""
def __init__(self):
self.prophet_manager = BakeryProphetManager()
def __init__(self, db_session: AsyncSession = None):
self.data_processor = BakeryDataProcessor()
self.prophet_manager = BakeryProphetManager(db_session=db_session)
async def train_tenant_models(self,
tenant_id: str,
sales_data: List[Dict],
weather_data: List[Dict] = None,
traffic_data: List[Dict] = None,
job_id: str = None) -> Dict[str, Any]:
training_dataset: TrainingDataSet,
job_id: Optional[str] = None) -> Dict[str, Any]:
"""
Train models for all products of a tenant.
Train models for all products using prepared training dataset.
Args:
tenant_id: Tenant identifier
sales_data: Historical sales data
weather_data: Weather data (optional)
traffic_data: Traffic data (optional)
training_dataset: Prepared training dataset with aligned dates
job_id: Training job identifier
Returns:
Dictionary with training results for each product
"""
if not job_id:
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
job_id = f"ml_training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info(f"Starting training job {job_id} for tenant {tenant_id}")
logger.info(f"Starting ML training pipeline {job_id} for tenant {tenant_id}")
try:
# Convert input data to DataFrames
sales_df = pd.DataFrame(sales_data) if sales_data else pd.DataFrame()
weather_df = pd.DataFrame(weather_data) if weather_data else pd.DataFrame()
traffic_df = pd.DataFrame(traffic_data) if traffic_data else pd.DataFrame()
# Convert sales data to DataFrame
sales_df = pd.DataFrame(training_dataset.sales_data)
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
# Validate input data
await self._validate_input_data(sales_df, tenant_id)
# Get unique products
# Get unique products from the sales data
products = sales_df['product_name'].unique().tolist()
logger.info(f"Training models for {len(products)} products: {products}")
# Process data for each product
logger.info("Processing data for all products...")
processed_data = await self._process_all_products(
sales_df, weather_df, traffic_df, products
)
# Train models for each product
# Train models for each processed product
logger.info("Training models for all products...")
training_results = await self._train_all_models(
tenant_id, processed_data, job_id
)
@@ -85,50 +84,56 @@ class BakeryMLTrainer:
"status": "completed",
"products_trained": len([r for r in training_results.values() if r.get('status') == 'success']),
"products_failed": len([r for r in training_results.values() if r.get('status') == 'error']),
"products_skipped": len([r for r in training_results.values() if r.get('status') == 'skipped']),
"total_products": len(products),
"training_results": training_results,
"summary": summary,
"data_info": {
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat(),
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
},
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
},
"completed_at": datetime.now().isoformat()
}
logger.info(f"Training job {job_id} completed successfully")
logger.info(f"ML training pipeline {job_id} completed successfully")
return result
except Exception as e:
logger.error(f"Training job {job_id} failed: {str(e)}")
logger.error(f"ML training pipeline {job_id} failed: {str(e)}")
raise
async def train_single_product(self,
tenant_id: str,
product_name: str,
sales_data: List[Dict],
weather_data: List[Dict] = None,
traffic_data: List[Dict] = None,
job_id: str = None) -> Dict[str, Any]:
async def train_single_product_model(self,
tenant_id: str,
product_name: str,
training_dataset: TrainingDataSet,
job_id: Optional[str] = None) -> Dict[str, Any]:
"""
Train model for a single product.
Train model for a single product using prepared training dataset.
Args:
tenant_id: Tenant identifier
product_name: Product name
sales_data: Historical sales data
weather_data: Weather data (optional)
traffic_data: Traffic data (optional)
training_dataset: Prepared training dataset
job_id: Training job identifier
Returns:
Training result for the product
"""
if not job_id:
job_id = f"training_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
job_id = f"single_ml_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
logger.info(f"Starting single product training {job_id} for {product_name}")
logger.info(f"Starting single product ML training {job_id} for {product_name}")
try:
# Convert input data to DataFrames
sales_df = pd.DataFrame(sales_data) if sales_data else pd.DataFrame()
weather_df = pd.DataFrame(weather_data) if weather_data else pd.DataFrame()
traffic_df = pd.DataFrame(traffic_data) if traffic_data else pd.DataFrame()
# Convert training data to DataFrames
sales_df = pd.DataFrame(training_dataset.sales_data)
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
# Filter sales data for the specific product
product_sales = sales_df[sales_df['product_name'] == product_name].copy()
@@ -137,7 +142,7 @@ class BakeryMLTrainer:
if product_sales.empty:
raise ValueError(f"No sales data found for product: {product_name}")
# Prepare training data
# Process data for this specific product
processed_data = await self.data_processor.prepare_training_data(
sales_data=product_sales,
weather_data=weather_df,
@@ -160,29 +165,38 @@ class BakeryMLTrainer:
"status": "success",
"model_info": model_info,
"data_points": len(processed_data),
"data_info": {
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat(),
"duration_days": (training_dataset.date_range.end - training_dataset.date_range.start).days
},
"data_sources": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
},
"completed_at": datetime.now().isoformat()
}
logger.info(f"Single product training {job_id} completed successfully")
logger.info(f"Single product ML training {job_id} completed successfully")
return result
except Exception as e:
logger.error(f"Single product training {job_id} failed: {str(e)}")
logger.error(f"Single product ML training {job_id} failed: {str(e)}")
raise
async def evaluate_model_performance(self,
tenant_id: str,
product_name: str,
model_path: str,
test_data: List[Dict]) -> Dict[str, Any]:
test_dataset: TrainingDataSet) -> Dict[str, Any]:
"""
Evaluate model performance on test data.
Evaluate model performance using test dataset.
Args:
tenant_id: Tenant identifier
product_name: Product name
model_path: Path to the trained model
test_data: Test data for evaluation
test_dataset: Test dataset for evaluation
Returns:
Performance metrics
@@ -190,46 +204,75 @@ class BakeryMLTrainer:
try:
logger.info(f"Evaluating model performance for {product_name}")
# Convert test data to DataFrame
test_df = pd.DataFrame(test_data)
# Convert test data to DataFrames
test_sales_df = pd.DataFrame(test_dataset.sales_data)
test_weather_df = pd.DataFrame(test_dataset.weather_data)
test_traffic_df = pd.DataFrame(test_dataset.traffic_data)
# Prepare test data
test_prepared = await self.data_processor.prepare_prediction_features(
future_dates=test_df['ds'],
weather_forecast=test_df if 'temperature' in test_df.columns else pd.DataFrame(),
traffic_forecast=test_df if 'traffic_volume' in test_df.columns else pd.DataFrame()
# Filter for specific product
product_test_sales = test_sales_df[test_sales_df['product_name'] == product_name].copy()
if product_test_sales.empty:
raise ValueError(f"No test data found for product: {product_name}")
# Process test data
processed_test_data = await self.data_processor.prepare_training_data(
sales_data=product_test_sales,
weather_data=test_weather_df,
traffic_data=test_traffic_df,
product_name=product_name
)
# Get regressor columns
regressor_columns = [col for col in test_prepared.columns if col not in ['ds', 'y']]
# Create future dataframe for prediction
future_dates = processed_test_data[['ds']].copy()
# Add regressor columns
regressor_columns = [col for col in processed_test_data.columns if col not in ['ds', 'y']]
for col in regressor_columns:
future_dates[col] = processed_test_data[col]
# Generate predictions
forecast = await self.prophet_manager.generate_forecast(
model_path=model_path,
future_dates=test_prepared,
future_dates=future_dates,
regressor_columns=regressor_columns
)
# Calculate performance metrics if we have actual values
metrics = {}
if 'y' in test_df.columns:
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
y_true = test_df['y'].values
y_pred = forecast['yhat'].values
metrics = {
"mae": float(mean_absolute_error(y_true, y_pred)),
"rmse": float(np.sqrt(mean_squared_error(y_true, y_pred))),
"mape": float(np.mean(np.abs((y_true - y_pred) / y_true)) * 100),
"r2_score": float(r2_score(y_true, y_pred))
}
# Calculate performance metrics
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
y_true = processed_test_data['y'].values
y_pred = forecast['yhat'].values
# Ensure arrays are the same length
min_len = min(len(y_true), len(y_pred))
y_true = y_true[:min_len]
y_pred = y_pred[:min_len]
metrics = {
"mae": float(mean_absolute_error(y_true, y_pred)),
"rmse": float(np.sqrt(mean_squared_error(y_true, y_pred))),
"r2_score": float(r2_score(y_true, y_pred))
}
# Calculate MAPE safely
non_zero_mask = y_true > 0.1
if np.sum(non_zero_mask) > 0:
mape = np.mean(np.abs((y_true[non_zero_mask] - y_pred[non_zero_mask]) / y_true[non_zero_mask])) * 100
metrics["mape"] = float(min(mape, 200)) # Cap at 200%
else:
metrics["mape"] = 100.0
result = {
"tenant_id": tenant_id,
"product_name": product_name,
"evaluation_metrics": metrics,
"forecast_samples": len(forecast),
"test_samples": len(processed_test_data),
"prediction_samples": len(forecast),
"test_period": {
"start": test_dataset.date_range.start.isoformat(),
"end": test_dataset.date_range.end.isoformat()
},
"evaluated_at": datetime.now().isoformat()
}
@@ -244,6 +287,7 @@ class BakeryMLTrainer:
if sales_df.empty:
raise ValueError(f"No sales data provided for tenant {tenant_id}")
# Handle quantity column mapping
if 'quantity_sold' in sales_df.columns and 'quantity' not in sales_df.columns:
sales_df['quantity'] = sales_df['quantity_sold']
logger.info("Mapped 'quantity_sold' to 'quantity' column")
@@ -261,14 +305,17 @@ class BakeryMLTrainer:
# Check for valid quantities
if not sales_df['quantity'].dtype in ['int64', 'float64']:
raise ValueError("Quantity column must be numeric")
try:
sales_df['quantity'] = pd.to_numeric(sales_df['quantity'], errors='coerce')
except Exception:
raise ValueError("Quantity column must be numeric")
async def _process_all_products(self,
sales_df: pd.DataFrame,
weather_df: pd.DataFrame,
traffic_df: pd.DataFrame,
products: List[str]) -> Dict[str, pd.DataFrame]:
"""Process data for all products"""
"""Process data for all products using the data processor"""
processed_data = {}
for product_name in products:
@@ -278,7 +325,11 @@ class BakeryMLTrainer:
# Filter sales data for this product
product_sales = sales_df[sales_df['product_name'] == product_name].copy()
# Process the product data
if product_sales.empty:
logger.warning(f"No sales data found for product: {product_name}")
continue
# Use data processor to prepare training data
processed_product_data = await self.data_processor.prepare_training_data(
sales_data=product_sales,
weather_data=weather_df,
@@ -300,7 +351,7 @@ class BakeryMLTrainer:
tenant_id: str,
processed_data: Dict[str, pd.DataFrame],
job_id: str) -> Dict[str, Any]:
"""Train models for all processed products"""
"""Train models for all processed products using Prophet manager"""
training_results = {}
for product_name, product_data in processed_data.items():
@@ -313,11 +364,13 @@ class BakeryMLTrainer:
'status': 'skipped',
'reason': 'insufficient_data',
'data_points': len(product_data),
'min_required': settings.MIN_TRAINING_DATA_DAYS
'min_required': settings.MIN_TRAINING_DATA_DAYS,
'message': f'Need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(product_data)}'
}
logger.warning(f"Skipping {product_name}: insufficient data ({len(product_data)} < {settings.MIN_TRAINING_DATA_DAYS})")
continue
# Train the model
# Train the model using Prophet manager
model_info = await self.prophet_manager.train_bakery_model(
tenant_id=tenant_id,
product_name=product_name,
@@ -339,7 +392,8 @@ class BakeryMLTrainer:
training_results[product_name] = {
'status': 'error',
'error_message': str(e),
'data_points': len(product_data) if product_data is not None else 0
'data_points': len(product_data) if product_data is not None else 0,
'failed_at': datetime.now().isoformat()
}
return training_results
@@ -360,17 +414,27 @@ class BakeryMLTrainer:
if metrics_list and all(metrics_list):
avg_metrics = {
'avg_mae': np.mean([m.get('mae', 0) for m in metrics_list]),
'avg_rmse': np.mean([m.get('rmse', 0) for m in metrics_list]),
'avg_mape': np.mean([m.get('mape', 0) for m in metrics_list]),
'avg_r2': np.mean([m.get('r2_score', 0) for m in metrics_list])
'avg_mae': round(np.mean([m.get('mae', 0) for m in metrics_list]), 2),
'avg_rmse': round(np.mean([m.get('rmse', 0) for m in metrics_list]), 2),
'avg_mape': round(np.mean([m.get('mape', 0) for m in metrics_list]), 2),
'avg_r2': round(np.mean([m.get('r2', 0) for m in metrics_list]), 3),
'avg_improvement': round(np.mean([m.get('improvement_estimated', 0) for m in metrics_list]), 1)
}
# Calculate data quality insights
data_points_list = [r.get('data_points', 0) for r in training_results.values()]
return {
'total_products': total_products,
'successful_products': successful_products,
'failed_products': failed_products,
'skipped_products': skipped_products,
'success_rate': round(successful_products / total_products * 100, 2) if total_products > 0 else 0,
'average_metrics': avg_metrics
'average_metrics': avg_metrics,
'data_summary': {
'total_data_points': sum(data_points_list),
'avg_data_points_per_product': round(np.mean(data_points_list), 1) if data_points_list else 0,
'min_data_points': min(data_points_list) if data_points_list else 0,
'max_data_points': max(data_points_list) if data_points_list else 0
}
}

View File

@@ -37,37 +37,6 @@ class ModelTrainingLog(Base):
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
class TrainedModel(Base):
"""
Table to store information about trained models.
"""
__tablename__ = "trained_models"
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), unique=True, index=True, nullable=False)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
product_name = Column(String(255), index=True, nullable=False)
# Model information
model_type = Column(String(50), nullable=False, default="prophet") # prophet, arima, etc.
model_path = Column(String(1000), nullable=False) # Path to stored model file
version = Column(Integer, nullable=False, default=1)
# Training information
training_samples = Column(Integer, nullable=False, default=0)
features = Column(ARRAY(String), nullable=True) # List of features used
hyperparameters = Column(JSON, nullable=True) # Model hyperparameters
training_metrics = Column(JSON, nullable=True) # Training performance metrics
# Data period information
data_period_start = Column(DateTime, nullable=True)
data_period_end = Column(DateTime, nullable=True)
# Status and metadata
is_active = Column(Boolean, default=True, index=True)
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
class ModelPerformanceMetric(Base):
"""
Table to track model performance over time.
@@ -150,4 +119,73 @@ class ModelArtifact(Base):
# Metadata
created_at = Column(DateTime, default=datetime.now)
expires_at = Column(DateTime, nullable=True) # For automatic cleanup
expires_at = Column(DateTime, nullable=True) # For automatic cleanup
class TrainedModel(Base):
__tablename__ = "trained_models"
# Primary identification
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
tenant_id = Column(String, nullable=False, index=True)
product_name = Column(String, nullable=False, index=True)
# Model information
model_type = Column(String, default="prophet_optimized")
model_version = Column(String, default="1.0")
job_id = Column(String, nullable=False)
# File storage
model_path = Column(String, nullable=False) # Path to the .pkl file
metadata_path = Column(String) # Path to metadata JSON
# Training metrics
mape = Column(Float)
mae = Column(Float)
rmse = Column(Float)
r2_score = Column(Float)
training_samples = Column(Integer)
# Hyperparameters and features
hyperparameters = Column(JSON) # Store optimized parameters
features_used = Column(JSON) # List of regressor columns
# Model status
is_active = Column(Boolean, default=True)
is_production = Column(Boolean, default=False)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
last_used_at = Column(DateTime)
# Training data info
training_start_date = Column(DateTime)
training_end_date = Column(DateTime)
data_quality_score = Column(Float)
# Additional metadata
notes = Column(Text)
created_by = Column(String) # User who triggered training
def to_dict(self):
return {
"id": self.id,
"tenant_id": self.tenant_id,
"product_name": self.product_name,
"model_type": self.model_type,
"model_version": self.model_version,
"model_path": self.model_path,
"mape": self.mape,
"mae": self.mae,
"rmse": self.rmse,
"r2_score": self.r2_score,
"training_samples": self.training_samples,
"hyperparameters": self.hyperparameters,
"features_used": self.features_used,
"is_active": self.is_active,
"is_production": self.is_production,
"created_at": self.created_at.isoformat() if self.created_at else None,
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,
"data_quality_score": self.data_quality_score
}

View File

@@ -23,8 +23,6 @@ class TrainingStatus(str, Enum):
class TrainingJobRequest(BaseModel):
"""Request schema for starting a training job"""
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, trains all)")
include_weather: bool = Field(True, description="Include weather data in training")
include_traffic: bool = Field(True, description="Include traffic data in training")
start_date: Optional[datetime] = Field(None, description="Start date for training data")
end_date: Optional[datetime] = Field(None, description="End date for training data")
@@ -48,8 +46,6 @@ class TrainingJobRequest(BaseModel):
class SingleProductTrainingRequest(BaseModel):
"""Request schema for training a single product"""
include_weather: bool = Field(True, description="Include weather data in training")
include_traffic: bool = Field(True, description="Include traffic data in training")
start_date: Optional[datetime] = Field(None, description="Start date for training data")
end_date: Optional[datetime] = Field(None, description="End date for training data")

View File

@@ -0,0 +1,240 @@
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
import logging
logger = logging.getLogger(__name__)
class DataSourceType(Enum):
BAKERY_SALES = "bakery_sales"
MADRID_TRAFFIC = "madrid_traffic"
WEATHER_FORECAST = "weather_forecast"
@dataclass
class DateRange:
start: datetime
end: datetime
source: DataSourceType
def duration_days(self) -> int:
return (self.end - self.start).days
def overlaps_with(self, other: 'DateRange') -> bool:
return self.start <= other.end and other.start <= self.end
@dataclass
class AlignedDateRange:
start: datetime
end: datetime
available_sources: List[DataSourceType]
constraints: Dict[str, str]
class DateAlignmentService:
"""
Central service for managing and aligning dates across multiple data sources
for the bakery sales prediction model.
"""
def __init__(self):
self.MAX_TRAINING_RANGE_DAYS = 365 # Maximum training data range
self.MIN_TRAINING_RANGE_DAYS = 30 # Minimum viable training data
def validate_and_align_dates(
self,
user_sales_range: DateRange,
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None
) -> AlignedDateRange:
"""
Main method to validate and align dates across all data sources.
Args:
user_sales_range: Date range of user-provided sales data
requested_start: Optional explicit start date for training
requested_end: Optional explicit end date for training
Returns:
AlignedDateRange with validated start/end dates and available sources
"""
try:
# Step 1: Determine the base date range
base_range = self._determine_base_range(
user_sales_range, requested_start, requested_end
)
# Step 2: Apply data source constraints
aligned_range = self._apply_data_source_constraints(base_range)
# Step 3: Validate final range
self._validate_final_range(aligned_range)
logger.info(f"Date alignment completed: {aligned_range.start} to {aligned_range.end}")
return aligned_range
except Exception as e:
logger.error(f"Date alignment failed: {str(e)}")
raise ValueError(f"Unable to align dates: {str(e)}")
def _determine_base_range(
self,
user_sales_range: DateRange,
requested_start: Optional[datetime],
requested_end: Optional[datetime]
) -> DateRange:
"""Determine the base date range for training."""
# Use explicit dates if provided
if requested_start and requested_end:
if requested_end <= requested_start:
raise ValueError("End date must be after start date")
return DateRange(requested_start, requested_end, DataSourceType.BAKERY_SALES)
# Otherwise, use the user's sales data range as the foundation
start_date = requested_start or user_sales_range.start
end_date = requested_end or user_sales_range.end
# Ensure we don't exceed maximum training range
if (end_date - start_date).days > self.MAX_TRAINING_RANGE_DAYS:
start_date = end_date - timedelta(days=self.MAX_TRAINING_RANGE_DAYS)
logger.warning(f"Limiting training range to {self.MAX_TRAINING_RANGE_DAYS} days")
return DateRange(start_date, end_date, DataSourceType.BAKERY_SALES)
def _apply_data_source_constraints(self, base_range: DateRange) -> AlignedDateRange:
"""Apply constraints from each data source and determine final aligned range."""
current_month = datetime.now().replace(day=1, hour=0, minute=0, second=0, microsecond=0)
available_sources = [DataSourceType.BAKERY_SALES] # Always have sales data
constraints = {}
# Madrid Traffic Data Constraint
madrid_end_date = self._get_madrid_traffic_end_date()
if base_range.end > madrid_end_date:
# If requested end date is in current month, adjust it
new_end = madrid_end_date
constraints["madrid_traffic"] = f"Adjusted end date to {new_end.date()} (latest available traffic data)"
logger.info(f"Madrid traffic constraint: end date adjusted to {new_end.date()}")
else:
new_end = base_range.end
available_sources.append(DataSourceType.MADRID_TRAFFIC)
# Weather Forecast Constraint
# Weather data available from yesterday backward
weather_end_date = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=1)
if base_range.end > weather_end_date:
if new_end > weather_end_date:
new_end = weather_end_date
constraints["weather"] = f"Adjusted end date to {new_end.date()} (latest available weather data)"
logger.info(f"Weather constraint: end date adjusted to {new_end.date()}")
if new_end >= base_range.start:
available_sources.append(DataSourceType.WEATHER_FORECAST)
# Ensure minimum training period
final_start = base_range.start
if (new_end - final_start).days < self.MIN_TRAINING_RANGE_DAYS:
final_start = new_end - timedelta(days=self.MIN_TRAINING_RANGE_DAYS)
constraints["minimum_period"] = f"Adjusted start date to ensure {self.MIN_TRAINING_RANGE_DAYS} day minimum training period"
logger.info(f"Minimum period constraint: start date adjusted to {final_start.date()}")
return AlignedDateRange(
start=final_start,
end=new_end,
available_sources=available_sources,
constraints=constraints
)
def _get_madrid_traffic_end_date(self) -> datetime:
"""
Get the latest available date for Madrid traffic data.
Data for current month is not available until the following month.
"""
now = datetime.now()
if now.day == 1:
# If it's the first day of the month, data up to previous month should be available
last_available_month = now.replace(day=1) - timedelta(days=1)
else:
# Data up to the previous month is available
last_available_month = now.replace(day=1) - timedelta(days=1)
# Return the last day of the last available month
if last_available_month.month == 12:
next_month = last_available_month.replace(year=last_available_month.year + 1, month=1)
else:
next_month = last_available_month.replace(month=last_available_month.month + 1)
return next_month - timedelta(days=1)
def _validate_final_range(self, aligned_range: AlignedDateRange) -> None:
"""Validate the final aligned date range."""
if aligned_range.start >= aligned_range.end:
raise ValueError("Invalid date range: start date must be before end date")
duration = (aligned_range.end - aligned_range.start).days
if duration < self.MIN_TRAINING_RANGE_DAYS:
raise ValueError(f"Insufficient training data: {duration} days (minimum: {self.MIN_TRAINING_RANGE_DAYS})")
if duration > self.MAX_TRAINING_RANGE_DAYS:
raise ValueError(f"Training period too long: {duration} days (maximum: {self.MAX_TRAINING_RANGE_DAYS})")
# Ensure we have at least sales data
if DataSourceType.BAKERY_SALES not in aligned_range.available_sources:
raise ValueError("No sales data available for the aligned date range")
def get_data_collection_plan(self, aligned_range: AlignedDateRange) -> Dict[str, Dict]:
"""
Generate a data collection plan based on the aligned date range.
Returns:
Dictionary with collection plans for each data source
"""
plan = {}
# Bakery Sales Data
if DataSourceType.BAKERY_SALES in aligned_range.available_sources:
plan["sales_data"] = {
"start_date": aligned_range.start,
"end_date": aligned_range.end,
"source": "user_upload",
"required": True
}
# Madrid Traffic Data
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
plan["traffic_data"] = {
"start_date": aligned_range.start,
"end_date": aligned_range.end,
"source": "madrid_opendata",
"required": False,
"constraint": "Cannot request current month data"
}
# Weather Data
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
plan["weather_data"] = {
"start_date": aligned_range.start,
"end_date": aligned_range.end,
"source": "aemet_api",
"required": False,
"constraint": "Available from yesterday backward"
}
return plan
def check_madrid_current_month_constraint(self, end_date: datetime) -> bool:
"""
Check if the end date violates the Madrid Open Data current month constraint.
Args:
end_date: The requested end date
Returns:
True if the constraint is violated (end date is in current month)
"""
now = datetime.now()
current_month_start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return end_date >= current_month_start

View File

@@ -0,0 +1,706 @@
# services/training/app/services/training_orchestrator.py
"""
Training Data Orchestrator - Enhanced Integration Layer
Orchestrates data collection, date alignment, and preparation for ML training
"""
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from app.services.data_client import DataServiceClient
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
logger = logging.getLogger(__name__)
@dataclass
class TrainingDataSet:
"""Container for all training data with metadata"""
sales_data: List[Dict[str, Any]]
weather_data: List[Dict[str, Any]]
traffic_data: List[Dict[str, Any]]
date_range: AlignedDateRange
metadata: Dict[str, Any]
class TrainingDataOrchestrator:
"""
Enhanced orchestrator for data collection from multiple sources.
Ensures date alignment, handles data source constraints, and prepares data for ML training.
"""
def __init__(self,
madrid_client=None,
weather_client=None,
date_alignment_service: DateAlignmentService = None):
self.madrid_client = madrid_client
self.weather_client = weather_client
self.data_client = DataServiceClient()
self.date_alignment_service = date_alignment_service or DateAlignmentService()
self.max_concurrent_requests = 3
async def prepare_training_data(
self,
tenant_id: str,
bakery_location: Tuple[float, float], # (lat, lon)
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None,
job_id: Optional[str] = None
) -> TrainingDataSet:
"""
Main method to prepare all training data with comprehensive date alignment.
Args:
tenant_id: Tenant identifier
sales_data: User-provided sales data
bakery_location: Bakery coordinates (lat, lon)
requested_start: Optional explicit start date
requested_end: Optional explicit end date
job_id: Training job identifier for logging
Returns:
TrainingDataSet with all aligned and validated data
"""
logger.info(f"Starting comprehensive training data preparation for tenant {tenant_id}, job {job_id}")
try:
sales_data = self.data_client.fetch_sales_data(tenant_id)
# Step 1: Extract and validate sales data date range
sales_date_range = self._extract_sales_date_range(sales_data)
logger.info(f"Sales data range detected: {sales_date_range.start} to {sales_date_range.end}")
# Step 2: Apply date alignment across all data sources
aligned_range = self.date_alignment_service.validate_and_align_dates(
user_sales_range=sales_date_range,
requested_start=requested_start,
requested_end=requested_end
)
logger.info(f"Date alignment completed: {aligned_range.start} to {aligned_range.end}")
if aligned_range.constraints:
logger.info(f"Applied constraints: {aligned_range.constraints}")
# Step 3: Filter sales data to aligned date range
filtered_sales = self._filter_sales_data(sales_data, aligned_range)
# Step 4: Collect external data sources concurrently
logger.info("Collecting external data sources...")
weather_data, traffic_data = await self._collect_external_data(
aligned_range, bakery_location
)
# Step 5: Validate data quality
data_quality_results = self._validate_data_sources(
filtered_sales, weather_data, traffic_data, aligned_range
)
# Step 6: Create comprehensive training dataset
training_dataset = TrainingDataSet(
sales_data=filtered_sales,
weather_data=weather_data,
traffic_data=traffic_data,
date_range=aligned_range,
metadata={
"tenant_id": tenant_id,
"job_id": job_id,
"bakery_location": bakery_location,
"data_sources_used": aligned_range.available_sources,
"constraints_applied": aligned_range.constraints,
"data_quality": data_quality_results,
"preparation_timestamp": datetime.now().isoformat(),
"original_sales_range": {
"start": sales_date_range.start.isoformat(),
"end": sales_date_range.end.isoformat()
}
}
)
# Step 7: Final validation
final_validation = self.validate_training_data_quality(training_dataset)
training_dataset.metadata["final_validation"] = final_validation
logger.info(f"Training data preparation completed successfully:")
logger.info(f" - Sales records: {len(filtered_sales)}")
logger.info(f" - Weather records: {len(weather_data)}")
logger.info(f" - Traffic records: {len(traffic_data)}")
logger.info(f" - Data quality score: {final_validation.get('data_quality_score', 'N/A')}")
return training_dataset
except Exception as e:
logger.error(f"Training data preparation failed: {str(e)}")
raise ValueError(f"Failed to prepare training data: {str(e)}")
def _extract_sales_date_range(self, sales_data: List[Dict[str, Any]]) -> DateRange:
"""Extract and validate the date range from sales data"""
if not sales_data:
raise ValueError("No sales data provided")
dates = []
valid_records = 0
for record in sales_data:
try:
if 'date' in record:
date_val = record['date']
if isinstance(date_val, str):
# Handle various date formats
if 'T' in date_val:
date_val = date_val.replace('Z', '+00:00')
parsed_date = datetime.fromisoformat(date_val.split('T')[0])
elif isinstance(date_val, datetime):
parsed_date = date_val
else:
continue
dates.append(parsed_date)
valid_records += 1
except (ValueError, TypeError) as e:
logger.warning(f"Invalid date in sales record: {record.get('date', 'N/A')} - {str(e)}")
continue
if not dates:
raise ValueError("No valid dates found in sales data")
logger.info(f"Processed {valid_records} valid date records from {len(sales_data)} total records")
return DateRange(
start=min(dates),
end=max(dates),
source=DataSourceType.BAKERY_SALES
)
def _filter_sales_data(
self,
sales_data: List[Dict[str, Any]],
aligned_range: AlignedDateRange
) -> List[Dict[str, Any]]:
"""Filter sales data to the aligned date range with enhanced validation"""
filtered_data = []
filtered_count = 0
for record in sales_data:
try:
if 'date' in record:
record_date = record['date']
if isinstance(record_date, str):
if 'T' in record_date:
record_date = record_date.replace('Z', '+00:00')
record_date = datetime.fromisoformat(record_date.split('T')[0])
elif isinstance(record_date, datetime):
record_date = record_date.replace(hour=0, minute=0, second=0, microsecond=0)
# Check if date falls within aligned range
if aligned_range.start <= record_date <= aligned_range.end:
# Validate that record has required fields
if self._validate_sales_record(record):
filtered_data.append(record)
else:
filtered_count += 1
except Exception as e:
logger.warning(f"Error processing sales record: {str(e)}")
filtered_count += 1
continue
logger.info(f"Filtered sales data: {len(filtered_data)} records in aligned range")
if filtered_count > 0:
logger.warning(f"Filtered out {filtered_count} invalid records")
return filtered_data
def _validate_sales_record(self, record: Dict[str, Any]) -> bool:
"""Validate individual sales record"""
required_fields = ['date', 'product_name']
quantity_fields = ['quantity', 'quantity_sold', 'sales', 'units_sold']
# Check required fields
for field in required_fields:
if field not in record or record[field] is None:
return False
# Check at least one quantity field exists
has_quantity = any(field in record and record[field] is not None for field in quantity_fields)
if not has_quantity:
return False
# Validate quantity is numeric and non-negative
for field in quantity_fields:
if field in record and record[field] is not None:
try:
quantity = float(record[field])
if quantity < 0:
return False
except (ValueError, TypeError):
return False
break
return True
async def _collect_external_data(
self,
aligned_range: AlignedDateRange,
bakery_location: Tuple[float, float]
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Collect weather and traffic data concurrently with enhanced error handling"""
lat, lon = bakery_location
# Create collection tasks with timeout
tasks = []
# Weather data collection
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
weather_task = asyncio.create_task(
self._collect_weather_data_with_timeout(lat, lon, aligned_range)
)
tasks.append(("weather", weather_task))
# Traffic data collection
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
traffic_task = asyncio.create_task(
self._collect_traffic_data_with_timeout(lat, lon, aligned_range)
)
tasks.append(("traffic", traffic_task))
# Execute tasks concurrently with proper error handling
results = {}
if tasks:
try:
completed_tasks = await asyncio.gather(
*[task for _, task in tasks],
return_exceptions=True
)
for i, (task_name, _) in enumerate(tasks):
result = completed_tasks[i]
if isinstance(result, Exception):
logger.warning(f"{task_name} data collection failed: {result}")
results[task_name] = []
else:
results[task_name] = result
logger.info(f"{task_name} data collection completed: {len(result)} records")
except Exception as e:
logger.error(f"Error in concurrent data collection: {str(e)}")
results = {"weather": [], "traffic": []}
weather_data = results.get("weather", [])
traffic_data = results.get("traffic", [])
return weather_data, traffic_data
async def _collect_weather_data_with_timeout(
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange
) -> List[Dict[str, Any]]:
"""Collect weather data with timeout and fallback"""
try:
if not self.weather_client:
logger.info("Weather client not configured, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
weather_data = await asyncio.wait_for(
self.data_client.fetch_weather_data(aligned_range.start, aligned_range.end, lat, lon),
)
# Validate weather data
if self._validate_weather_data(weather_data):
logger.info(f"Collected {len(weather_data)} valid weather records")
return weather_data
else:
logger.warning("Invalid weather data received, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
except asyncio.TimeoutError:
logger.warning(f"Weather data collection timed out after {timeout_seconds}s, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
except Exception as e:
logger.warning(f"Weather data collection failed: {e}, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
async def _collect_traffic_data_with_timeout(
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange
) -> List[Dict[str, Any]]:
"""Collect traffic data with timeout and Madrid constraint validation"""
try:
if not self.madrid_client:
logger.info("Madrid client not configured, no traffic data available")
return []
# Double-check Madrid constraint before making request
if self.date_alignment_service.check_madrid_current_month_constraint(aligned_range.end):
logger.warning("Madrid current month constraint violation, no traffic data available")
return []
traffic_data = await asyncio.wait_for(
self.data_client.fetch_traffic_data(aligned_range.start, aligned_range.end, lat, lon),
)
# Validate traffic data
if self._validate_traffic_data(traffic_data):
logger.info(f"Collected {len(traffic_data)} valid traffic records")
return traffic_data
else:
logger.warning("Invalid traffic data received")
return []
except asyncio.TimeoutError:
logger.warning(f"Traffic data collection timed out after {timeout_seconds}s")
return []
except Exception as e:
logger.warning(f"Traffic data collection failed: {e}")
return []
def _validate_weather_data(self, weather_data: List[Dict[str, Any]]) -> bool:
"""Validate weather data quality"""
if not weather_data:
return False
required_fields = ['date']
weather_fields = ['temperature', 'temp', 'temperatura', 'precipitation', 'rain', 'lluvia']
valid_records = 0
for record in weather_data:
# Check required fields
if not all(field in record for field in required_fields):
continue
# Check at least one weather field exists
if any(field in record and record[field] is not None for field in weather_fields):
valid_records += 1
# Consider valid if at least 50% of records are valid
validity_threshold = 0.5
is_valid = (valid_records / len(weather_data)) >= validity_threshold
if not is_valid:
logger.warning(f"Weather data validation failed: {valid_records}/{len(weather_data)} valid records")
return is_valid
def _validate_traffic_data(self, traffic_data: List[Dict[str, Any]]) -> bool:
"""Validate traffic data quality"""
if not traffic_data:
return False
required_fields = ['date']
traffic_fields = ['traffic_volume', 'traffic_intensity', 'intensidad', 'trafico']
valid_records = 0
for record in traffic_data:
# Check required fields
if not all(field in record for field in required_fields):
continue
# Check at least one traffic field exists
if any(field in record and record[field] is not None for field in traffic_fields):
valid_records += 1
# Consider valid if at least 30% of records are valid (traffic data is often sparse)
validity_threshold = 0.3
is_valid = (valid_records / len(traffic_data)) >= validity_threshold
if not is_valid:
logger.warning(f"Traffic data validation failed: {valid_records}/{len(traffic_data)} valid records")
return is_valid
def _validate_data_sources(
self,
sales_data: List[Dict[str, Any]],
weather_data: List[Dict[str, Any]],
traffic_data: List[Dict[str, Any]],
aligned_range: AlignedDateRange
) -> Dict[str, Any]:
"""Validate all data sources and provide quality metrics"""
validation_results = {
"sales_data": {
"record_count": len(sales_data),
"is_valid": len(sales_data) > 0,
"coverage_days": (aligned_range.end - aligned_range.start).days,
"quality_score": 0.0
},
"weather_data": {
"record_count": len(weather_data),
"is_valid": self._validate_weather_data(weather_data) if weather_data else False,
"quality_score": 0.0
},
"traffic_data": {
"record_count": len(traffic_data),
"is_valid": self._validate_traffic_data(traffic_data) if traffic_data else False,
"quality_score": 0.0
},
"overall_quality_score": 0.0
}
# Calculate quality scores
# Sales data quality (most important)
if validation_results["sales_data"]["record_count"] > 0:
coverage_ratio = min(1.0, validation_results["sales_data"]["record_count"] / validation_results["sales_data"]["coverage_days"])
validation_results["sales_data"]["quality_score"] = coverage_ratio * 100
# Weather data quality
if validation_results["weather_data"]["record_count"] > 0:
expected_weather_records = (aligned_range.end - aligned_range.start).days
coverage_ratio = min(1.0, validation_results["weather_data"]["record_count"] / expected_weather_records)
validation_results["weather_data"]["quality_score"] = coverage_ratio * 100
# Traffic data quality
if validation_results["traffic_data"]["record_count"] > 0:
expected_traffic_records = (aligned_range.end - aligned_range.start).days
coverage_ratio = min(1.0, validation_results["traffic_data"]["record_count"] / expected_traffic_records)
validation_results["traffic_data"]["quality_score"] = coverage_ratio * 100
# Overall quality score (weighted by importance)
weights = {"sales_data": 0.7, "weather_data": 0.2, "traffic_data": 0.1}
overall_score = sum(
validation_results[source]["quality_score"] * weight
for source, weight in weights.items()
)
validation_results["overall_quality_score"] = round(overall_score, 2)
return validation_results
def _generate_synthetic_weather_data(
self,
aligned_range: AlignedDateRange
) -> List[Dict[str, Any]]:
"""Generate realistic synthetic weather data for Madrid"""
synthetic_data = []
current_date = aligned_range.start
# Madrid seasonal temperature patterns
seasonal_temps = {
1: 9, 2: 11, 3: 15, 4: 17, 5: 21, 6: 26,
7: 29, 8: 28, 9: 24, 10: 18, 11: 12, 12: 9
}
while current_date <= aligned_range.end:
month = current_date.month
base_temp = seasonal_temps.get(month, 15)
# Add some realistic variation
import random
temp_variation = random.gauss(0, 3) # ±3°C variation
temperature = max(0, base_temp + temp_variation)
# Precipitation patterns (Madrid is relatively dry)
precipitation = 0.0
if random.random() < 0.15: # 15% chance of rain
precipitation = random.uniform(0.1, 15.0)
synthetic_data.append({
"date": current_date,
"temperature": round(temperature, 1),
"precipitation": round(precipitation, 1),
"humidity": round(random.uniform(40, 80), 1),
"wind_speed": round(random.uniform(2, 15), 1),
"pressure": round(random.uniform(1005, 1025), 1),
"source": "synthetic_madrid_pattern"
})
current_date = current_date + timedelta(days=1)
logger.info(f"Generated {len(synthetic_data)} synthetic weather records with Madrid patterns")
return synthetic_data
def validate_training_data_quality(self, dataset: TrainingDataSet) -> Dict[str, Any]:
"""Enhanced validation of training data quality"""
validation_results = {
"is_valid": True,
"warnings": [],
"errors": [],
"data_quality_score": 100.0,
"recommendations": []
}
# Check sales data completeness
sales_count = len(dataset.sales_data)
if sales_count < 30:
validation_results["warnings"].append(
f"Limited sales data: {sales_count} records (recommended: 30+)"
)
validation_results["data_quality_score"] -= 20
validation_results["recommendations"].append("Consider collecting more historical sales data")
elif sales_count < 90:
validation_results["warnings"].append(
f"Moderate sales data: {sales_count} records (optimal: 90+)"
)
validation_results["data_quality_score"] -= 10
# Check date coverage
date_coverage = (dataset.date_range.end - dataset.date_range.start).days
if date_coverage < 90:
validation_results["warnings"].append(
f"Limited date coverage: {date_coverage} days (recommended: 90+)"
)
validation_results["data_quality_score"] -= 15
validation_results["recommendations"].append("Extend date range for better seasonality detection")
# Check external data availability
if not dataset.weather_data:
validation_results["warnings"].append("No weather data available")
validation_results["data_quality_score"] -= 10
validation_results["recommendations"].append("Weather data improves forecast accuracy")
elif len(dataset.weather_data) < date_coverage * 0.5:
validation_results["warnings"].append("Sparse weather data coverage")
validation_results["data_quality_score"] -= 5
if not dataset.traffic_data:
validation_results["warnings"].append("No traffic data available")
validation_results["data_quality_score"] -= 5
validation_results["recommendations"].append("Traffic data can help with location-based patterns")
# Check data consistency
unique_products = set()
for record in dataset.sales_data:
if 'product_name' in record:
unique_products.add(record['product_name'])
if len(unique_products) == 0:
validation_results["errors"].append("No product names found in sales data")
validation_results["is_valid"] = False
elif len(unique_products) > 50:
validation_results["warnings"].append(
f"Many products detected ({len(unique_products)}). Consider training models in batches."
)
validation_results["recommendations"].append("Group similar products for better training efficiency")
# Check for data source constraints
if dataset.date_range.constraints:
constraint_info = []
for constraint_type, message in dataset.date_range.constraints.items():
constraint_info.append(f"{constraint_type}: {message}")
validation_results["warnings"].append(
f"Data source constraints applied: {'; '.join(constraint_info)}"
)
# Final validation
if validation_results["errors"]:
validation_results["is_valid"] = False
validation_results["data_quality_score"] = 0.0
# Ensure score doesn't go below 0
validation_results["data_quality_score"] = max(0.0, validation_results["data_quality_score"])
# Add quality assessment
score = validation_results["data_quality_score"]
if score >= 80:
validation_results["quality_assessment"] = "Excellent"
elif score >= 60:
validation_results["quality_assessment"] = "Good"
elif score >= 40:
validation_results["quality_assessment"] = "Fair"
else:
validation_results["quality_assessment"] = "Poor"
return validation_results
def get_data_collection_plan(self, aligned_range: AlignedDateRange) -> Dict[str, Dict]:
"""
Generate an enhanced data collection plan based on the aligned date range.
"""
plan = {
"collection_summary": {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"duration_days": (aligned_range.end - aligned_range.start).days,
"available_sources": [source.value for source in aligned_range.available_sources],
"constraints": aligned_range.constraints
},
"data_sources": {}
}
# Bakery Sales Data
if DataSourceType.BAKERY_SALES in aligned_range.available_sources:
plan["data_sources"]["sales_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "user_upload",
"required": True,
"priority": "high",
"expected_records": "variable",
"data_points": ["date", "product_name", "quantity"],
"validation": "required_fields_check"
}
# Madrid Traffic Data
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
plan["data_sources"]["traffic_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "madrid_opendata",
"required": False,
"priority": "medium",
"expected_records": (aligned_range.end - aligned_range.start).days,
"constraint": "Cannot request current month data",
"data_points": ["date", "traffic_volume", "congestion_level"],
"validation": "date_constraint_check"
}
# Weather Data
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
plan["data_sources"]["weather_data"] = {
"start_date": aligned_range.start.isoformat(),
"end_date": aligned_range.end.isoformat(),
"source": "aemet_api",
"required": False,
"priority": "high",
"expected_records": (aligned_range.end - aligned_range.start).days,
"constraint": "Available from yesterday backward",
"data_points": ["date", "temperature", "precipitation", "humidity"],
"validation": "temporal_constraint_check",
"fallback": "synthetic_madrid_weather"
}
return plan
def get_orchestration_summary(self, dataset: TrainingDataSet) -> Dict[str, Any]:
"""
Generate a comprehensive summary of the orchestration process.
"""
return {
"tenant_id": dataset.metadata.get("tenant_id"),
"job_id": dataset.metadata.get("job_id"),
"orchestration_completed_at": dataset.metadata.get("preparation_timestamp"),
"data_alignment": {
"original_range": dataset.metadata.get("original_sales_range"),
"aligned_range": {
"start": dataset.date_range.start.isoformat(),
"end": dataset.date_range.end.isoformat(),
"duration_days": (dataset.date_range.end - dataset.date_range.start).days
},
"constraints_applied": dataset.date_range.constraints,
"available_sources": [source.value for source in dataset.date_range.available_sources]
},
"data_collection_results": {
"sales_records": len(dataset.sales_data),
"weather_records": len(dataset.weather_data),
"traffic_records": len(dataset.traffic_data),
"total_records": len(dataset.sales_data) + len(dataset.weather_data) + len(dataset.traffic_data)
},
"data_quality": dataset.metadata.get("data_quality", {}),
"validation_results": dataset.metadata.get("final_validation", {}),
"processing_metadata": {
"bakery_location": dataset.metadata.get("bakery_location"),
"data_sources_requested": len(dataset.date_range.available_sources),
"data_sources_successful": sum([
1 if len(dataset.sales_data) > 0 else 0,
1 if len(dataset.weather_data) > 0 else 0,
1 if len(dataset.traffic_data) > 0 else 0
])
}
}

View File

@@ -1,721 +1,303 @@
# services/training/app/services/training_service.py
"""
Training service business logic
Orchestrates ML training operations and manages job lifecycle
Main Training Service - Coordinates the complete training process
This is the entry point from the API layer
"""
from typing import Dict, List, Any, Optional
import logging
from datetime import datetime, timedelta
import asyncio
import uuid
import logging
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, and_
import httpx
from app.models.training import ModelTrainingLog, TrainedModel
from app.ml.trainer import BakeryMLTrainer
from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
from app.services.messaging import publish_job_completed, publish_job_failed
from app.core.config import settings
from shared.monitoring.metrics import MetricsCollector
from app.services.data_client import DataServiceClient
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType
from app.services.training_orchestrator import TrainingDataOrchestrator
from app.core.database import get_db_session
logger = logging.getLogger(__name__)
metrics = MetricsCollector("training-service")
class TrainingService:
"""
Main service class for managing ML training operations.
Replaces the old Celery-based training system with clean async implementation.
Main training service that coordinates the complete training pipeline.
Entry point from API layer - handles business logic and orchestration.
"""
def __init__(self):
self.ml_trainer = BakeryMLTrainer()
self.data_client = DataServiceClient()
async def _determine_sales_date_range(self, sales_data: List[Dict]) -> tuple[datetime, datetime]:
"""Determine start and end dates from sales data with validation"""
if not sales_data:
raise ValueError("No sales data available to determine date range")
dates = []
for record in sales_data:
if 'date' in record:
try:
if isinstance(record['date'], str):
# Handle various date string formats
date_str = record['date'].replace('Z', '+00:00')
if 'T' in date_str:
parsed_date = datetime.fromisoformat(date_str)
else:
# Handle date-only strings
parsed_date = datetime.strptime(date_str, '%Y-%m-%d')
dates.append(parsed_date)
elif isinstance(record['date'], datetime):
dates.append(record['date'])
except (ValueError, AttributeError) as e:
logger.warning(f"Invalid date format in record: {record['date']} - {e}")
continue
if not dates:
raise ValueError("No valid dates found in sales data")
start_date = min(dates)
end_date = max(dates)
# Validate and adjust date range for external APIs
start_date, end_date = self._adjust_date_range_for_apis(start_date, end_date)
logger.info(f"Determined and adjusted sales date range: {start_date} to {end_date}")
return start_date, end_date
def _adjust_date_range_for_apis(self, start_date: datetime, end_date: datetime) -> tuple[datetime, datetime]:
"""Adjust date range to comply with external API limits"""
# Weather and traffic APIs have a 90-day limit
MAX_DAYS = 90
# Calculate current range
current_range = (end_date - start_date).days
if current_range > MAX_DAYS:
logger.warning(f"Date range ({current_range} days) exceeds API limit ({MAX_DAYS} days). Adjusting...")
# Keep the most recent data
start_date = end_date - timedelta(days=MAX_DAYS)
logger.info(f"Adjusted start_date to {start_date} to fit within {MAX_DAYS} day limit")
# Ensure dates are not in the future
now = datetime.now()
if end_date > now:
end_date = now.replace(hour=0, minute=0, second=0, microsecond=0)
logger.info(f"Adjusted end_date to {end_date} (cannot be in future)")
if start_date > now:
start_date = now.replace(hour=0, minute=0, second=0, microsecond=0) - timedelta(days=30)
logger.info(f"Adjusted start_date to {start_date} (was in future)")
# Ensure start_date is before end_date
if start_date >= end_date:
start_date = end_date - timedelta(days=30) # Default to 30 days of data
logger.warning(f"start_date was not before end_date. Adjusted start_date to {start_date}")
def __init__(self, db_session: AsyncSession = None):
self.db_session = db_session
self.trainer = BakeryMLTrainer(db_session=db_session) # Pass DB session
self.date_alignment_service = DateAlignmentService()
self.orchestrator = TrainingDataOrchestrator(
date_alignment_service=self.date_alignment_service
)
return start_date, end_date
async def start_training_job(
self,
tenant_id: str,
bakery_location: tuple[float, float] = (40.4168, -3.7038), # Default Madrid
requested_start: Optional[datetime] = None,
requested_end: Optional[datetime] = None,
job_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Start a complete training job for a tenant.
Args:
tenant_id: Tenant identifier
sales_data: Historical sales data
bakery_location: Bakery coordinates (lat, lon)
weather_data: Optional weather data
traffic_data: Optional traffic data
requested_start: Optional explicit start date
requested_end: Optional explicit end date
job_id: Optional job identifier
Returns:
Training job results
"""
if not job_id:
job_id = f"training_{tenant_id}_{uuid.uuid4().hex[:8]}"
logger.info(f"Starting training job {job_id} for tenant {tenant_id}")
async def execute_training_job_simple(self, job_id: str, tenant_id_str: str, request: TrainingJobRequest):
"""Simple wrapper that creates its own database session"""
try:
# Import database_manager locally to avoid circular imports
from app.core.database import database_manager
logger.info(f"Starting background training job {job_id} for tenant {tenant_id_str}")
# Create new session for background task
async with database_manager.async_session_local() as session:
await self.execute_training_job(session, job_id, tenant_id_str, request)
await session.commit()
except Exception as e:
logger.error(f"Background training job {job_id} failed: {str(e)}")
# Try to update job status to failed
try:
from app.core.database import database_manager
async with database_manager.async_session_local() as error_session:
await self._update_job_status(
error_session, job_id, "failed", 0,
f"Training failed: {str(e)}", error_message=str(e)
)
await error_session.commit()
except Exception as update_error:
logger.error(f"Failed to update job status: {str(update_error)}")
raise
async def create_training_job(self,
db: AsyncSession,
tenant_id: str,
job_id: str,
config: Dict[str, Any]) -> ModelTrainingLog:
"""Create a new training job record"""
try:
training_log = ModelTrainingLog(
job_id=job_id,
# Step 1: Prepare training dataset with date alignment and orchestration
logger.info("Step 1: Preparing and aligning training data")
training_dataset = await self.orchestrator.prepare_training_data(
tenant_id=tenant_id,
status="pending",
progress=0,
current_step="Initializing training job",
start_time=datetime.now(),
config=config
)
db.add(training_log)
await db.commit()
await db.refresh(training_log)
logger.info(f"Created training job {job_id} for tenant {tenant_id}")
return training_log
except Exception as e:
logger.error(f"Failed to create training job: {str(e)}")
await db.rollback()
raise
async def create_single_product_job(self,
db: AsyncSession,
tenant_id: str,
product_name: str,
job_id: str,
config: Dict[str, Any]) -> ModelTrainingLog:
"""Create a training job for a single product"""
try:
config["single_product"] = product_name
training_log = ModelTrainingLog(
job_id=job_id,
tenant_id=tenant_id,
status="pending",
progress=0,
current_step=f"Initializing training for {product_name}",
start_time=datetime.now(),
config=config
)
db.add(training_log)
await db.commit()
await db.refresh(training_log)
logger.info(f"Created single product training job {job_id} for {product_name}")
return training_log
except Exception as e:
logger.error(f"Failed to create single product training job: {str(e)}")
await db.rollback()
raise
async def execute_training_job(self,
db: AsyncSession,
job_id: str,
tenant_id: str,
request: TrainingJobRequest):
"""Execute a complete training job"""
try:
logger.info(f"Starting execution of training job {job_id}")
# Update job status to running
await self._update_job_status(db, job_id, "running", 5, "Fetching training data")
# Fetch sales data from data service
sales_data = await self.data_client.fetch_sales_data(tenant_id)
if not sales_data:
raise ValueError("No sales data found for training")
# Determine date range from sales data
start_date, end_date = await self._determine_sales_date_range(sales_data)
# Convert dates to ISO format strings for API calls
start_date_str = start_date.isoformat()
end_date_str = end_date.isoformat()
logger.info(f"Using date range for external APIs: {start_date_str} to {end_date_str}")
# Fetch external data if requested using the sales date range
weather_data = []
traffic_data = []
await self._update_job_status(db, job_id, "running", 15, "Fetching weather data")
try:
weather_data = await self.data_client.fetch_weather_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=40.4168, # Madrid coordinates
longitude=-3.7038
)
logger.info(f"Fetched {len(weather_data)} weather records")
except Exception as e:
logger.warning(f"Failed to fetch weather data: {e}. Continuing without weather data.")
weather_data = []
await self._update_job_status(db, job_id, "running", 25, "Fetching traffic data")
try:
traffic_data = await self.data_client.fetch_traffic_data(
tenant_id=tenant_id,
start_date=start_date_str,
end_date=end_date_str,
latitude=40.4168,
longitude=-3.7038
)
logger.info(f"Fetched {len(traffic_data)} traffic records")
except Exception as e:
logger.warning(f"Failed to fetch traffic data: {e}. Continuing without traffic data.")
traffic_data = []
# Execute ML training
await self._update_job_status(db, job_id, "running", 35, "Processing training data")
training_results = await self.ml_trainer.train_tenant_models(
tenant_id=tenant_id,
sales_data=sales_data,
weather_data=weather_data,
traffic_data=traffic_data,
bakery_location=bakery_location,
requested_start=requested_start,
requested_end=requested_end,
job_id=job_id
)
await self._update_job_status(db, job_id, "running", 85, "Storing trained models")
# Store trained models in database
await self._store_trained_models(db, tenant_id, training_results)
await self._update_job_status(
db, job_id, "completed", 100, "Training completed successfully",
results=training_results
# Step 2: Execute ML training pipeline
logger.info("Step 2: Starting ML training pipeline")
training_results = await self.trainer.train_tenant_models(
tenant_id=tenant_id,
training_dataset=training_dataset,
job_id=job_id
)
# Publish completion event
await publish_job_completed(job_id, tenant_id, training_results)
# Step 3: Compile final results
final_result = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "completed",
"training_results": training_results,
"data_summary": {
"sales_records": len(training_dataset.sales_data),
"weather_records": len(training_dataset.weather_data),
"traffic_records": len(training_dataset.traffic_data),
"date_range": {
"start": training_dataset.date_range.start.isoformat(),
"end": training_dataset.date_range.end.isoformat()
},
"data_sources_used": [source.value for source in training_dataset.date_range.available_sources],
"constraints_applied": training_dataset.date_range.constraints
},
"completed_at": datetime.now().isoformat()
}
logger.info(f"Training results {training_results}")
logger.info(f"Training job {job_id} completed successfully")
metrics.increment_counter("training_jobs_completed")
return final_result
except Exception as e:
logger.error(f"Training job {job_id} failed: {str(e)}")
await self._update_job_status(
db, job_id, "failed", 0, f"Training failed: {str(e)}",
error_message=str(e)
)
# Publish failure event
await publish_job_failed(job_id, tenant_id, str(e))
metrics.increment_counter("training_jobs_failed")
raise
return {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "failed",
"error_message": str(e),
"failed_at": datetime.now().isoformat()
}
async def execute_single_product_training(self,
db: AsyncSession,
job_id: str,
tenant_id: str,
product_name: str,
request: SingleProductTrainingRequest):
"""Execute training for a single product"""
async def start_single_product_training(
self,
tenant_id: str,
product_name: str,
sales_data: List[Dict[str, Any]],
bakery_location: tuple[float, float] = (40.4168, -3.7038),
weather_data: Optional[List[Dict[str, Any]]] = None,
traffic_data: Optional[List[Dict[str, Any]]] = None,
job_id: Optional[str] = None
) -> Dict[str, Any]:
"""
Train a model for a single product.
Args:
tenant_id: Tenant identifier
product_name: Product name
sales_data: Historical sales data
bakery_location: Bakery coordinates
weather_data: Optional weather data
traffic_data: Optional traffic data
job_id: Optional job identifier
Returns:
Single product training result
"""
if not job_id:
job_id = f"single_{tenant_id}_{product_name}_{uuid.uuid4().hex[:8]}"
logger.info(f"Starting single product training {job_id} for {product_name}")
try:
logger.info(f"Starting single product training {job_id} for {product_name}")
# Filter sales data for the specific product
product_sales = [
record for record in sales_data
if record.get('product_name') == product_name
]
# Update job status
await self._update_job_status(db, job_id, "running", 10, f"Fetching data for {product_name}")
if not product_sales:
raise ValueError(f"No sales data found for product: {product_name}")
# Fetch data
sales_data = await self._fetch_product_sales_data(tenant_id, product_name, request)
weather_data = []
traffic_data = []
if request.include_weather:
await self._update_job_status(db, job_id, "running", 30, "Fetching weather data")
weather_data = await self.data_client.fetch_weather_data(tenant_id, request)
if request.include_traffic:
await self._update_job_status(db, job_id, "running", 50, "Fetching traffic data")
traffic_data = await self.data_client.fetch_traffic_data(tenant_id, request)
# Execute training
await self._update_job_status(db, job_id, "running", 70, f"Training model for {product_name}")
training_result = await self.ml_trainer.train_single_product(
# Use the same pipeline but for single product
return await self.start_training_job(
tenant_id=tenant_id,
product_name=product_name,
sales_data=sales_data,
sales_data=product_sales,
bakery_location=bakery_location,
weather_data=weather_data,
traffic_data=traffic_data,
job_id=job_id
)
# Store model
await self._update_job_status(db, job_id, "running", 90, "Storing trained model")
await self._store_single_trained_model(db, tenant_id, product_name, training_result)
await self._update_job_status(
db, job_id, "completed", 100, f"Training completed for {product_name}",
results=training_result
)
logger.info(f"Single product training {job_id} completed successfully")
metrics.increment_counter("single_product_training_completed")
except Exception as e:
logger.error(f"Single product training {job_id} failed: {str(e)}")
await self._update_job_status(
db, job_id, "failed", 0, f"Training failed: {str(e)}",
error_message=str(e)
)
metrics.increment_counter("single_product_training_failed")
raise
return {
"job_id": job_id,
"tenant_id": tenant_id,
"product_name": product_name,
"status": "failed",
"error_message": str(e),
"failed_at": datetime.now().isoformat()
}
async def get_job_status(self,
db: AsyncSession,
job_id: str,
tenant_id: str) -> Optional[ModelTrainingLog]:
"""Get training job status"""
try:
result = await db.execute(
select(ModelTrainingLog).where(
and_(
ModelTrainingLog.job_id == job_id,
ModelTrainingLog.tenant_id == tenant_id
)
)
)
return result.scalar_one_or_none()
async def validate_training_data(
self,
tenant_id: str,
sales_data: List[Dict[str, Any]],
products: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Validate training data quality before starting training.
Args:
tenant_id: Tenant identifier
sales_data: Sales data to validate
products: Optional list of specific products to validate
except Exception as e:
logger.error(f"Failed to get job status: {str(e)}")
return None
async def list_training_jobs(self,
db: AsyncSession,
tenant_id: str,
limit: int = 10,
status_filter: Optional[str] = None) -> List[ModelTrainingLog]:
"""List training jobs for a tenant"""
try:
query = select(ModelTrainingLog).where(
ModelTrainingLog.tenant_id == tenant_id
).order_by(ModelTrainingLog.start_time.desc()).limit(limit)
if status_filter:
query = query.where(ModelTrainingLog.status == status_filter)
result = await db.execute(query)
return result.scalars().all()
except Exception as e:
logger.error(f"Failed to list training jobs: {str(e)}")
return []
async def cancel_training_job(self,
db: AsyncSession,
job_id: str,
tenant_id: str) -> bool:
"""Cancel a training job"""
try:
result = await db.execute(
update(ModelTrainingLog)
.where(
and_(
ModelTrainingLog.job_id == job_id,
ModelTrainingLog.tenant_id == tenant_id,
ModelTrainingLog.status.in_(["pending", "running"])
)
)
.values(
status="cancelled",
end_time=datetime.now(),
current_step="Training cancelled by user"
)
)
await db.commit()
if result.rowcount > 0:
logger.info(f"Cancelled training job {job_id}")
return True
else:
logger.warning(f"Could not cancel training job {job_id} - not found or not cancellable")
return False
except Exception as e:
logger.error(f"Failed to cancel training job: {str(e)}")
await db.rollback()
return False
async def validate_training_data(self,
db: AsyncSession,
tenant_id: str,
config: Dict[str, Any]) -> Dict[str, Any]:
"""Validate training data before starting a job"""
Returns:
Validation results
"""
try:
logger.info(f"Validating training data for tenant {tenant_id}")
issues = []
recommendations = []
# Fetch a sample of sales data to validate
sales_data = await self._fetch_sales_data(tenant_id, config, limit=1000)
# Extract sales date range for validation
if not sales_data:
issues.append("No sales data found for tenant")
return {
"is_valid": False,
"issues": issues,
"recommendations": ["Upload sales data before training"],
"estimated_time_minutes": 0
"valid": False,
"errors": ["No sales data provided"],
"warnings": []
}
# Analyze data quality
products = set(item.get("product_name") for item in sales_data)
total_records = len(sales_data)
# Check for sufficient data per product
product_counts = {}
for item in sales_data:
product = item.get("product_name")
if product:
product_counts[product] = product_counts.get(product, 0) + 1
insufficient_products = [
product for product, count in product_counts.items()
if count < config.get("min_data_points", 30)
]
if insufficient_products:
issues.append(f"Insufficient data for products: {', '.join(insufficient_products)}")
recommendations.append("Collect more historical data for these products")
# Estimate training time
valid_products = len(products) - len(insufficient_products)
estimated_time = max(5, valid_products * 2) # 2 minutes per product minimum
is_valid = len(issues) == 0
return {
"is_valid": is_valid,
"issues": issues,
"recommendations": recommendations,
"estimated_time_minutes": estimated_time,
"products_analyzed": len(products),
"total_data_points": total_records
}
except Exception as e:
logger.error(f"Failed to validate training data: {str(e)}")
return {
"is_valid": False,
"issues": [f"Validation error: {str(e)}"],
"recommendations": ["Check data service connectivity"],
"estimated_time_minutes": 0
}
async def _update_job_status(self,
db: AsyncSession,
job_id: str,
status: str,
progress: int,
current_step: str,
results: Optional[Dict] = None,
error_message: Optional[str] = None):
"""Update training job status"""
try:
update_values = {
"status": status,
"progress": progress,
"current_step": current_step
}
if status == "completed":
update_values["end_time"] = datetime.now()
if results:
update_values["results"] = results
if error_message:
update_values["error_message"] = error_message
update_values["end_time"] = datetime.now()
await db.execute(
update(ModelTrainingLog)
.where(ModelTrainingLog.job_id == job_id)
.values(**update_values)
# Create a mock training dataset to validate
mock_dataset = await self.orchestrator.prepare_training_data(
tenant_id=tenant_id,
sales_data=sales_data,
bakery_location=(40.4168, -3.7038), # Default Madrid
job_id=f"validation_{uuid.uuid4().hex[:8]}"
)
await db.commit()
# Validate the dataset
validation_results = self.orchestrator.validate_training_data_quality(mock_dataset)
# Add product-specific information
unique_products = list(set(record.get('product_name', 'unknown') for record in sales_data))
product_data_points = {}
for record in sales_data:
product = record.get('product_name', 'unknown')
product_data_points[product] = product_data_points.get(product, 0) + 1
validation_results.update({
"products_found": unique_products,
"product_data_points": product_data_points,
"total_records": len(sales_data),
"date_range_info": {
"start": mock_dataset.date_range.start.isoformat(),
"end": mock_dataset.date_range.end.isoformat(),
"duration_days": (mock_dataset.date_range.end - mock_dataset.date_range.start).days
}
})
return validation_results
except Exception as e:
logger.error(f"Failed to update job status: {str(e)}")
await db.rollback()
logger.error(f"Training data validation failed: {str(e)}")
return {
"valid": False,
"errors": [f"Validation failed: {str(e)}"],
"warnings": []
}
async def _store_trained_models(self,
db: AsyncSession,
tenant_id: str,
training_results: Dict[str, Any]):
"""Store trained models in database"""
async def get_training_recommendations(
self,
tenant_id: str,
sales_data: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Get training recommendations based on data analysis.
Args:
tenant_id: Tenant identifier
sales_data: Historical sales data
Returns:
Training recommendations
"""
try:
models_to_store = []
logger.info(f"Generating training recommendations for tenant {tenant_id}")
for product_name, result in training_results.get("training_results", {}).items():
if result.get("status") == "success":
model_info = result.get("model_info", {})
trained_model = TrainedModel(
tenant_id=tenant_id,
product_name=product_name,
model_id=model_info.get("model_id"),
model_type=model_info.get("type", "prophet"),
model_path=model_info.get("model_path"),
version=1, # Start with version 1
training_samples=model_info.get("training_samples", 0),
features=model_info.get("features", []),
hyperparameters=model_info.get("hyperparameters", {}),
training_metrics=model_info.get("training_metrics", {}),
data_period_start=datetime.fromisoformat(
model_info.get("data_period", {}).get("start_date", datetime.now().isoformat())
),
data_period_end=datetime.fromisoformat(
model_info.get("data_period", {}).get("end_date", datetime.now().isoformat())
),
created_at=datetime.now(),
is_active=True
)
models_to_store.append(trained_model)
# Analyze the data
validation_results = await self.validate_training_data(tenant_id, sales_data)
# Deactivate old models for these products
if models_to_store:
product_names = [model.product_name for model in models_to_store]
await db.execute(
update(TrainedModel)
.where(
and_(
TrainedModel.tenant_id == tenant_id,
TrainedModel.product_name.in_(product_names),
TrainedModel.is_active == True
)
)
.values(is_active=False)
)
# Add new models
db.add_all(models_to_store)
await db.commit()
logger.info(f"Stored {len(models_to_store)} trained models for tenant {tenant_id}")
recommendations = {
"should_retrain": True,
"reasons": [],
"recommended_products": [],
"optimal_config": {
"include_weather": True,
"include_traffic": True,
"min_data_points": 30,
"hyperparameter_optimization": True
}
}
# Analyze data quality and provide recommendations
if validation_results.get("data_quality_score", 0) >= 80:
recommendations["reasons"].append("High quality data detected")
else:
recommendations["reasons"].append("Data quality could be improved")
# Recommend products with sufficient data
product_data_points = validation_results.get("product_data_points", {})
for product, points in product_data_points.items():
if points >= 30: # Minimum viable data points
recommendations["recommended_products"].append(product)
if len(recommendations["recommended_products"]) == 0:
recommendations["should_retrain"] = False
recommendations["reasons"].append("Insufficient data for reliable training")
return recommendations
except Exception as e:
logger.error(f"Failed to store trained models: {str(e)}")
await db.rollback()
raise
async def _store_single_trained_model(self,
db: AsyncSession,
tenant_id: str,
product_name: str,
training_result: Dict[str, Any]):
"""Store a single trained model"""
try:
if training_result.get("status") == "success":
model_info = training_result.get("model_info", {})
# Deactivate old model for this product
await db.execute(
update(TrainedModel)
.where(
and_(
TrainedModel.tenant_id == tenant_id,
TrainedModel.product_name == product_name,
TrainedModel.is_active == True
)
)
.values(is_active=False)
)
# Create new model record
trained_model = TrainedModel(
tenant_id=tenant_id,
product_name=product_name,
model_id=model_info.get("model_id"),
model_type=model_info.get("type", "prophet"),
model_path=model_info.get("model_path"),
version=1,
training_samples=model_info.get("training_samples", 0),
features=model_info.get("features", []),
hyperparameters=model_info.get("hyperparameters", {}),
training_metrics=model_info.get("training_metrics", {}),
data_period_start=datetime.fromisoformat(
model_info.get("data_period", {}).get("start_date", datetime.now().isoformat())
),
data_period_end=datetime.fromisoformat(
model_info.get("data_period", {}).get("end_date", datetime.now().isoformat())
),
created_at=datetime.now(),
is_active=True
)
db.add(trained_model)
await db.commit()
logger.info(f"Stored trained model for {product_name}")
except Exception as e:
logger.error(f"Failed to store trained model: {str(e)}")
await db.rollback()
raise
async def get_training_logs(self,
db: AsyncSession,
job_id: str,
tenant_id: str) -> Optional[List[str]]:
"""Get detailed training logs for a job"""
try:
# For now, return basic log information from the database
# In a production system, you might store detailed logs separately
result = await db.execute(
select(ModelTrainingLog).where(
and_(
ModelTrainingLog.job_id == job_id,
ModelTrainingLog.tenant_id == tenant_id
)
)
)
training_log = result.scalar_one_or_none()
if training_log:
logs = [
f"Job started at: {training_log.start_time}",
f"Current status: {training_log.status}",
f"Progress: {training_log.progress}%",
f"Current step: {training_log.current_step}"
]
if training_log.end_time:
logs.append(f"Job completed at: {training_log.end_time}")
if training_log.error_message:
logs.append(f"Error: {training_log.error_message}")
if training_log.results:
results = training_log.results
logs.append(f"Models trained: {results.get('products_trained', 0)}")
logs.append(f"Models failed: {results.get('products_failed', 0)}")
return logs
return None
except Exception as e:
logger.error(f"Failed to get training logs: {str(e)}")
return None
async def _determine_sales_date_range(self, sales_data: List[Dict]) -> tuple[datetime, datetime]:
"""Determine start and end dates from sales data"""
if not sales_data:
raise ValueError("No sales data available to determine date range")
dates = []
for record in sales_data:
if 'date' in record:
if isinstance(record['date'], str):
dates.append(datetime.fromisoformat(record['date'].replace('Z', '+00:00')))
elif isinstance(record['date'], datetime):
dates.append(record['date'])
if not dates:
raise ValueError("No valid dates found in sales data")
start_date = min(dates)
end_date = max(dates)
logger.info(f"Determined sales date range: {start_date} to {end_date}")
return start_date, end_date
logger.error(f"Failed to generate training recommendations: {str(e)}")
return {
"should_retrain": False,
"reasons": [f"Error analyzing data: {str(e)}"],
"recommended_products": [],
"optimal_config": {}
}

View File

@@ -47,4 +47,7 @@ psutil==5.9.0
# Utilities
python-dateutil==2.8.2
pytz==2023.3
pytz==2023.3
# Hyperparameter optimization
optuna==3.4.0

View File

@@ -1,311 +0,0 @@
# services/training/tests/conftest.py
"""
Test configuration and fixtures for training service ML components
"""
import pytest
import asyncio
import os
import tempfile
import pandas as pd
import numpy as np
from unittest.mock import Mock, AsyncMock, patch
from typing import Dict, List, Any, Generator
from datetime import datetime, timedelta
import uuid
# Configure test environment
os.environ["MODEL_STORAGE_PATH"] = "/tmp/test_models"
os.environ["TRAINING_DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
# Create test event loop
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for the test session."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
# ================================================================
# PYTEST CONFIGURATION
# ================================================================
def pytest_configure(config):
"""Configure pytest markers"""
config.addinivalue_line("markers", "unit: Unit tests")
config.addinivalue_line("markers", "integration: Integration tests")
config.addinivalue_line("markers", "ml: Machine learning tests")
config.addinivalue_line("markers", "slow: Slow-running tests")
# ================================================================
# MOCK SETTINGS AND CONFIGURATION
# ================================================================
@pytest.fixture(autouse=True)
def mock_settings():
"""Mock settings for all tests"""
with patch('app.core.config.settings') as mock_settings:
mock_settings.MODEL_STORAGE_PATH = "/tmp/test_models"
mock_settings.MIN_TRAINING_DATA_DAYS = 30
mock_settings.PROPHET_SEASONALITY_MODE = "additive"
mock_settings.PROPHET_CHANGEPOINT_PRIOR_SCALE = 0.05
mock_settings.PROPHET_SEASONALITY_PRIOR_SCALE = 10.0
mock_settings.PROPHET_HOLIDAYS_PRIOR_SCALE = 10.0
mock_settings.ENABLE_SPANISH_HOLIDAYS = True
mock_settings.ENABLE_MADRID_HOLIDAYS = True
# Ensure test model directory exists
os.makedirs("/tmp/test_models", exist_ok=True)
yield mock_settings
# ================================================================
# MOCK ML COMPONENTS
# ================================================================
@pytest.fixture
def mock_prophet_manager():
"""Mock BakeryProphetManager for testing"""
mock_manager = AsyncMock()
# Mock train_bakery_model method
mock_manager.train_bakery_model.return_value = {
'model_id': f'test-model-{uuid.uuid4().hex[:8]}',
'model_path': '/tmp/test_models/test_model.pkl',
'type': 'prophet',
'training_samples': 100,
'features': ['temperature', 'humidity', 'day_of_week'],
'training_metrics': {
'mae': 5.2,
'rmse': 7.8,
'r2': 0.85
},
'created_at': datetime.now().isoformat()
}
# Mock validate_training_data method
mock_manager._validate_training_data = AsyncMock()
# Mock generate_forecast method
mock_manager.generate_forecast.return_value = pd.DataFrame({
'ds': pd.date_range('2024-02-01', periods=7, freq='D'),
'yhat': [50.0] * 7,
'yhat_lower': [45.0] * 7,
'yhat_upper': [55.0] * 7
})
# Mock other methods
mock_manager._get_spanish_holidays.return_value = pd.DataFrame({
'holiday': ['new_year', 'christmas'],
'ds': [datetime(2024, 1, 1), datetime(2024, 12, 25)]
})
mock_manager._extract_regressor_columns.return_value = ['temperature', 'humidity']
return mock_manager
@pytest.fixture
def mock_data_processor():
"""Mock BakeryDataProcessor for testing"""
mock_processor = AsyncMock()
# Mock prepare_training_data method
mock_processor.prepare_training_data.return_value = pd.DataFrame({
'ds': pd.date_range('2024-01-01', periods=35, freq='D'),
'y': [45 + 5 * np.sin(i / 7) for i in range(35)],
'temperature': [15.0] * 35,
'humidity': [65.0] * 35,
'day_of_week': [i % 7 for i in range(35)],
'is_weekend': [1 if i % 7 >= 5 else 0 for i in range(35)],
'month': [1] * 35,
'is_holiday': [0] * 35
})
# Mock prepare_prediction_features method
mock_processor.prepare_prediction_features.return_value = pd.DataFrame({
'ds': pd.date_range('2024-02-01', periods=7, freq='D'),
'temperature': [18.0] * 7,
'humidity': [65.0] * 7,
'day_of_week': [i % 7 for i in range(7)],
'is_weekend': [1 if i % 7 >= 5 else 0 for i in range(7)],
'month': [2] * 7,
'is_holiday': [0] * 7
})
# Mock private methods for testing
mock_processor._add_temporal_features.return_value = pd.DataFrame({
'date': pd.date_range('2024-01-01', periods=10, freq='D'),
'day_of_week': [i % 7 for i in range(10)],
'is_weekend': [1 if i % 7 >= 5 else 0 for i in range(10)],
'month': [1] * 10,
'season': ['winter'] * 10,
'week_of_year': [1] * 10,
'quarter': [1] * 10,
'is_holiday': [0] * 10,
'is_school_holiday': [0] * 10
})
mock_processor._is_spanish_holiday.return_value = False
return mock_processor
# ================================================================
# SAMPLE DATA FIXTURES
# ================================================================
@pytest.fixture
def sample_sales_data():
"""Generate sample sales data for testing"""
dates = pd.date_range('2024-01-01', periods=35, freq='D')
data = []
for i, date in enumerate(dates):
data.append({
'date': date,
'product_name': 'Pan Integral',
'quantity': 40 + (5 * np.sin(i / 7)) + np.random.normal(0, 2)
})
return pd.DataFrame(data)
@pytest.fixture
def sample_weather_data():
"""Generate sample weather data for testing"""
dates = pd.date_range('2024-01-01', periods=60, freq='D')
return pd.DataFrame({
'date': dates,
'temperature': [15 + 5 * np.sin(2 * np.pi * i / 365) + np.random.normal(0, 2) for i in range(60)],
'precipitation': [max(0, np.random.exponential(1)) for _ in range(60)],
'humidity': [60 + np.random.normal(0, 10) for _ in range(60)]
})
@pytest.fixture
def sample_traffic_data():
"""Generate sample traffic data for testing"""
dates = pd.date_range('2024-01-01', periods=60, freq='D')
return pd.DataFrame({
'date': dates,
'traffic_volume': [100 + np.random.normal(0, 20) for _ in range(60)]
})
@pytest.fixture
def sample_prophet_data():
"""Generate sample data in Prophet format for testing"""
dates = pd.date_range('2024-01-01', periods=100, freq='D')
return pd.DataFrame({
'ds': dates,
'y': [45 + 10 * np.sin(2 * np.pi * i / 7) + np.random.normal(0, 5) for i in range(100)],
'temperature': [15 + 5 * np.sin(2 * np.pi * i / 365) for i in range(100)],
'humidity': [60 + np.random.normal(0, 10) for _ in range(100)]
})
@pytest.fixture
def sample_sales_records():
"""Generate sample sales records as list of dicts"""
return [
{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45},
{"date": "2024-01-02", "product_name": "Pan Integral", "quantity": 50},
{"date": "2024-01-03", "product_name": "Pan Integral", "quantity": 48},
{"date": "2024-01-04", "product_name": "Croissant", "quantity": 25},
{"date": "2024-01-05", "product_name": "Croissant", "quantity": 30}
]
# ================================================================
# UTILITY FIXTURES
# ================================================================
@pytest.fixture
def temp_model_dir():
"""Create a temporary directory for model storage"""
with tempfile.TemporaryDirectory() as temp_dir:
yield temp_dir
@pytest.fixture
def test_tenant_id():
"""Generate a test tenant ID"""
return f"test-tenant-{uuid.uuid4().hex[:8]}"
@pytest.fixture
def test_job_id():
"""Generate a test job ID"""
return f"test-job-{uuid.uuid4().hex[:8]}"
# ================================================================
# MOCK EXTERNAL DEPENDENCIES (Simplified)
# ================================================================
@pytest.fixture
def mock_prophet_model():
"""Create a mock Prophet model for testing"""
mock_model = Mock()
mock_model.fit.return_value = None
mock_model.predict.return_value = pd.DataFrame({
'ds': pd.date_range('2024-02-01', periods=7, freq='D'),
'yhat': [50.0] * 7,
'yhat_lower': [45.0] * 7,
'yhat_upper': [55.0] * 7
})
mock_model.add_regressor.return_value = None
return mock_model
# ================================================================
# DATABASE MOCKS
# ================================================================
@pytest.fixture
def mock_db_session():
"""Mock database session for testing"""
mock_session = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
mock_session.close = AsyncMock()
mock_session.add = Mock()
mock_session.execute = AsyncMock()
mock_session.scalar = AsyncMock()
mock_session.scalars = AsyncMock()
return mock_session
# ================================================================
# PERFORMANCE TESTING
# ================================================================
@pytest.fixture
def performance_tracker():
"""Performance tracking utilities for tests"""
class PerformanceTracker:
def __init__(self):
self.start_time = None
self.measurements = {}
def start(self, operation_name: str = "default"):
self.start_time = datetime.now()
self.operation_name = operation_name
def stop(self) -> float:
if self.start_time:
duration = (datetime.now() - self.start_time).total_seconds() * 1000
self.measurements[self.operation_name] = duration
return duration
return 0.0
def assert_performance(self, max_duration_ms: float, operation_name: str = "default"):
duration = self.measurements.get(operation_name, float('inf'))
assert duration <= max_duration_ms, f"Operation {operation_name} took {duration:.0f}ms, expected <= {max_duration_ms}ms"
return PerformanceTracker()
# ================================================================
# CLEANUP
# ================================================================
@pytest.fixture(autouse=True)
def cleanup_after_test():
"""Automatic cleanup after each test"""
yield
# Clean up any test model files
test_model_path = "/tmp/test_models"
if os.path.exists(test_model_path):
for file in os.listdir(test_model_path):
try:
os.remove(os.path.join(test_model_path, file))
except (OSError, PermissionError):
pass

View File

@@ -1,47 +0,0 @@
# services/training/pytest.ini
[tool:pytest]
# Minimum pytest configuration for training service ML tests
# Test discovery
python_files = test_*.py *_test.py
python_classes = Test*
python_functions = test_*
# Test directories
testpaths = tests
# Markers
markers =
unit: Unit tests (fast, isolated)
integration: Integration tests (slower, with dependencies)
ml: Machine learning specific tests
slow: Slow-running tests
api: API endpoint tests
performance: Performance tests
# Asyncio configuration
asyncio_mode = auto
# Output configuration
addopts =
-v
--tb=short
--strict-markers
--disable-warnings
--color=yes
# Minimum Python version
minversion = 3.8
# Ignore certain warnings
filterwarnings =
ignore::DeprecationWarning
ignore::PendingDeprecationWarning
ignore::UserWarning:prophet.*
ignore::UserWarning:pandas.*
# Test timeout (in seconds)
timeout = 300
# Coverage (if pytest-cov is installed)
# addopts = -v --tb=short --strict-markers --disable-warnings --color=yes --cov=app --cov-report=term-missing

View File

@@ -1,734 +0,0 @@
# services/training/tests/test_ml.py
"""
Tests for ML components: trainer, prophet_manager, and data_processor
"""
import pytest
import pandas as pd
import numpy as np
from unittest.mock import Mock, patch, AsyncMock
from datetime import datetime, timedelta
import os
import tempfile
from app.ml.trainer import BakeryMLTrainer
from app.ml.prophet_manager import BakeryProphetManager
from app.ml.data_processor import BakeryDataProcessor
class TestBakeryDataProcessor:
"""Test the data processor component"""
@pytest.fixture
def data_processor(self):
return BakeryDataProcessor()
@pytest.mark.asyncio
async def test_prepare_training_data_basic(
self,
data_processor,
sample_sales_data,
sample_weather_data,
sample_traffic_data
):
"""Test basic data preparation"""
result = await data_processor.prepare_training_data(
sales_data=sample_sales_data,
weather_data=sample_weather_data,
traffic_data=sample_traffic_data,
product_name="Pan Integral"
)
# Check result structure
assert isinstance(result, pd.DataFrame)
assert 'ds' in result.columns
assert 'y' in result.columns
assert len(result) > 0
# Check Prophet format
assert result['ds'].dtype == 'datetime64[ns]'
assert pd.api.types.is_numeric_dtype(result['y'])
# Check temporal features
temporal_features = ['day_of_week', 'is_weekend', 'month', 'is_holiday']
for feature in temporal_features:
assert feature in result.columns
# Check weather features
weather_features = ['temperature', 'precipitation', 'humidity']
for feature in weather_features:
assert feature in result.columns
# Check traffic features
assert 'traffic_volume' in result.columns
@pytest.mark.asyncio
async def test_prepare_training_data_empty_weather(
self,
data_processor,
sample_sales_data
):
"""Test data preparation with empty weather data"""
result = await data_processor.prepare_training_data(
sales_data=sample_sales_data,
weather_data=pd.DataFrame(),
traffic_data=pd.DataFrame(),
product_name="Pan Integral"
)
# Should still work with default values
assert isinstance(result, pd.DataFrame)
assert 'ds' in result.columns
assert 'y' in result.columns
# Should have default weather values
assert 'temperature' in result.columns
assert result['temperature'].iloc[0] == 15.0 # Default value
@pytest.mark.asyncio
async def test_prepare_prediction_features(self, data_processor):
"""Test preparation of prediction features"""
future_dates = pd.date_range('2024-02-01', periods=7, freq='D')
weather_forecast = pd.DataFrame({
'ds': future_dates,
'temperature': [18.0] * 7,
'precipitation': [0.0] * 7,
'humidity': [65.0] * 7
})
result = await data_processor.prepare_prediction_features(
future_dates=future_dates,
weather_forecast=weather_forecast,
traffic_forecast=pd.DataFrame()
)
assert isinstance(result, pd.DataFrame)
assert len(result) == 7
assert 'ds' in result.columns
# Check temporal features are added
assert 'day_of_week' in result.columns
assert 'is_weekend' in result.columns
# Check weather features
assert 'temperature' in result.columns
assert all(result['temperature'] == 18.0)
def test_add_temporal_features(self, data_processor):
"""Test temporal feature engineering"""
dates = pd.date_range('2024-01-01', periods=10, freq='D')
df = pd.DataFrame({'date': dates})
result = data_processor._add_temporal_features(df)
# Check temporal features
assert 'day_of_week' in result.columns
assert 'is_weekend' in result.columns
assert 'month' in result.columns
assert 'season' in result.columns
assert 'week_of_year' in result.columns
assert 'quarter' in result.columns
assert 'is_holiday' in result.columns
assert 'is_school_holiday' in result.columns
# Check weekend detection
# 2024-01-01 was a Monday (day_of_week = 0)
assert result.iloc[0]['day_of_week'] == 0
assert result.iloc[0]['is_weekend'] == 0
# 2024-01-06 was a Saturday (day_of_week = 5)
assert result.iloc[5]['day_of_week'] == 5
assert result.iloc[5]['is_weekend'] == 1
def test_spanish_holiday_detection(self, data_processor):
"""Test Spanish holiday detection"""
# Test known Spanish holidays
new_year = datetime(2024, 1, 1)
epiphany = datetime(2024, 1, 6)
labour_day = datetime(2024, 5, 1)
christmas = datetime(2024, 12, 25)
assert data_processor._is_spanish_holiday(new_year) == True
assert data_processor._is_spanish_holiday(epiphany) == True
assert data_processor._is_spanish_holiday(labour_day) == True
assert data_processor._is_spanish_holiday(christmas) == True
# Test non-holiday
regular_day = datetime(2024, 3, 15)
assert data_processor._is_spanish_holiday(regular_day) == False
@pytest.mark.asyncio
async def test_prepare_training_data_insufficient_data(self, data_processor):
"""Test handling of insufficient training data"""
# Create very small dataset (less than 30 days minimum)
small_sales_data = pd.DataFrame({
'date': pd.date_range('2024-01-01', periods=5, freq='D'),
'product_name': ['Pan Integral'] * 5,
'quantity': [45, 50, 48, 52, 49]
})
# The actual implementation might not raise an exception, so let's test the behavior
try:
result = await data_processor.prepare_training_data(
sales_data=small_sales_data,
weather_data=pd.DataFrame(),
traffic_data=pd.DataFrame(),
product_name="Pan Integral"
)
# If no exception is raised, check that we get minimal data
assert len(result) <= 30, "Should have limited data for small dataset"
except (ValueError, Exception) as e:
# If an exception is raised, that's also acceptable for insufficient data
assert "insufficient" in str(e).lower() or "minimum" in str(e).lower() or len(small_sales_data) < 30
class TestBakeryProphetManager:
"""Test the Prophet manager component"""
@pytest.fixture
def prophet_manager(self, temp_model_dir):
with patch('app.ml.prophet_manager.settings.MODEL_STORAGE_PATH', temp_model_dir):
return BakeryProphetManager()
@pytest.mark.asyncio
async def test_train_bakery_model_success(self, prophet_manager, sample_prophet_data):
"""Test successful model training"""
# Use explicit patching within the test to ensure mocking works
with patch('app.ml.prophet_manager.Prophet') as mock_prophet_class, \
patch('app.ml.prophet_manager.joblib.dump') as mock_dump:
mock_model = Mock()
mock_model.fit.return_value = None
mock_model.add_regressor.return_value = None
mock_prophet_class.return_value = mock_model
result = await prophet_manager.train_bakery_model(
tenant_id="test-tenant",
product_name="Pan Integral",
df=sample_prophet_data,
job_id="test-job-123"
)
# Check result structure
assert isinstance(result, dict)
assert 'model_id' in result
assert 'model_path' in result
assert 'type' in result
assert result['type'] == 'prophet'
assert 'training_samples' in result
assert 'features' in result
assert 'training_metrics' in result
# Check that model was created and fitted
mock_prophet_class.assert_called_once()
mock_model.fit.assert_called_once()
mock_dump.assert_called_once()
@pytest.mark.asyncio
async def test_validate_training_data_valid(self, prophet_manager, sample_prophet_data):
"""Test validation with valid data"""
# Should not raise exception
await prophet_manager._validate_training_data(sample_prophet_data, "Pan Integral")
@pytest.mark.asyncio
async def test_validate_training_data_insufficient(self, prophet_manager):
"""Test validation with insufficient data"""
small_data = pd.DataFrame({
'ds': pd.date_range('2024-01-01', periods=5, freq='D'),
'y': [45, 50, 48, 52, 49]
})
with pytest.raises(ValueError, match="Insufficient training data"):
await prophet_manager._validate_training_data(small_data, "Pan Integral")
@pytest.mark.asyncio
async def test_validate_training_data_missing_columns(self, prophet_manager):
"""Test validation with missing required columns"""
invalid_data = pd.DataFrame({
'date': pd.date_range('2024-01-01', periods=50, freq='D'),
'quantity': [45] * 50
})
with pytest.raises(ValueError, match="Missing required columns"):
await prophet_manager._validate_training_data(invalid_data, "Pan Integral")
def test_get_spanish_holidays(self, prophet_manager):
"""Test Spanish holidays creation"""
holidays = prophet_manager._get_spanish_holidays()
if not holidays.empty:
assert 'holiday' in holidays.columns
assert 'ds' in holidays.columns
# Check some known holidays exist
holiday_names = holidays['holiday'].unique()
expected_holidays = ['new_year', 'christmas', 'may_day']
for holiday in expected_holidays:
assert holiday in holiday_names
def test_extract_regressor_columns(self, prophet_manager, sample_prophet_data):
"""Test regressor column extraction"""
regressors = prophet_manager._extract_regressor_columns(sample_prophet_data)
assert isinstance(regressors, list)
assert 'temperature' in regressors
assert 'humidity' in regressors
assert 'ds' not in regressors # Should be excluded
assert 'y' not in regressors # Should be excluded
@pytest.mark.asyncio
async def test_generate_forecast(self, prophet_manager):
"""Test forecast generation"""
# Create a temporary model file
with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as temp_file:
model_path = temp_file.name
try:
# Mock joblib.load and the loaded model
with patch('app.ml.prophet_manager.joblib.load') as mock_load:
mock_model = Mock()
mock_forecast = pd.DataFrame({
'ds': pd.date_range('2024-02-01', periods=7, freq='D'),
'yhat': [50.0] * 7,
'yhat_lower': [45.0] * 7,
'yhat_upper': [55.0] * 7
})
mock_model.predict.return_value = mock_forecast
mock_load.return_value = mock_model
future_data = pd.DataFrame({
'ds': pd.date_range('2024-02-01', periods=7, freq='D'),
'temperature': [18.0] * 7,
'humidity': [65.0] * 7
})
result = await prophet_manager.generate_forecast(
model_path=model_path,
future_dates=future_data,
regressor_columns=['temperature', 'humidity']
)
assert isinstance(result, pd.DataFrame)
assert len(result) == 7
mock_load.assert_called_once_with(model_path)
mock_model.predict.assert_called_once()
finally:
# Cleanup
try:
os.unlink(model_path)
except FileNotFoundError:
pass
class TestBakeryMLTrainer:
"""Test the ML trainer component"""
@pytest.fixture
def ml_trainer(self):
# Create trainer with mocked dependencies
trainer = BakeryMLTrainer()
# Replace with mocks
trainer.prophet_manager = Mock()
trainer.data_processor = Mock()
return trainer
@pytest.mark.asyncio
async def test_train_tenant_models_success(
self,
ml_trainer,
sample_sales_records,
mock_prophet_manager,
mock_data_processor
):
"""Test successful training of tenant models"""
# Configure mocks
ml_trainer.prophet_manager = mock_prophet_manager
ml_trainer.data_processor = mock_data_processor
result = await ml_trainer.train_tenant_models(
tenant_id="test-tenant",
sales_data=sample_sales_records,
weather_data=[],
traffic_data=[],
job_id="test-job-123"
)
# Check result structure
assert isinstance(result, dict)
assert 'job_id' in result
assert 'tenant_id' in result
assert 'status' in result
assert 'training_results' in result
assert 'summary' in result
assert result['status'] == 'completed'
assert result['tenant_id'] == 'test-tenant'
@pytest.mark.asyncio
async def test_train_single_product_success(
self,
ml_trainer,
sample_sales_records,
mock_prophet_manager,
mock_data_processor
):
"""Test successful single product training"""
# Configure mocks
ml_trainer.prophet_manager = mock_prophet_manager
ml_trainer.data_processor = mock_data_processor
product_sales = [item for item in sample_sales_records if item['product_name'] == 'Pan Integral']
result = await ml_trainer.train_single_product(
tenant_id="test-tenant",
product_name="Pan Integral",
sales_data=product_sales,
weather_data=[],
traffic_data=[],
job_id="test-job-123"
)
# Check result structure
assert isinstance(result, dict)
assert 'job_id' in result
assert 'tenant_id' in result
assert 'product_name' in result
assert 'status' in result
assert 'model_info' in result
assert result['status'] == 'success'
assert result['product_name'] == 'Pan Integral'
@pytest.mark.asyncio
async def test_train_single_product_no_data(self, ml_trainer):
"""Test single product training with no data"""
# Test with empty list
try:
result = await ml_trainer.train_single_product(
tenant_id="test-tenant",
product_name="Nonexistent Product",
sales_data=[],
weather_data=[],
traffic_data=[],
job_id="test-job-123"
)
# If no exception is raised, check that status indicates failure
assert result.get('status') in ['error', 'failed'] or 'error' in result
except (ValueError, KeyError) as e:
# Expected exceptions for no data
assert True # This is the expected behavior
@pytest.mark.asyncio
async def test_validate_input_data_valid(self, ml_trainer, sample_sales_records):
"""Test input data validation with valid data"""
df = pd.DataFrame(sample_sales_records)
# Should not raise exception
await ml_trainer._validate_input_data(df, "test-tenant")
@pytest.mark.asyncio
async def test_validate_input_data_empty(self, ml_trainer):
"""Test input data validation with empty data"""
empty_df = pd.DataFrame()
with pytest.raises(ValueError, match="No sales data provided"):
await ml_trainer._validate_input_data(empty_df, "test-tenant")
@pytest.mark.asyncio
async def test_validate_input_data_missing_columns(self, ml_trainer):
"""Test input data validation with missing columns"""
invalid_df = pd.DataFrame([
{"invalid_column": "value1"},
{"invalid_column": "value2"}
])
with pytest.raises(ValueError, match="Missing required columns"):
await ml_trainer._validate_input_data(invalid_df, "test-tenant")
def test_calculate_training_summary(self, ml_trainer):
"""Test training summary calculation"""
training_results = {
"Pan Integral": {
"status": "success",
"model_info": {"training_metrics": {"mae": 5.0, "rmse": 7.0}}
},
"Croissant": {
"status": "error",
"error_message": "Insufficient data"
},
"Baguette": {
"status": "skipped",
"reason": "insufficient_data"
}
}
summary = ml_trainer._calculate_training_summary(training_results)
assert summary['total_products'] == 3
assert summary['successful_products'] == 1
assert summary['failed_products'] == 1
assert summary['skipped_products'] == 1
assert summary['success_rate'] == 33.33 # 1/3 * 100
class TestIntegrationML:
"""Integration tests for ML components working together"""
@pytest.mark.integration
@pytest.mark.asyncio
async def test_end_to_end_training_flow(self, sample_sales_data, sample_weather_data):
"""Test complete training flow from data to model"""
# This test demonstrates the full flow without external dependencies
data_processor = BakeryDataProcessor()
# Test data preparation
prepared_data = await data_processor.prepare_training_data(
sales_data=sample_sales_data,
weather_data=sample_weather_data,
traffic_data=pd.DataFrame(),
product_name="Pan Integral"
)
# Verify prepared data structure
assert isinstance(prepared_data, pd.DataFrame)
assert len(prepared_data) > 0
assert 'ds' in prepared_data.columns
assert 'y' in prepared_data.columns
# Mock prophet manager for the integration test
with patch('app.ml.prophet_manager.Prophet') as mock_prophet, \
patch('app.ml.prophet_manager.joblib.dump') as mock_dump:
mock_model = Mock()
mock_model.fit.return_value = None
mock_model.add_regressor.return_value = None
mock_prophet.return_value = mock_model
prophet_manager = BakeryProphetManager()
result = await prophet_manager.train_bakery_model(
tenant_id="test-tenant",
product_name="Pan Integral",
df=prepared_data,
job_id="integration-test"
)
assert result['type'] == 'prophet'
assert 'model_path' in result
mock_prophet.assert_called_once()
mock_model.fit.assert_called_once()
@pytest.mark.integration
@pytest.mark.asyncio
async def test_data_pipeline_integration(self, sample_sales_data, sample_weather_data):
"""Test data processor -> prophet manager integration"""
data_processor = BakeryDataProcessor()
# Prepare data
prepared_data = await data_processor.prepare_training_data(
sales_data=sample_sales_data,
weather_data=sample_weather_data,
traffic_data=pd.DataFrame(),
product_name="Pan Integral"
)
# Verify the data can be used by Prophet
assert 'ds' in prepared_data.columns
assert 'y' in prepared_data.columns
assert len(prepared_data) >= 30 # Minimum training data
# Check feature columns are present
feature_columns = ['temperature', 'humidity', 'day_of_week', 'is_weekend']
for col in feature_columns:
assert col in prepared_data.columns
@pytest.mark.unit
def test_temporal_feature_consistency(self):
"""Test that temporal features are consistently generated"""
data_processor = BakeryDataProcessor()
# Test with different date ranges
test_dates = [
pd.date_range('2024-01-01', periods=7, freq='D'), # Week
pd.date_range('2024-01-01', periods=31, freq='D'), # Month
pd.date_range('2024-01-01', periods=365, freq='D') # Year
]
for dates in test_dates:
df = pd.DataFrame({'date': dates})
result = data_processor._add_temporal_features(df)
# Check all expected features are present
expected_features = [
'day_of_week', 'is_weekend', 'month', 'season',
'week_of_year', 'quarter', 'is_holiday', 'is_school_holiday'
]
for feature in expected_features:
assert feature in result.columns, f"Missing feature: {feature}"
# Check value ranges
assert result['day_of_week'].min() >= 0
assert result['day_of_week'].max() <= 6
assert result['month'].min() >= 1
assert result['month'].max() <= 12
assert result['quarter'].min() >= 1
assert result['quarter'].max() <= 4
assert result['is_weekend'].isin([0, 1]).all()
assert result['is_holiday'].isin([0, 1]).all()
class TestMLPerformance:
"""Performance tests for ML components"""
@pytest.mark.slow
@pytest.mark.asyncio
async def test_data_processing_performance(self, performance_tracker):
"""Test data processing performance with larger datasets"""
# Create larger dataset
dates = pd.date_range('2023-01-01', periods=365, freq='D')
large_sales_data = pd.DataFrame({
'date': dates,
'product_name': ['Pan Integral'] * 365,
'quantity': [45 + 10 * np.sin(2 * np.pi * i / 7) for i in range(365)]
})
large_weather_data = pd.DataFrame({
'date': dates,
'temperature': [15 + 5 * np.sin(2 * np.pi * i / 365) for i in range(365)],
'precipitation': [max(0, np.random.exponential(1)) for _ in range(365)],
'humidity': [60 + np.random.normal(0, 10) for _ in range(365)]
})
data_processor = BakeryDataProcessor()
# Measure performance
performance_tracker.start("data_processing")
result = await data_processor.prepare_training_data(
sales_data=large_sales_data,
weather_data=large_weather_data,
traffic_data=pd.DataFrame(),
product_name="Pan Integral"
)
duration = performance_tracker.stop()
# Assert performance (should process 365 days in reasonable time)
performance_tracker.assert_performance(5000, "data_processing") # 5 seconds max
# Verify result quality
assert len(result) == 365
assert result['y'].notna().all()
@pytest.mark.unit
def test_memory_efficiency(self):
"""Test memory efficiency with multiple datasets"""
try:
import psutil
process = psutil.Process()
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
data_processor = BakeryDataProcessor()
# Process multiple datasets
for i in range(10):
dates = pd.date_range('2024-01-01', periods=100, freq='D')
sales_data = pd.DataFrame({
'date': dates,
'product_name': [f'Product_{i}'] * 100,
'quantity': [45] * 100
})
# This would normally be async, but for memory testing we'll mock it
temporal_features = data_processor._add_temporal_features(
pd.DataFrame({'date': dates})
)
assert len(temporal_features) == 100
# Force garbage collection
import gc
gc.collect()
final_memory = process.memory_info().rss / 1024 / 1024 # MB
memory_increase = final_memory - initial_memory
# Memory increase should be reasonable (less than 100MB for this test)
assert memory_increase < 100, f"Memory increased by {memory_increase:.1f}MB"
except ImportError:
# Skip test if psutil is not available
pytest.skip("psutil not available, skipping memory efficiency test")
class TestMLErrorHandling:
"""Test error handling and edge cases"""
@pytest.mark.asyncio
async def test_corrupted_data_handling(self):
"""Test handling of corrupted or invalid data"""
data_processor = BakeryDataProcessor()
# Test with NaN values
corrupted_sales = pd.DataFrame({
'date': pd.date_range('2024-01-01', periods=35, freq='D'),
'product_name': ['Pan Integral'] * 35,
'quantity': [np.nan if i % 5 == 0 else 45 for i in range(35)]
})
result = await data_processor.prepare_training_data(
sales_data=corrupted_sales,
weather_data=pd.DataFrame(),
traffic_data=pd.DataFrame(),
product_name="Pan Integral"
)
# Should handle NaN values appropriately
assert not result['y'].isna().all() # Some values should be preserved
@pytest.mark.asyncio
async def test_missing_product_data(self):
"""Test handling when requested product is not in data"""
data_processor = BakeryDataProcessor()
sales_data = pd.DataFrame({
'date': pd.date_range('2024-01-01', periods=35, freq='D'),
'product_name': ['Other Product'] * 35,
'quantity': [45] * 35
})
with pytest.raises((ValueError, KeyError)):
await data_processor.prepare_training_data(
sales_data=sales_data,
weather_data=pd.DataFrame(),
traffic_data=pd.DataFrame(),
product_name="Pan Integral" # This product doesn't exist
)
@pytest.mark.asyncio
async def test_date_format_variations(self):
"""Test handling of different date formats"""
data_processor = BakeryDataProcessor()
# Test with string dates
string_date_sales = pd.DataFrame({
'date': ['2024-01-01', '2024-01-02', '2024-01-03'] * 12, # 36 days
'product_name': ['Pan Integral'] * 36,
'quantity': [45] * 36
})
result = await data_processor.prepare_training_data(
sales_data=string_date_sales,
weather_data=pd.DataFrame(),
traffic_data=pd.DataFrame(),
product_name="Pan Integral"
)
# Should convert and handle string dates
assert result['ds'].dtype == 'datetime64[ns]'
assert len(result) > 0

View File

@@ -647,14 +647,7 @@ fi
# Training request with real products
TRAINING_DATA="{
\"tenant_id\": \"$TENANT_ID\",
\"selected_products\": [$REAL_PRODUCTS],
\"include_weather\": \"True\",
\"include_traffic\": \"True\",
\"training_parameters\": {
\"forecast_horizon\": 7,
\"validation_split\": 0.2,
\"model_type\": \"lstm\"
\"tenant_id\": \"$TENANT_ID\"
}
}"
@@ -682,57 +675,6 @@ fi
if [ -n "$TRAINING_TASK_ID" ]; then
log_success "Training started successfully - Task ID: $TRAINING_TASK_ID"
log_step "4.2. Monitoring training progress"
# Poll training status (limited polling for test)
MAX_POLLS=100
POLL_COUNT=0
while [ $POLL_COUNT -lt $MAX_POLLS ]; do
echo "Polling training status... ($((POLL_COUNT+1))/$MAX_POLLS)"
STATUS_RESPONSE=$(curl -s -X GET "$API_BASE/api/v1/tenants/$TENANT_ID/training/jobs/$TRAINING_TASK_ID" \
-H "Authorization: Bearer $ACCESS_TOKEN" \
-H "X-Tenant-ID: $TENANT_ID")
echo "Status Response:"
echo "$STATUS_RESPONSE" | python3 -m json.tool 2>/dev/null || echo "$STATUS_RESPONSE"
STATUS=$(extract_json_field "$STATUS_RESPONSE" "status")
PROGRESS=$(extract_json_field "$STATUS_RESPONSE" "progress")
if [ -n "$PROGRESS" ]; then
echo " Progress: $PROGRESS%"
fi
case "$STATUS" in
"completed"|"success")
log_success "Training completed successfully!"
break
;;
"failed"|"error")
log_error "Training failed!"
echo "Status response: $STATUS_RESPONSE"
break
;;
"running"|"in_progress"|"pending")
echo " Status: $STATUS (continuing...)"
;;
*)
log_warning "Unknown status: $STATUS"
;;
esac
POLL_COUNT=$((POLL_COUNT+1))
sleep 2
done
if [ $POLL_COUNT -eq $MAX_POLLS ]; then
log_warning "Training status polling completed - may still be in progress"
else
log_success "Training monitoring completed"
fi
else
log_warning "Could not start training - task ID not found"
fi