Improve gateway service 2
This commit is contained in:
39
scripts/test_unified_auth.sh
Executable file
39
scripts/test_unified_auth.sh
Executable file
@@ -0,0 +1,39 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
echo "Testing Unified Authentication System"
|
||||||
|
|
||||||
|
# 1. Get auth token
|
||||||
|
echo "1. Getting authentication token..."
|
||||||
|
TOKEN=$(curl -s -X POST http://localhost:8000/api/v1/auth/login \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"email": "test@bakery.com", "password": "testpass123"}' \
|
||||||
|
| jq -r '.access_token')
|
||||||
|
|
||||||
|
echo "Token obtained: ${TOKEN:0:20}..."
|
||||||
|
|
||||||
|
# 2. Test data service through gateway
|
||||||
|
echo -e "\n2. Testing data service through gateway..."
|
||||||
|
curl -s -X GET http://localhost:8000/api/v1/data/sales \
|
||||||
|
-H "Authorization: Bearer $TOKEN" \
|
||||||
|
-H "X-Tenant-ID: test-tenant" \
|
||||||
|
| jq '.'
|
||||||
|
|
||||||
|
# 3. Test training service through gateway
|
||||||
|
echo -e "\n3. Testing training service through gateway..."
|
||||||
|
curl -s -X POST http://localhost:8000/api/v1/training/jobs \
|
||||||
|
-H "Authorization: Bearer $TOKEN" \
|
||||||
|
-H "X-Tenant-ID: test-tenant" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"include_weather": true,
|
||||||
|
"include_traffic": false,
|
||||||
|
"min_data_points": 30
|
||||||
|
}' \
|
||||||
|
| jq '.'
|
||||||
|
|
||||||
|
# 4. Test direct service call (should work with headers)
|
||||||
|
echo -e "\n4. Testing direct service call..."
|
||||||
|
curl -s -X GET http://localhost:8002/health \
|
||||||
|
| jq '.'
|
||||||
|
|
||||||
|
echo -e "\nUnified authentication test complete!"
|
||||||
@@ -11,7 +11,7 @@ import structlog
|
|||||||
from app.schemas.training import (
|
from app.schemas.training import (
|
||||||
TrainingJobRequest,
|
TrainingJobRequest,
|
||||||
TrainingJobResponse,
|
TrainingJobResponse,
|
||||||
TrainingJobStatus,
|
TrainingStatus,
|
||||||
SingleProductTrainingRequest,
|
SingleProductTrainingRequest,
|
||||||
TrainingJobProgress,
|
TrainingJobProgress,
|
||||||
DataValidationRequest,
|
DataValidationRequest,
|
||||||
@@ -89,7 +89,7 @@ async def start_training_job(
|
|||||||
|
|
||||||
@router.get("/jobs", response_model=List[TrainingJobResponse])
|
@router.get("/jobs", response_model=List[TrainingJobResponse])
|
||||||
async def get_training_jobs(
|
async def get_training_jobs(
|
||||||
status: Optional[TrainingJobStatus] = Query(None),
|
status: Optional[TrainingStatus] = Query(None, description="Filter jobs by status"),
|
||||||
limit: int = Query(100, ge=1, le=1000),
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
offset: int = Query(0, ge=0),
|
offset: int = Query(0, ge=0),
|
||||||
tenant_id: str = Depends(get_current_tenant_id_dep),
|
tenant_id: str = Depends(get_current_tenant_id_dep),
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
# services/training/app/schemas/training.py
|
# services/training/app/schemas/training.py
|
||||||
"""
|
"""
|
||||||
Pydantic schemas for training service
|
Complete schema definitions for training service
|
||||||
|
Includes all request/response schemas used by the API endpoints
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
from typing import Dict, List, Any, Optional
|
from typing import List, Optional, Dict, Any, Union
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
class TrainingStatus(str, Enum):
|
class TrainingStatus(str, Enum):
|
||||||
"""Training job status enumeration"""
|
"""Training job status enumeration"""
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
@@ -16,33 +18,32 @@ class TrainingStatus(str, Enum):
|
|||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
CANCELLED = "cancelled"
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
class TrainingJobRequest(BaseModel):
|
class TrainingJobRequest(BaseModel):
|
||||||
"""Request schema for starting a training job"""
|
"""Request schema for starting a training job"""
|
||||||
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, train all)")
|
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, trains all)")
|
||||||
include_weather: bool = Field(True, description="Include weather data in training")
|
include_weather: bool = Field(True, description="Include weather data in training")
|
||||||
include_traffic: bool = Field(True, description="Include traffic data in training")
|
include_traffic: bool = Field(True, description="Include traffic data in training")
|
||||||
start_date: Optional[datetime] = Field(None, description="Start date for training data")
|
start_date: Optional[datetime] = Field(None, description="Start date for training data")
|
||||||
end_date: Optional[datetime] = Field(None, description="End date for training data")
|
end_date: Optional[datetime] = Field(None, description="End date for training data")
|
||||||
min_data_points: int = Field(30, description="Minimum data points required per product")
|
|
||||||
estimated_duration: Optional[int] = Field(None, description="Estimated duration in minutes")
|
|
||||||
|
|
||||||
# Prophet-specific parameters
|
# Prophet-specific parameters
|
||||||
seasonality_mode: str = Field("additive", description="Prophet seasonality mode")
|
seasonality_mode: str = Field("additive", description="Prophet seasonality mode", pattern="^(additive|multiplicative)$")
|
||||||
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
|
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
|
||||||
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
||||||
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
||||||
|
|
||||||
|
# Advanced configuration
|
||||||
|
force_retrain: bool = Field(False, description="Force retraining even if recent model exists")
|
||||||
|
parallel_training: bool = Field(True, description="Train products in parallel")
|
||||||
|
max_workers: int = Field(4, description="Maximum parallel workers", ge=1, le=10)
|
||||||
|
|
||||||
@validator('seasonality_mode')
|
@validator('seasonality_mode')
|
||||||
def validate_seasonality_mode(cls, v):
|
def validate_seasonality_mode(cls, v):
|
||||||
if v not in ['additive', 'multiplicative']:
|
if v not in ['additive', 'multiplicative']:
|
||||||
raise ValueError('seasonality_mode must be additive or multiplicative')
|
raise ValueError('seasonality_mode must be either "additive" or "multiplicative"')
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@validator('min_data_points')
|
|
||||||
def validate_min_data_points(cls, v):
|
|
||||||
if v < 7:
|
|
||||||
raise ValueError('min_data_points must be at least 7')
|
|
||||||
return v
|
|
||||||
|
|
||||||
class SingleProductTrainingRequest(BaseModel):
|
class SingleProductTrainingRequest(BaseModel):
|
||||||
"""Request schema for training a single product"""
|
"""Request schema for training a single product"""
|
||||||
@@ -57,6 +58,7 @@ class SingleProductTrainingRequest(BaseModel):
|
|||||||
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
||||||
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
||||||
|
|
||||||
|
|
||||||
class TrainingJobResponse(BaseModel):
|
class TrainingJobResponse(BaseModel):
|
||||||
"""Response schema for training job creation"""
|
"""Response schema for training job creation"""
|
||||||
job_id: str = Field(..., description="Unique training job identifier")
|
job_id: str = Field(..., description="Unique training job identifier")
|
||||||
@@ -66,17 +68,60 @@ class TrainingJobResponse(BaseModel):
|
|||||||
created_at: datetime = Field(..., description="Job creation timestamp")
|
created_at: datetime = Field(..., description="Job creation timestamp")
|
||||||
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
|
estimated_duration_minutes: int = Field(..., description="Estimated completion time in minutes")
|
||||||
|
|
||||||
class TrainingStatusResponse(BaseModel):
|
|
||||||
"""Response schema for training job status"""
|
class TrainingJobStatus(BaseModel):
|
||||||
|
"""Response schema for training job status checks"""
|
||||||
job_id: str = Field(..., description="Training job identifier")
|
job_id: str = Field(..., description="Training job identifier")
|
||||||
status: TrainingStatus = Field(..., description="Current job status")
|
status: TrainingStatus = Field(..., description="Current job status")
|
||||||
progress: int = Field(0, description="Progress percentage (0-100)")
|
progress: int = Field(0, description="Progress percentage (0-100)")
|
||||||
current_step: str = Field("", description="Current processing step")
|
current_step: str = Field("", description="Current processing step")
|
||||||
started_at: datetime = Field(..., description="Job start timestamp")
|
started_at: datetime = Field(..., description="Job start timestamp")
|
||||||
completed_at: Optional[datetime] = Field(None, description="Job completion timestamp")
|
completed_at: Optional[datetime] = Field(None, description="Job completion timestamp")
|
||||||
results: Optional[Dict[str, Any]] = Field(None, description="Training results")
|
products_total: int = Field(0, description="Total number of products to train")
|
||||||
|
products_completed: int = Field(0, description="Number of products completed")
|
||||||
|
products_failed: int = Field(0, description="Number of products that failed")
|
||||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingJobProgress(BaseModel):
|
||||||
|
"""Schema for real-time training job progress updates"""
|
||||||
|
job_id: str = Field(..., description="Training job identifier")
|
||||||
|
status: TrainingStatus = Field(..., description="Current job status")
|
||||||
|
progress: int = Field(0, description="Progress percentage (0-100)", ge=0, le=100)
|
||||||
|
current_step: str = Field(..., description="Current processing step")
|
||||||
|
current_product: Optional[str] = Field(None, description="Currently training product")
|
||||||
|
products_completed: int = Field(0, description="Number of products completed")
|
||||||
|
products_total: int = Field(0, description="Total number of products")
|
||||||
|
estimated_time_remaining_minutes: Optional[int] = Field(None, description="Estimated time remaining")
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.now, description="Progress update timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class DataValidationRequest(BaseModel):
|
||||||
|
"""Request schema for validating training data"""
|
||||||
|
products: Optional[List[str]] = Field(None, description="Specific products to validate (if None, validates all)")
|
||||||
|
min_data_points: int = Field(30, description="Minimum required data points per product", ge=10, le=1000)
|
||||||
|
start_date: Optional[datetime] = Field(None, description="Start date for data validation")
|
||||||
|
end_date: Optional[datetime] = Field(None, description="End date for data validation")
|
||||||
|
|
||||||
|
@validator('min_data_points')
|
||||||
|
def validate_min_data_points(cls, v):
|
||||||
|
if v < 10:
|
||||||
|
raise ValueError('min_data_points must be at least 10')
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class DataValidationResponse(BaseModel):
|
||||||
|
"""Response schema for data validation results"""
|
||||||
|
is_valid: bool = Field(..., description="Whether the data is valid for training")
|
||||||
|
issues: List[str] = Field(default_factory=list, description="List of data quality issues")
|
||||||
|
recommendations: List[str] = Field(default_factory=list, description="Recommendations for improvement")
|
||||||
|
estimated_time_minutes: int = Field(..., description="Estimated training time in minutes")
|
||||||
|
products_analyzed: int = Field(..., description="Number of products analyzed")
|
||||||
|
total_data_points: int = Field(..., description="Total data points available")
|
||||||
|
products_with_insufficient_data: List[str] = Field(default_factory=list, description="Products with insufficient data")
|
||||||
|
data_quality_score: float = Field(0.0, description="Overall data quality score (0-1)", ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
"""Schema for trained model information"""
|
"""Schema for trained model information"""
|
||||||
model_id: str = Field(..., description="Unique model identifier")
|
model_id: str = Field(..., description="Unique model identifier")
|
||||||
@@ -89,6 +134,7 @@ class ModelInfo(BaseModel):
|
|||||||
trained_at: datetime = Field(..., description="Training completion timestamp")
|
trained_at: datetime = Field(..., description="Training completion timestamp")
|
||||||
data_period: Dict[str, str] = Field(..., description="Training data period")
|
data_period: Dict[str, str] = Field(..., description="Training data period")
|
||||||
|
|
||||||
|
|
||||||
class ProductTrainingResult(BaseModel):
|
class ProductTrainingResult(BaseModel):
|
||||||
"""Schema for individual product training result"""
|
"""Schema for individual product training result"""
|
||||||
product_name: str = Field(..., description="Product name")
|
product_name: str = Field(..., description="Product name")
|
||||||
@@ -97,6 +143,8 @@ class ProductTrainingResult(BaseModel):
|
|||||||
data_points: int = Field(..., description="Number of data points used")
|
data_points: int = Field(..., description="Number of data points used")
|
||||||
error_message: Optional[str] = Field(None, description="Error message if failed")
|
error_message: Optional[str] = Field(None, description="Error message if failed")
|
||||||
trained_at: datetime = Field(..., description="Training completion timestamp")
|
trained_at: datetime = Field(..., description="Training completion timestamp")
|
||||||
|
training_duration_seconds: Optional[float] = Field(None, description="Training duration in seconds")
|
||||||
|
|
||||||
|
|
||||||
class TrainingResultsResponse(BaseModel):
|
class TrainingResultsResponse(BaseModel):
|
||||||
"""Response schema for complete training results"""
|
"""Response schema for complete training results"""
|
||||||
@@ -110,6 +158,7 @@ class TrainingResultsResponse(BaseModel):
|
|||||||
summary: Dict[str, Any] = Field(..., description="Training summary statistics")
|
summary: Dict[str, Any] = Field(..., description="Training summary statistics")
|
||||||
completed_at: datetime = Field(..., description="Job completion timestamp")
|
completed_at: datetime = Field(..., description="Job completion timestamp")
|
||||||
|
|
||||||
|
|
||||||
class TrainingValidationResult(BaseModel):
|
class TrainingValidationResult(BaseModel):
|
||||||
"""Schema for training data validation results"""
|
"""Schema for training data validation results"""
|
||||||
is_valid: bool = Field(..., description="Whether the data is valid for training")
|
is_valid: bool = Field(..., description="Whether the data is valid for training")
|
||||||
@@ -119,6 +168,7 @@ class TrainingValidationResult(BaseModel):
|
|||||||
products_analyzed: int = Field(..., description="Number of products analyzed")
|
products_analyzed: int = Field(..., description="Number of products analyzed")
|
||||||
total_data_points: int = Field(..., description="Total data points available")
|
total_data_points: int = Field(..., description="Total data points available")
|
||||||
|
|
||||||
|
|
||||||
class TrainingMetrics(BaseModel):
|
class TrainingMetrics(BaseModel):
|
||||||
"""Schema for training performance metrics"""
|
"""Schema for training performance metrics"""
|
||||||
mae: float = Field(..., description="Mean Absolute Error")
|
mae: float = Field(..., description="Mean Absolute Error")
|
||||||
@@ -129,6 +179,7 @@ class TrainingMetrics(BaseModel):
|
|||||||
mean_actual: float = Field(..., description="Mean of actual values")
|
mean_actual: float = Field(..., description="Mean of actual values")
|
||||||
mean_predicted: float = Field(..., description="Mean of predicted values")
|
mean_predicted: float = Field(..., description="Mean of predicted values")
|
||||||
|
|
||||||
|
|
||||||
class ExternalDataConfig(BaseModel):
|
class ExternalDataConfig(BaseModel):
|
||||||
"""Configuration for external data sources"""
|
"""Configuration for external data sources"""
|
||||||
weather_enabled: bool = Field(True, description="Enable weather data")
|
weather_enabled: bool = Field(True, description="Enable weather data")
|
||||||
@@ -142,6 +193,7 @@ class ExternalDataConfig(BaseModel):
|
|||||||
description="Traffic features to include"
|
description="Traffic features to include"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TrainingJobConfig(BaseModel):
|
class TrainingJobConfig(BaseModel):
|
||||||
"""Complete training job configuration"""
|
"""Complete training job configuration"""
|
||||||
external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig)
|
external_data: ExternalDataConfig = Field(default_factory=ExternalDataConfig)
|
||||||
@@ -163,6 +215,7 @@ class TrainingJobConfig(BaseModel):
|
|||||||
description="Data validation parameters"
|
description="Data validation parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TrainedModelResponse(BaseModel):
|
class TrainedModelResponse(BaseModel):
|
||||||
"""Response schema for trained model information"""
|
"""Response schema for trained model information"""
|
||||||
model_id: str = Field(..., description="Unique model identifier")
|
model_id: str = Field(..., description="Unique model identifier")
|
||||||
@@ -179,3 +232,60 @@ class TrainedModelResponse(BaseModel):
|
|||||||
created_at: datetime = Field(..., description="Model creation timestamp")
|
created_at: datetime = Field(..., description="Model creation timestamp")
|
||||||
data_period_start: Optional[datetime] = Field(None, description="Training data start date")
|
data_period_start: Optional[datetime] = Field(None, description="Training data start date")
|
||||||
data_period_end: Optional[datetime] = Field(None, description="Training data end date")
|
data_period_end: Optional[datetime] = Field(None, description="Training data end date")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelTrainingStats(BaseModel):
|
||||||
|
"""Schema for model training statistics"""
|
||||||
|
total_models: int = Field(..., description="Total number of trained models")
|
||||||
|
active_models: int = Field(..., description="Number of active models")
|
||||||
|
last_training_date: Optional[datetime] = Field(None, description="Last training date")
|
||||||
|
avg_training_time_minutes: float = Field(..., description="Average training time in minutes")
|
||||||
|
success_rate: float = Field(..., description="Training success rate (0-1)")
|
||||||
|
|
||||||
|
|
||||||
|
class BulkTrainingRequest(BaseModel):
|
||||||
|
"""Request schema for bulk training operations"""
|
||||||
|
tenant_ids: List[str] = Field(..., description="List of tenant IDs to train")
|
||||||
|
config: TrainingJobConfig = Field(default_factory=TrainingJobConfig, description="Training configuration")
|
||||||
|
priority: int = Field(1, description="Training priority (1-10)", ge=1, le=10)
|
||||||
|
schedule_time: Optional[datetime] = Field(None, description="Schedule training for specific time")
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingScheduleResponse(BaseModel):
|
||||||
|
"""Response schema for scheduled training jobs"""
|
||||||
|
schedule_id: str = Field(..., description="Unique schedule identifier")
|
||||||
|
tenant_ids: List[str] = Field(..., description="Scheduled tenant IDs")
|
||||||
|
scheduled_time: datetime = Field(..., description="Scheduled execution time")
|
||||||
|
status: str = Field(..., description="Schedule status")
|
||||||
|
created_at: datetime = Field(..., description="Schedule creation timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
# WebSocket response schemas for real-time updates
|
||||||
|
class TrainingProgressUpdate(BaseModel):
|
||||||
|
"""WebSocket message for training progress updates"""
|
||||||
|
type: str = Field("training_progress", description="Message type")
|
||||||
|
job_id: str = Field(..., description="Training job identifier")
|
||||||
|
progress: TrainingJobProgress = Field(..., description="Progress information")
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingCompletedUpdate(BaseModel):
|
||||||
|
"""WebSocket message for training completion"""
|
||||||
|
type: str = Field("training_completed", description="Message type")
|
||||||
|
job_id: str = Field(..., description="Training job identifier")
|
||||||
|
results: TrainingResultsResponse = Field(..., description="Training results")
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingErrorUpdate(BaseModel):
|
||||||
|
"""WebSocket message for training errors"""
|
||||||
|
type: str = Field("training_error", description="Message type")
|
||||||
|
job_id: str = Field(..., description="Training job identifier")
|
||||||
|
error: str = Field(..., description="Error message")
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.now, description="Error timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
# Union type for all WebSocket messages
|
||||||
|
TrainingWebSocketMessage = Union[
|
||||||
|
TrainingProgressUpdate,
|
||||||
|
TrainingCompletedUpdate,
|
||||||
|
TrainingErrorUpdate
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user