Improve gateway service 2

This commit is contained in:
Urtzi Alfaro
2025-07-20 07:43:45 +02:00
parent 8cd433c0cd
commit 5f56c2fd00
4 changed files with 169 additions and 20 deletions

39
scripts/test_unified_auth.sh Executable file
View File

@@ -0,0 +1,39 @@
#!/bin/bash
echo "Testing Unified Authentication System"
# 1. Get auth token
echo "1. Getting authentication token..."
TOKEN=$(curl -s -X POST http://localhost:8000/api/v1/auth/login \
-H "Content-Type: application/json" \
-d '{"email": "test@bakery.com", "password": "testpass123"}' \
| jq -r '.access_token')
echo "Token obtained: ${TOKEN:0:20}..."
# 2. Test data service through gateway
echo -e "\n2. Testing data service through gateway..."
curl -s -X GET http://localhost:8000/api/v1/data/sales \
-H "Authorization: Bearer $TOKEN" \
-H "X-Tenant-ID: test-tenant" \
| jq '.'
# 3. Test training service through gateway
echo -e "\n3. Testing training service through gateway..."
curl -s -X POST http://localhost:8000/api/v1/training/jobs \
-H "Authorization: Bearer $TOKEN" \
-H "X-Tenant-ID: test-tenant" \
-H "Content-Type: application/json" \
-d '{
"include_weather": true,
"include_traffic": false,
"min_data_points": 30
}' \
| jq '.'
# 4. Test direct service call (should work with headers)
echo -e "\n4. Testing direct service call..."
curl -s -X GET http://localhost:8002/health \
| jq '.'
echo -e "\nUnified authentication test complete!"

View File

