From 71216f8ec9a03e47ef34fcc2ee197ad88ec90707 Mon Sep 17 00:00:00 2001 From: Urtzi Alfaro Date: Tue, 29 Jul 2025 07:53:30 +0200 Subject: [PATCH] Improve training code 4 --- .../training/app/models/training_models.py | 80 ++++++++++++ services/training/app/schemas/training.py | 17 --- tests/test_onboarding_flow.sh | 122 +----------------- 3 files changed, 84 insertions(+), 135 deletions(-) create mode 100644 services/training/app/models/training_models.py diff --git a/services/training/app/models/training_models.py b/services/training/app/models/training_models.py new file mode 100644 index 00000000..1b94aeae --- /dev/null +++ b/services/training/app/models/training_models.py @@ -0,0 +1,80 @@ +# services/training/app/models/training_models.py +""" +Database models for trained ML models +""" + +from sqlalchemy import Column, String, DateTime, Float, Integer, Text, Boolean, JSON +from sqlalchemy.ext.declarative import declarative_base +from datetime import datetime +import uuid + +Base = declarative_base() + +class TrainedModel(Base): + __tablename__ = "trained_models" + + # Primary identification + id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) + tenant_id = Column(String, nullable=False, index=True) + product_name = Column(String, nullable=False, index=True) + + # Model information + model_type = Column(String, default="prophet_optimized") + model_version = Column(String, default="1.0") + job_id = Column(String, nullable=False) + + # File storage + model_path = Column(String, nullable=False) # Path to the .pkl file + metadata_path = Column(String) # Path to metadata JSON + + # Training metrics + mape = Column(Float) + mae = Column(Float) + rmse = Column(Float) + r2_score = Column(Float) + training_samples = Column(Integer) + + # Hyperparameters and features + hyperparameters = Column(JSON) # Store optimized parameters + features_used = Column(JSON) # List of regressor columns + + # Model status + is_active = Column(Boolean, default=True) + is_production = Column(Boolean, default=False) + + # Timestamps + created_at = Column(DateTime, default=datetime.utcnow) + last_used_at = Column(DateTime) + + # Training data info + training_start_date = Column(DateTime) + training_end_date = Column(DateTime) + data_quality_score = Column(Float) + + # Additional metadata + notes = Column(Text) + created_by = Column(String) # User who triggered training + + def to_dict(self): + return { + "id": self.id, + "tenant_id": self.tenant_id, + "product_name": self.product_name, + "model_type": self.model_type, + "model_version": self.model_version, + "model_path": self.model_path, + "mape": self.mape, + "mae": self.mae, + "rmse": self.rmse, + "r2_score": self.r2_score, + "training_samples": self.training_samples, + "hyperparameters": self.hyperparameters, + "features_used": self.features_used, + "is_active": self.is_active, + "is_production": self.is_production, + "created_at": self.created_at.isoformat() if self.created_at else None, + "last_used_at": self.last_used_at.isoformat() if self.last_used_at else None, + "training_start_date": self.training_start_date.isoformat() if self.training_start_date else None, + "training_end_date": self.training_end_date.isoformat() if self.training_end_date else None, + "data_quality_score": self.data_quality_score + } \ No newline at end of file diff --git a/services/training/app/schemas/training.py b/services/training/app/schemas/training.py index 153e492d..52a95cf8 100644 --- a/services/training/app/schemas/training.py +++ b/services/training/app/schemas/training.py @@ -25,23 +25,6 @@ class TrainingJobRequest(BaseModel): products: Optional[List[str]] = Field(None, description="Specific products to train (if None, trains all)") start_date: Optional[datetime] = Field(None, description="Start date for training data") end_date: Optional[datetime] = Field(None, description="End date for training data") - - # Prophet-specific parameters - seasonality_mode: str = Field("additive", description="Prophet seasonality mode", pattern="^(additive|multiplicative)$") - daily_seasonality: bool = Field(True, description="Enable daily seasonality") - weekly_seasonality: bool = Field(True, description="Enable weekly seasonality") - yearly_seasonality: bool = Field(True, description="Enable yearly seasonality") - - # Advanced configuration - force_retrain: bool = Field(False, description="Force retraining even if recent model exists") - parallel_training: bool = Field(True, description="Train products in parallel") - max_workers: int = Field(4, description="Maximum parallel workers", ge=1, le=10) - - @validator('seasonality_mode') - def validate_seasonality_mode(cls, v): - if v not in ['additive', 'multiplicative']: - raise ValueError('seasonality_mode must be either "additive" or "multiplicative"') - return v class SingleProductTrainingRequest(BaseModel): diff --git a/tests/test_onboarding_flow.sh b/tests/test_onboarding_flow.sh index 0ac8aadd..20e2c560 100755 --- a/tests/test_onboarding_flow.sh +++ b/tests/test_onboarding_flow.sh @@ -326,9 +326,7 @@ BAKERY_DATA="{ \"address\": \"Calle Gran Vía 123\", \"city\": \"Madrid\", \"postal_code\": \"28001\", - \"phone\": \"+34600123456\", - \"latitude\": $MOCK_LATITUDE, - \"longitude\": $MOCK_LONGITUDE + \"phone\": \"+34600123456\" }" echo "Bakery Data with mock coordinates:" @@ -562,63 +560,6 @@ if [ "$HTTP_CODE" = "200" ]; then fi fi -log_step "3.3. Verifying imported sales data" - -SALES_LIST_RESPONSE=$(curl -s -X GET "$API_BASE/api/v1/tenants/$TENANT_ID/sales" \ - -H "Authorization: Bearer $ACCESS_TOKEN") - -echo "Sales Data Response:" -echo "$SALES_LIST_RESPONSE" | python3 -m json.tool 2>/dev/null || echo "$SALES_LIST_RESPONSE" - -# Check if we actually got any sales data -SALES_COUNT=$(echo "$SALES_LIST_RESPONSE" | python3 -c " -import json, sys -try: - data = json.load(sys.stdin) - if isinstance(data, list): - print(len(data)) - elif isinstance(data, dict) and 'data' in data: - print(len(data['data']) if isinstance(data['data'], list) else 0) - else: - print(0) -except: - print(0) -" 2>/dev/null) - -if [ "$SALES_COUNT" -gt 0 ]; then - log_success "Sales data successfully retrieved!" - echo " Records found: $SALES_COUNT" - - # Show some sample products found - echo " Sample products found:" - echo "$SALES_LIST_RESPONSE" | python3 -c " -import json, sys -try: - data = json.load(sys.stdin) - records = data if isinstance(data, list) else data.get('data', []) - products = set() - for record in records[:5]: # First 5 records - if isinstance(record, dict) and 'product_name' in record: - products.add(record['product_name']) - for product in sorted(products): - print(f' - {product}') -except: - pass -" 2>/dev/null -else - log_warning "No sales data found in database" - - if [ -n "$RECORDS_CREATED" ] && [ "$RECORDS_CREATED" -gt 0 ]; then - log_error "Inconsistency detected: Import reported $RECORDS_CREATED records created, but none found in database" - echo "This could indicate:" - echo " 1. Records were created but failed timezone validation and were rolled back" - echo " 2. Database transaction was not committed" - echo " 3. Records were created in a different tenant/schema" - else - echo "This is expected if the import failed due to timezone or other errors." - fi -fi - echo "" # ================================================================= @@ -631,40 +572,10 @@ echo "" log_step "4.1. Starting model training process with real data products" -# Get unique products from the imported data for training -# Extract some real product names from the CSV for training -REAL_PRODUCTS_RAW=$(tail -n +2 "$PREPARED_CSV" | cut -d',' -f2 | sort | uniq | head -3 | tr '\n' ',' | sed 's/,$//') - -if [ -z "$REAL_PRODUCTS_RAW" ]; then - # Fallback to default products if extraction fails - REAL_PRODUCTS_ARRAY='["Pan de molde","Croissants","Magdalenas"]' - log_warning "Could not extract real product names, using defaults" -else - # Format for JSON array properly - REAL_PRODUCTS_ARRAY='['$(echo "$REAL_PRODUCTS_RAW" | sed 's/,/","/g' | sed 's/^/"/' | sed 's/$/"/')']' - log_success "Extracted real products for training: $REAL_PRODUCTS_ARRAY" -fi - -# ✅ FIXED: Training request with correct data types matching TrainingJobRequest schema -TRAINING_DATA="{ - \"products\": $REAL_PRODUCTS_ARRAY, - \"max_workers\": 4, - \"seasonality_mode\": \"additive\", - \"daily_seasonality\": true, - \"weekly_seasonality\": true, - \"yearly_seasonality\": true, - \"force_retrain\": false, - \"parallel_training\": true -}" - -echo "Training Request:" -echo "$TRAINING_DATA" | python3 -m json.tool - TRAINING_RESPONSE=$(curl -s -w "\nHTTP_CODE:%{http_code}" -X POST "$API_BASE/api/v1/tenants/$TENANT_ID/training/jobs" \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer $ACCESS_TOKEN" \ - -H "X-Tenant-ID: $TENANT_ID" \ - -d "$TRAINING_DATA") + -H "Authorization: Bearer $ACCESS_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{}') # Extract HTTP code and response HTTP_CODE=$(echo "$TRAINING_RESPONSE" | grep "HTTP_CODE:" | cut -d: -f2) @@ -685,31 +596,6 @@ if [ "$HTTP_CODE" = "422" ]; then echo "" echo "Response details:" echo "$TRAINING_RESPONSE" - - # Try a minimal request that should work - log_step "4.2. Attempting minimal training request as fallback" - - MINIMAL_TRAINING_DATA='{"seasonality_mode": "additive"}' - - FALLBACK_RESPONSE=$(curl -s -w "\nHTTP_CODE:%{http_code}" -X POST "$API_BASE/api/v1/tenants/$TENANT_ID/training/jobs" \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer $ACCESS_TOKEN" \ - -H "X-Tenant-ID: $TENANT_ID" \ - -d "$MINIMAL_TRAINING_DATA") - - FALLBACK_HTTP_CODE=$(echo "$FALLBACK_RESPONSE" | grep "HTTP_CODE:" | cut -d: -f2) - FALLBACK_RESPONSE=$(echo "$FALLBACK_RESPONSE" | sed '/HTTP_CODE:/d') - - echo "Fallback HTTP Status Code: $FALLBACK_HTTP_CODE" - echo "Fallback Response:" - echo "$FALLBACK_RESPONSE" | python3 -m json.tool 2>/dev/null || echo "$FALLBACK_RESPONSE" - - if [ "$FALLBACK_HTTP_CODE" = "200" ] || [ "$FALLBACK_HTTP_CODE" = "201" ]; then - log_success "Minimal training request succeeded" - TRAINING_TASK_ID=$(extract_json_field "$FALLBACK_RESPONSE" "job_id") - else - log_error "Both training requests failed" - fi else # Original success handling TRAINING_TASK_ID=$(extract_json_field "$TRAINING_RESPONSE" "job_id")