diff --git a/scripts/test_unified_auth.sh b/scripts/test_unified_auth.sh new file mode 100755 index 00000000..c439dd81 --- /dev/null +++ b/scripts/test_unified_auth.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +echo "Testing Unified Authentication System" + +# 1. Get auth token +echo "1. Getting authentication token..." +TOKEN=$(curl -s -X POST http://localhost:8000/api/v1/auth/login \ + -H "Content-Type: application/json" \ + -d '{"email": "test@bakery.com", "password": "testpass123"}' \ + | jq -r '.access_token') + +echo "Token obtained: ${TOKEN:0:20}..." + +# 2. Test data service through gateway +echo -e "\n2. Testing data service through gateway..." +curl -s -X GET http://localhost:8000/api/v1/data/sales \ + -H "Authorization: Bearer $TOKEN" \ + -H "X-Tenant-ID: test-tenant" \ + | jq '.' + +# 3. Test training service through gateway +echo -e "\n3. Testing training service through gateway..." +curl -s -X POST http://localhost:8000/api/v1/training/jobs \ + -H "Authorization: Bearer $TOKEN" \ + -H "X-Tenant-ID: test-tenant" \ + -H "Content-Type: application/json" \ + -d '{ + "include_weather": true, + "include_traffic": false, + "min_data_points": 30 + }' \ + | jq '.' + +# 4. Test direct service call (should work with headers) +echo -e "\n4. Testing direct service call..." +curl -s -X GET http://localhost:8002/health \ + | jq '.' + +echo -e "\nUnified authentication test complete!" \ No newline at end of file diff --git a/services/training/app/api/training.py b/services/training/app/api/training.py index 218564c6..f0b89110 100644 --- a/services/training/app/api/training.py +++ b/services/training/app/api/training.py @@ -11,7 +11,7 @@ import structlog from app.schemas.training import ( TrainingJobRequest, TrainingJobResponse, - TrainingJobStatus, + TrainingStatus, SingleProductTrainingRequest, TrainingJobProgress, DataValidationRequest, @@ -89,7 +89,7 @@ async def start_training_job( @router.get("/jobs", response_model=List[TrainingJobResponse]) async def get_training_jobs( - status: Optional[TrainingJobStatus] = Query(None), + status: Optional[TrainingStatus] = Query(None, description="Filter jobs by status"), limit: int = Query(100, ge=1, le=1000), offset: int = Query(0, ge=0), tenant_id: str = Depends(get_current_tenant_id_dep), diff --git a/services/training/app/schemas/training.py b/services/training/app/schemas/training.py index 9d1cd244..a9c898a6 100644 --- a/services/training/app/schemas/training.py +++ b/services/training/app/schemas/training.py @@ -1,13 +1,15 @@ # services/training/app/schemas/training.py """ -Pydantic schemas for training service +Complete schema definitions for training service +Includes all request/response schemas used by the API endpoints """ from pydantic import BaseModel, Field, validator -from typing import Dict, List, Any, Optional +from typing import List, Optional, Dict, Any, Union from datetime import datetime from enum import Enum + class TrainingStatus(str, Enum): """Training job status enumeration""" PENDING = "pending" @@ -16,34 +18,33 @@ class TrainingStatus(str, Enum): FAILED = "failed" CANCELLED = "cancelled" + class TrainingJobRequest(BaseModel): """Request schema for starting a training job""" - products: Optional[List[str]] = Field(None, description="Specific products to train (if None, train all)") + products: Optional[List[str]] = Field(None, description="Specific products to train (if None, trains all)") include_weather: bool = Field(True, description="Include weather data in training") include_traffic: bool = Field(True, description="Include traffic data in training") start_date: Optional[datetime] = Field(None, description="Start date for training data") end_date: Optional[datetime] = Field(None, description="End date for training data") - min_data_points: int = Field(30, description="Minimum data points required per product") - estimated_duration: Optional[int] = Field(None, description="Estimated duration in minutes") # Prophet-specific parameters - seasonality_mode: str = Field("additive", description="Prophet seasonality mode") + seasonality_mode: str = Field("additive", description="Prophet seasonality mode", pattern="^(additive|multiplicative)$") daily_seasonality: bool = Field(True, description="Enable daily seasonality") weekly_seasonality: bool = Field(True, description="Enable weekly seasonality") yearly_seasonality: bool = Field(True, description="Enable yearly seasonality") + # Advanced configuration + force_retrain: bool = Field(False, description="Force retraining even if recent model exists") + parallel_training: bool = Field(True, description="Train products in parallel") + max_workers: int = Field(4, description="Maximum parallel workers", ge=1, le=10) + @validator('seasonality_mode') def validate_seasonality_mode(cls, v): if v not in ['additive', 'multiplicative']: - raise ValueError('seasonality_mode must be additive or multiplicative') - return v - - @validator('min_data_points') - def validate_min_data_points(cls, v): - if v < 7: - raise ValueError('min_data_points must be at least 7') + raise ValueError('seasonality_mode must be either "additive" or "multiplicative"') return v + class SingleProductTrainingRequest(BaseModel): """Request schema for training a single product""" include_weather: bool = Field(True, description="Include weather data in training") @@ -57,6 +58,7 @@ class SingleProductTrainingRequest(BaseModel): weekly_seasonality: bool = Field(True, description="Enable weekly seasonality") yearly_seasonality: bool = Field(True, description="Enable yearly seasonality") + class TrainingJobResponse(BaseModel): """Response schema for training job creation""" job_id: str = Field(..., description="Unique training job identifier") @@ -66,17 +68,60 @@ class TrainingJobResponse(BaseModel): created_at: datetime = Field(..., description="Job creation timestamp") estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes") -class TrainingStatusResponse(BaseModel): - """Response schema for training job status""" + +class TrainingJobStatus(BaseModel): + """Response schema for training job status checks""" job_id: str = Field(..., description="Training job identifier") status: TrainingStatus = Field(..., description="Current job status") progress: int = Field(0, description="Progress percentage (0-100)") current_step: str = Field("", description="Current processing step") started_at: datetime = Field(..., description="Job start timestamp") completed_at: Optional[datetime] = Field(None, description="Job completion timestamp") - results: Optional[Dict[str, Any]] = Field(None, description="Training results") + products_total: int = Field(0, description="Total number of products to train") + products_completed: int = Field(0, description="Number of products completed") + products_failed: int = Field(0, description="Number of products that failed") error_message: Optional[str] = Field(None, description="Error message if failed") + +class TrainingJobProgress(BaseModel): + """Schema for real-time training job progress updates""" + job_id: str = Field(..., description="Training job identifier") + status: TrainingStatus = Field(..., description="Current job status") + progress: int = Field(0, description="Progress percentage (0-100)", ge=0, le=100) + current_step: str = Field(..., description="Current processing step") + current_product: Optional[str] = Field(None, description="Currently training product") + products_completed: int = Field(0, description="Number of products completed") + products_total: int = Field(0, description="Total number of products") + estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining") + timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp") + + +class DataValidationRequest(BaseModel): + """Request schema for validating training data""" + products: Optional[List[str]] = Field(None, description="Specific products to validate (if None, validates all)") + min_data_points: int = Field(30, description="Minimum required data points per product", ge=10, le=1000) + start_date: Optional[datetime] = Field(None, description="Start date for data validation") + end_date: Optional[datetime] = Field(None, description="End date for data validation") + + @validator('min_data_points') + def validate_min_data_points(cls, v): + if v < 10: + raise ValueError('min_data_points must be at least 10') + return v + + +class DataValidationResponse(BaseModel): + """Response schema for data validation results""" + is_valid: bool = Field(..., description="Whether the data is valid for training") + issues: List[str] = Field(default_factory=list, description="List of data quality issues") + recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement") + estimated_time_minutes: int = Field(..., description="Estimated training time in minutes") + products_analyzed: int = Field(..., description="Number of products analyzed") + total_data_points: int = Field(..., description="Total data points available") + products_with_insufficient_data: List[str] = Field(default_factory=list, description="Products with insufficient data") + data_quality_score: float = Field(0.0, description="Overall data quality score (0-1)", ge=0.0, le=1.0) + + class ModelInfo(BaseModel): """Schema for trained model information""" model_id: str = Field(..., description="Unique model identifier") @@ -89,6 +134,7 @@ class ModelInfo(BaseModel): trained_at: datetime = Field(..., description="Training completion timestamp") data_period: Dict[str, str] = Field(..., description="Training data period") + class ProductTrainingResult(BaseModel): """Schema for individual product training result""" product_name: str = Field(..., description="Product name") @@ -97,6 +143,8 @@ class ProductTrainingResult(BaseModel): data_points: int = Field(..., description="Number of data points used") error_message: Optional[str] = Field(None, description="Error message if failed") trained_at: datetime = Field(..., description="Training completion timestamp") + training_duration_seconds: Optional[float] = Field(None, description="Training duration in seconds") + class TrainingResultsResponse(BaseModel): """Response schema for complete training results""" @@ -110,6 +158,7 @@ class TrainingResultsResponse(BaseModel): summary: Dict[str, Any] = Field(..., description="Training summary statistics") completed_at: datetime = Field(..., description="Job completion timestamp") + class TrainingValidationResult(BaseModel): """Schema for training data validation results""" is_valid: bool = Field(..., description="Whether the data is valid for training") @@ -119,6 +168,7 @@ class TrainingValidationResult(BaseModel): products_analyzed: int = Field(..., description="Number of products analyzed") total_data_points: int = Field(..., description="Total data points available") + class TrainingMetrics(BaseModel): """Schema for training performance metrics""" mae: float = Field(..., description="Mean Absolute Error") @@ -129,6 +179,7 @@ class TrainingMetrics(BaseModel): mean_actual: float = Field(..., description="Mean of actual values") mean_predicted: float = Field(..., description="Mean of predicted values") + class ExternalDataConfig(BaseModel): """Configuration for external data sources""" weather_enabled: bool = Field(True, description="Enable weather data") @@ -142,6 +193,7 @@ class ExternalDataConfig(BaseModel): description="Traffic features to include" ) + class TrainingJobConfig(BaseModel): """Complete training job configuration""" external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig) @@ -162,7 +214,8 @@ class TrainingJobConfig(BaseModel): default_factory=lambda: {"min_data_points": 30}, description="Data validation parameters" ) - + + class TrainedModelResponse(BaseModel): """Response schema for trained model information""" model_id: str = Field(..., description="Unique model identifier") @@ -178,4 +231,61 @@ class TrainedModelResponse(BaseModel): is_active: bool = Field(..., description="Whether model is active") created_at: datetime = Field(..., description="Model creation timestamp") data_period_start: Optional[datetime] = Field(None, description="Training data start date") - data_period_end: Optional[datetime] = Field(None, description="Training data end date") \ No newline at end of file + data_period_end: Optional[datetime] = Field(None, description="Training data end date") + + +class ModelTrainingStats(BaseModel): + """Schema for model training statistics""" + total_models: int = Field(..., description="Total number of trained models") + active_models: int = Field(..., description="Number of active models") + last_training_date: Optional[datetime] = Field(None, description="Last training date") + avg_training_time_minutes: float = Field(..., description="Average training time in minutes") + success_rate: float = Field(..., description="Training success rate (0-1)") + + +class BulkTrainingRequest(BaseModel): + """Request schema for bulk training operations""" + tenant_ids: List[str] = Field(..., description="List of tenant IDs to train") + config: TrainingJobConfig = Field(default_factory=TrainingJobConfig, description="Training configuration") + priority: int = Field(1, description="Training priority (1-10)", ge=1, le=10) + schedule_time: Optional[datetime] = Field(None, description="Schedule training for specific time") + + +class TrainingScheduleResponse(BaseModel): + """Response schema for scheduled training jobs""" + schedule_id: str = Field(..., description="Unique schedule identifier") + tenant_ids: List[str] = Field(..., description="Scheduled tenant IDs") + scheduled_time: datetime = Field(..., description="Scheduled execution time") + status: str = Field(..., description="Schedule status") + created_at: datetime = Field(..., description="Schedule creation timestamp") + + +# WebSocket response schemas for real-time updates +class TrainingProgressUpdate(BaseModel): + """WebSocket message for training progress updates""" + type: str = Field("training_progress", description="Message type") + job_id: str = Field(..., description="Training job identifier") + progress: TrainingJobProgress = Field(..., description="Progress information") + + +class TrainingCompletedUpdate(BaseModel): + """WebSocket message for training completion""" + type: str = Field("training_completed", description="Message type") + job_id: str = Field(..., description="Training job identifier") + results: TrainingResultsResponse = Field(..., description="Training results") + + +class TrainingErrorUpdate(BaseModel): + """WebSocket message for training errors""" + type: str = Field("training_error", description="Message type") + job_id: str = Field(..., description="Training job identifier") + error: str = Field(..., description="Error message") + timestamp: datetime = Field(default_factory=datetime.now, description="Error timestamp") + + +# Union type for all WebSocket messages +TrainingWebSocketMessage = Union[ + TrainingProgressUpdate, + TrainingCompletedUpdate, + TrainingErrorUpdate +] \ No newline at end of file diff --git a/test_data.py b/tests/test_data.py similarity index 100% rename from test_data.py rename to tests/test_data.py