@@ -11,7 +11,7 @@ import structlog
from app.schemas.training import ( from app.schemas.training import (
TrainingJobRequest, TrainingJobRequest,
TrainingJobResponse, TrainingJobResponse,
TrainingJobStatus, TrainingStatus,
SingleProductTrainingRequest, SingleProductTrainingRequest,
TrainingJobProgress, TrainingJobProgress,
DataValidationRequest, DataValidationRequest,
@@ -89,7 +89,7 @@ async def start_training_job(
@router.get("/jobs", response_model=List[TrainingJobResponse]) @router.get("/jobs", response_model=List[TrainingJobResponse])
async def get_training_jobs( async def get_training_jobs(
status: Optional[TrainingJobStatus] = Query(None), status: Optional[TrainingStatus] = Query(None, description="Filter jobs by status"),
limit: int = Query(100, ge=1, le=1000), limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
tenant_id: str = Depends(get_current_tenant_id_dep), tenant_id: str = Depends(get_current_tenant_id_dep),

View File

@@ -1,13 +1,15 @@
# services/training/app/schemas/training.py # services/training/app/schemas/training.py
""" """
Pydantic schemas for training service Complete schema definitions for training service
Includes all request/response schemas used by the API endpoints
""" """
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from typing import Dict, List, Any, Optional from typing import List, Optional, Dict, Any, Union
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
class TrainingStatus(str, Enum): class TrainingStatus(str, Enum):
"""Training job status enumeration""" """Training job status enumeration"""
PENDING = "pending" PENDING = "pending"
@@ -16,34 +18,33 @@ class TrainingStatus(str, Enum):
FAILED = "failed" FAILED = "failed"
CANCELLED = "cancelled" CANCELLED = "cancelled"
class TrainingJobRequest(BaseModel): class TrainingJobRequest(BaseModel):
"""Request schema for starting a training job""" """Request schema for starting a training job"""
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, train all)") products: Optional[List[str]] = Field(None, description="Specific products to train (if None, trains all)")
include_weather: bool = Field(True, description="Include weather data in training") include_weather: bool = Field(True, description="Include weather data in training")
include_traffic: bool = Field(True, description="Include traffic data in training") include_traffic: bool = Field(True, description="Include traffic data in training")
start_date: Optional[datetime] = Field(None, description="Start date for training data") start_date: Optional[datetime] = Field(None, description="Start date for training data")
end_date: Optional[datetime] = Field(None, description="End date for training data") end_date: Optional[datetime] = Field(None, description="End date for training data")
min_data_points: int = Field(30, description="Minimum data points required per product")
estimated_duration: Optional[int] = Field(None, description="Estimated duration in minutes")
# Prophet-specific parameters # Prophet-specific parameters
seasonality_mode: str = Field("additive", description="Prophet seasonality mode") seasonality_mode: str = Field("additive", description="Prophet seasonality mode", pattern="^(additive|multiplicative)$")
daily_seasonality: bool = Field(True, description="Enable daily seasonality") daily_seasonality: bool = Field(True, description="Enable daily seasonality")
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality") weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality") yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
# Advanced configuration
force_retrain: bool = Field(False, description="Force retraining even if recent model exists")
parallel_training: bool = Field(True, description="Train products in parallel")
max_workers: int = Field(4, description="Maximum parallel workers", ge=1, le=10)
@validator('seasonality_mode') @validator('seasonality_mode')
def validate_seasonality_mode(cls, v): def validate_seasonality_mode(cls, v):
if v not in ['additive', 'multiplicative']: if v not in ['additive', 'multiplicative']:
raise ValueError('seasonality_mode must be additive or multiplicative') raise ValueError('seasonality_mode must be either "additive" or "multiplicative"')
return v
@validator('min_data_points')
def validate_min_data_points(cls, v):
if v < 7:
raise ValueError('min_data_points must be at least 7')
return v return v
class SingleProductTrainingRequest(BaseModel): class SingleProductTrainingRequest(BaseModel):
"""Request schema for training a single product""" """Request schema for training a single product"""
include_weather: bool = Field(True, description="Include weather data in training") include_weather: bool = Field(True, description="Include weather data in training")
@@ -57,6 +58,7 @@ class SingleProductTrainingRequest(BaseModel):
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality") weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality") yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
class TrainingJobResponse(BaseModel): class TrainingJobResponse(BaseModel):
"""Response schema for training job creation""" """Response schema for training job creation"""
job_id: str = Field(..., description="Unique training job identifier") job_id: str = Field(..., description="Unique training job identifier")
@@ -66,17 +68,60 @@ class TrainingJobResponse(BaseModel):
created_at: datetime = Field(..., description="Job creation timestamp") created_at: datetime = Field(..., description="Job creation timestamp")
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes") estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
class TrainingStatusResponse(BaseModel):
"""Response schema for training job status""" class TrainingJobStatus(BaseModel):
"""Response schema for training job status checks"""
job_id: str = Field(..., description="Training job identifier") job_id: str = Field(..., description="Training job identifier")
status: TrainingStatus = Field(..., description="Current job status") status: TrainingStatus = Field(..., description="Current job status")
progress: int = Field(0, description="Progress percentage (0-100)") progress: int = Field(0, description="Progress percentage (0-100)")
current_step: str = Field("", description="Current processing step") current_step: str = Field("", description="Current processing step")
started_at: datetime = Field(..., description="Job start timestamp") started_at: datetime = Field(..., description="Job start timestamp")
completed_at: Optional[datetime] = Field(None, description="Job completion timestamp") completed_at: Optional[datetime] = Field(None, description="Job completion timestamp")
results: Optional[Dict[str, Any]] = Field(None, description="Training results") products_total: int = Field(0, description="Total number of products to train")
products_completed: int = Field(0, description="Number of products completed")
products_failed: int = Field(0, description="Number of products that failed")
error_message: Optional[str] = Field(None, description="Error message if failed") error_message: Optional[str] = Field(None, description="Error message if failed")
class TrainingJobProgress(BaseModel):
"""Schema for real-time training job progress updates"""
job_id: str = Field(..., description="Training job identifier")
status: TrainingStatus = Field(..., description="Current job status")
progress: int = Field(0, description="Progress percentage (0-100)", ge=0, le=100)
current_step: str = Field(..., description="Current processing step")
current_product: Optional[str] = Field(None, description="Currently training product")
products_completed: int = Field(0, description="Number of products completed")
products_total: int = Field(0, description="Total number of products")
estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining")
timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp")
class DataValidationRequest(BaseModel):
"""Request schema for validating training data"""
products: Optional[List[str]] = Field(None, description="Specific products to validate (if None, validates all)")
min_data_points: int = Field(30, description="Minimum required data points per product", ge=10, le=1000)
start_date: Optional[datetime] = Field(None, description="Start date for data validation")
end_date: Optional[datetime] = Field(None, description="End date for data validation")
@validator('min_data_points')
def validate_min_data_points(cls, v):
if v < 10:
raise ValueError('min_data_points must be at least 10')
return v
class DataValidationResponse(BaseModel):
"""Response schema for data validation results"""
is_valid: bool = Field(..., description="Whether the data is valid for training")
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
products_analyzed: int = Field(..., description="Number of products analyzed")
total_data_points: int = Field(..., description="Total data points available")
products_with_insufficient_data: List[str] = Field(default_factory=list, description="Products with insufficient data")
data_quality_score: float = Field(0.0, description="Overall data quality score (0-1)", ge=0.0, le=1.0)
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
"""Schema for trained model information""" """Schema for trained model information"""
model_id: str = Field(..., description="Unique model identifier") model_id: str = Field(..., description="Unique model identifier")
@@ -89,6 +134,7 @@ class ModelInfo(BaseModel):
trained_at: datetime = Field(..., description="Training completion timestamp") trained_at: datetime = Field(..., description="Training completion timestamp")
data_period: Dict[str, str] = Field(..., description="Training data period") data_period: Dict[str, str] = Field(..., description="Training data period")
class ProductTrainingResult(BaseModel): class ProductTrainingResult(BaseModel):
"""Schema for individual product training result""" """Schema for individual product training result"""
product_name: str = Field(..., description="Product name") product_name: str = Field(..., description="Product name")
@@ -97,6 +143,8 @@ class ProductTrainingResult(BaseModel):
data_points: int = Field(..., description="Number of data points used") data_points: int = Field(..., description="Number of data points used")
error_message: Optional[str] = Field(None, description="Error message if failed") error_message: Optional[str] = Field(None, description="Error message if failed")
trained_at: datetime = Field(..., description="Training completion timestamp") trained_at: datetime = Field(..., description="Training completion timestamp")
training_duration_seconds: Optional[float] = Field(None, description="Training duration in seconds")
class TrainingResultsResponse(BaseModel): class TrainingResultsResponse(BaseModel):
"""Response schema for complete training results""" """Response schema for complete training results"""
@@ -110,6 +158,7 @@ class TrainingResultsResponse(BaseModel):
summary: Dict[str, Any] = Field(..., description="Training summary statistics") summary: Dict[str, Any] = Field(..., description="Training summary statistics")
completed_at: datetime = Field(..., description="Job completion timestamp") completed_at: datetime = Field(..., description="Job completion timestamp")
class TrainingValidationResult(BaseModel): class TrainingValidationResult(BaseModel):
"""Schema for training data validation results""" """Schema for training data validation results"""
is_valid: bool = Field(..., description="Whether the data is valid for training") is_valid: bool = Field(..., description="Whether the data is valid for training")
@@ -119,6 +168,7 @@ class TrainingValidationResult(BaseModel):
products_analyzed: int = Field(..., description="Number of products analyzed") products_analyzed: int = Field(..., description="Number of products analyzed")
total_data_points: int = Field(..., description="Total data points available") total_data_points: int = Field(..., description="Total data points available")
class TrainingMetrics(BaseModel): class TrainingMetrics(BaseModel):
"""Schema for training performance metrics""" """Schema for training performance metrics"""
mae: float = Field(..., description="Mean Absolute Error") mae: float = Field(..., description="Mean Absolute Error")
@@ -129,6 +179,7 @@ class TrainingMetrics(BaseModel):
mean_actual: float = Field(..., description="Mean of actual values") mean_actual: float = Field(..., description="Mean of actual values")
mean_predicted: float = Field(..., description="Mean of predicted values") mean_predicted: float = Field(..., description="Mean of predicted values")
class ExternalDataConfig(BaseModel): class ExternalDataConfig(BaseModel):
"""Configuration for external data sources""" """Configuration for external data sources"""
weather_enabled: bool = Field(True, description="Enable weather data") weather_enabled: bool = Field(True, description="Enable weather data")
@@ -142,6 +193,7 @@ class ExternalDataConfig(BaseModel):
description="Traffic features to include" description="Traffic features to include"
) )
class TrainingJobConfig(BaseModel): class TrainingJobConfig(BaseModel):
"""Complete training job configuration""" """Complete training job configuration"""
external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig) external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig)
@@ -162,7 +214,8 @@ class TrainingJobConfig(BaseModel):
default_factory=lambda: {"min_data_points": 30}, default_factory=lambda: {"min_data_points": 30},
description="Data validation parameters" description="Data validation parameters"
) )
class TrainedModelResponse(BaseModel): class TrainedModelResponse(BaseModel):
"""Response schema for trained model information""" """Response schema for trained model information"""
model_id: str = Field(..., description="Unique model identifier") model_id: str = Field(..., description="Unique model identifier")
@@ -178,4 +231,61 @@ class TrainedModelResponse(BaseModel):
is_active: bool = Field(..., description="Whether model is active") is_active: bool = Field(..., description="Whether model is active")
created_at: datetime = Field(..., description="Model creation timestamp") created_at: datetime = Field(..., description="Model creation timestamp")
data_period_start: Optional[datetime] = Field(None, description="Training data start date") data_period_start: Optional[datetime] = Field(None, description="Training data start date")
data_period_end: Optional[datetime] = Field(None, description="Training data end date") data_period_end: Optional[datetime] = Field(None, description="Training data end date")
class ModelTrainingStats(BaseModel):
"""Schema for model training statistics"""
total_models: int = Field(..., description="Total number of trained models")
active_models: int = Field(..., description="Number of active models")
last_training_date: Optional[datetime] = Field(None, description="Last training date")
avg_training_time_minutes: float = Field(..., description="Average training time in minutes")
success_rate: float = Field(..., description="Training success rate (0-1)")
class BulkTrainingRequest(BaseModel):
"""Request schema for bulk training operations"""
tenant_ids: List[str] = Field(..., description="List of tenant IDs to train")
config: TrainingJobConfig = Field(default_factory=TrainingJobConfig, description="Training configuration")
priority: int = Field(1, description="Training priority (1-10)", ge=1, le=10)
schedule_time: Optional[datetime] = Field(None, description="Schedule training for specific time")
class TrainingScheduleResponse(BaseModel):
"""Response schema for scheduled training jobs"""
schedule_id: str = Field(..., description="Unique schedule identifier")
tenant_ids: List[str] = Field(..., description="Scheduled tenant IDs")
scheduled_time: datetime = Field(..., description="Scheduled execution time")
status: str = Field(..., description="Schedule status")
created_at: datetime = Field(..., description="Schedule creation timestamp")
# WebSocket response schemas for real-time updates
class TrainingProgressUpdate(BaseModel):
"""WebSocket message for training progress updates"""
type: str = Field("training_progress", description="Message type")
job_id: str = Field(..., description="Training job identifier")
progress: TrainingJobProgress = Field(..., description="Progress information")
class TrainingCompletedUpdate(BaseModel):
"""WebSocket message for training completion"""
type: str = Field("training_completed", description="Message type")
job_id: str = Field(..., description="Training job identifier")
results: TrainingResultsResponse = Field(..., description="Training results")
class TrainingErrorUpdate(BaseModel):
"""WebSocket message for training errors"""
type: str = Field("training_error", description="Message type")
job_id: str = Field(..., description="Training job identifier")
error: str = Field(..., description="Error message")
timestamp: datetime = Field(default_factory=datetime.now, description="Error timestamp")
# Union type for all WebSocket messages
TrainingWebSocketMessage = Union[
TrainingProgressUpdate,
TrainingCompletedUpdate,
TrainingErrorUpdate
]