Improve training code 4

This commit is contained in:
Urtzi Alfaro
2025-07-29 07:53:30 +02:00
parent c788c7e406
commit 71216f8ec9
3 changed files with 84 additions and 135 deletions

View File

@@ -0,0 +1,80 @@
# services/training/app/models/training_models.py
"""
Database models for trained ML models
"""
from sqlalchemy import Column, String, DateTime, Float, Integer, Text, Boolean, JSON
from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime
import uuid
Base = declarative_base()
class TrainedModel(Base):
__tablename__ = "trained_models"
# Primary identification
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
tenant_id = Column(String, nullable=False, index=True)
product_name = Column(String, nullable=False, index=True)
# Model information
model_type = Column(String, default="prophet_optimized")
model_version = Column(String, default="1.0")
job_id = Column(String, nullable=False)
# File storage
model_path = Column(String, nullable=False) # Path to the .pkl file
metadata_path = Column(String) # Path to metadata JSON
# Training metrics
mape = Column(Float)
mae = Column(Float)
rmse = Column(Float)
r2_score = Column(Float)
training_samples = Column(Integer)
# Hyperparameters and features
hyperparameters = Column(JSON) # Store optimized parameters
features_used = Column(JSON) # List of regressor columns
# Model status
is_active = Column(Boolean, default=True)
is_production = Column(Boolean, default=False)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
last_used_at = Column(DateTime)
# Training data info
training_start_date = Column(DateTime)
training_end_date = Column(DateTime)
data_quality_score = Column(Float)
# Additional metadata
notes = Column(Text)
created_by = Column(String) # User who triggered training
def to_dict(self):
return {
"id": self.id,
"tenant_id": self.tenant_id,
"product_name": self.product_name,
"model_type": self.model_type,
"model_version": self.model_version,
"model_path": self.model_path,
"mape": self.mape,
"mae": self.mae,
"rmse": self.rmse,
"r2_score": self.r2_score,
"training_samples": self.training_samples,
"hyperparameters": self.hyperparameters,
"features_used": self.features_used,
"is_active": self.is_active,
"is_production": self.is_production,
"created_at": self.created_at.isoformat() if self.created_at else None,
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,
"data_quality_score": self.data_quality_score
}

View File

@@ -26,23 +26,6 @@ class TrainingJobRequest(BaseModel):
start_date: Optional[datetime] = Field(None, description="Start date for training data") start_date: Optional[datetime] = Field(None, description="Start date for training data")
end_date: Optional[datetime] = Field(None, description="End date for training data") end_date: Optional[datetime] = Field(None, description="End date for training data")
# Prophet-specific parameters
seasonality_mode: str = Field("additive", description="Prophet seasonality mode", pattern="^(additive|multiplicative)$")
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
# Advanced configuration
force_retrain: bool = Field(False, description="Force retraining even if recent model exists")
parallel_training: bool = Field(True, description="Train products in parallel")
max_workers: int = Field(4, description="Maximum parallel workers", ge=1, le=10)
@validator('seasonality_mode')
def validate_seasonality_mode(cls, v):
if v not in ['additive', 'multiplicative']:
raise ValueError('seasonality_mode must be either "additive" or "multiplicative"')
return v
class SingleProductTrainingRequest(BaseModel): class SingleProductTrainingRequest(BaseModel):
"""Request schema for training a single product""" """Request schema for training a single product"""

View File

@@ -326,9 +326,7 @@ BAKERY_DATA="{
\"address\": \"Calle Gran Vía 123\", \"address\": \"Calle Gran Vía 123\",
\"city\": \"Madrid\", \"city\": \"Madrid\",
\"postal_code\": \"28001\", \"postal_code\": \"28001\",
\"phone\": \"+34600123456\", \"phone\": \"+34600123456\"
\"latitude\": $MOCK_LATITUDE,
\"longitude\": $MOCK_LONGITUDE
}" }"
echo "Bakery Data with mock coordinates:" echo "Bakery Data with mock coordinates:"
@@ -562,63 +560,6 @@ if [ "$HTTP_CODE" = "200" ]; then
fi fi
fi fi
log_step "3.3. Verifying imported sales data"
SALES_LIST_RESPONSE=$(curl -s -X GET "$API_BASE/api/v1/tenants/$TENANT_ID/sales" \
-H "Authorization: Bearer $ACCESS_TOKEN")
echo "Sales Data Response:"
echo "$SALES_LIST_RESPONSE" | python3 -m json.tool 2>/dev/null || echo "$SALES_LIST_RESPONSE"
# Check if we actually got any sales data
SALES_COUNT=$(echo "$SALES_LIST_RESPONSE" | python3 -c "
import json, sys
try:
data = json.load(sys.stdin)
if isinstance(data, list):
print(len(data))
elif isinstance(data, dict) and 'data' in data:
print(len(data['data']) if isinstance(data['data'], list) else 0)
else:
print(0)
except:
print(0)
" 2>/dev/null)
if [ "$SALES_COUNT" -gt 0 ]; then
log_success "Sales data successfully retrieved!"
echo " Records found: $SALES_COUNT"
# Show some sample products found
echo " Sample products found:"
echo "$SALES_LIST_RESPONSE" | python3 -c "
import json, sys
try:
data = json.load(sys.stdin)
records = data if isinstance(data, list) else data.get('data', [])
products = set()
for record in records[:5]: # First 5 records
if isinstance(record, dict) and 'product_name' in record:
products.add(record['product_name'])
for product in sorted(products):
print(f' - {product}')
except:
pass
" 2>/dev/null
else
log_warning "No sales data found in database"
if [ -n "$RECORDS_CREATED" ] && [ "$RECORDS_CREATED" -gt 0 ]; then
log_error "Inconsistency detected: Import reported $RECORDS_CREATED records created, but none found in database"
echo "This could indicate:"
echo " 1. Records were created but failed timezone validation and were rolled back"
echo " 2. Database transaction was not committed"
echo " 3. Records were created in a different tenant/schema"
else
echo "This is expected if the import failed due to timezone or other errors."
fi
fi
echo "" echo ""
# ================================================================= # =================================================================
@@ -631,40 +572,10 @@ echo ""
log_step "4.1. Starting model training process with real data products" log_step "4.1. Starting model training process with real data products"
# Get unique products from the imported data for training
# Extract some real product names from the CSV for training
REAL_PRODUCTS_RAW=$(tail -n +2 "$PREPARED_CSV" | cut -d',' -f2 | sort | uniq | head -3 | tr '\n' ',' | sed 's/,$//')
if [ -z "$REAL_PRODUCTS_RAW" ]; then
# Fallback to default products if extraction fails
REAL_PRODUCTS_ARRAY='["Pan de molde","Croissants","Magdalenas"]'
log_warning "Could not extract real product names, using defaults"
else
# Format for JSON array properly
REAL_PRODUCTS_ARRAY='['$(echo "$REAL_PRODUCTS_RAW" | sed 's/,/","/g' | sed 's/^/"/' | sed 's/$/"/')']'
log_success "Extracted real products for training: $REAL_PRODUCTS_ARRAY"
fi
# ✅ FIXED: Training request with correct data types matching TrainingJobRequest schema
TRAINING_DATA="{
\"products\": $REAL_PRODUCTS_ARRAY,
\"max_workers\": 4,
\"seasonality_mode\": \"additive\",
\"daily_seasonality\": true,
\"weekly_seasonality\": true,
\"yearly_seasonality\": true,
\"force_retrain\": false,
\"parallel_training\": true
}"
echo "Training Request:"
echo "$TRAINING_DATA" | python3 -m json.tool
TRAINING_RESPONSE=$(curl -s -w "\nHTTP_CODE:%{http_code}" -X POST "$API_BASE/api/v1/tenants/$TENANT_ID/training/jobs" \ TRAINING_RESPONSE=$(curl -s -w "\nHTTP_CODE:%{http_code}" -X POST "$API_BASE/api/v1/tenants/$TENANT_ID/training/jobs" \
-H "Content-Type: application/json" \ -H "Authorization: Bearer $ACCESS_TOKEN" \
-H "Authorization: Bearer $ACCESS_TOKEN" \ -H "Content-Type: application/json" \
-H "X-Tenant-ID: $TENANT_ID" \ -d '{}')
-d "$TRAINING_DATA")
# Extract HTTP code and response # Extract HTTP code and response
HTTP_CODE=$(echo "$TRAINING_RESPONSE" | grep "HTTP_CODE:" | cut -d: -f2) HTTP_CODE=$(echo "$TRAINING_RESPONSE" | grep "HTTP_CODE:" | cut -d: -f2)
@@ -685,31 +596,6 @@ if [ "$HTTP_CODE" = "422" ]; then
echo "" echo ""
echo "Response details:" echo "Response details:"
echo "$TRAINING_RESPONSE" echo "$TRAINING_RESPONSE"
# Try a minimal request that should work
log_step "4.2. Attempting minimal training request as fallback"
MINIMAL_TRAINING_DATA='{"seasonality_mode": "additive"}'
FALLBACK_RESPONSE=$(curl -s -w "\nHTTP_CODE:%{http_code}" -X POST "$API_BASE/api/v1/tenants/$TENANT_ID/training/jobs" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $ACCESS_TOKEN" \
-H "X-Tenant-ID: $TENANT_ID" \
-d "$MINIMAL_TRAINING_DATA")
FALLBACK_HTTP_CODE=$(echo "$FALLBACK_RESPONSE" | grep "HTTP_CODE:" | cut -d: -f2)
FALLBACK_RESPONSE=$(echo "$FALLBACK_RESPONSE" | sed '/HTTP_CODE:/d')
echo "Fallback HTTP Status Code: $FALLBACK_HTTP_CODE"
echo "Fallback Response:"
echo "$FALLBACK_RESPONSE" | python3 -m json.tool 2>/dev/null || echo "$FALLBACK_RESPONSE"
if [ "$FALLBACK_HTTP_CODE" = "200" ] || [ "$FALLBACK_HTTP_CODE" = "201" ]; then
log_success "Minimal training request succeeded"
TRAINING_TASK_ID=$(extract_json_field "$FALLBACK_RESPONSE" "job_id")
else
log_error "Both training requests failed"
fi
else else
# Original success handling # Original success handling
TRAINING_TASK_ID=$(extract_json_field "$TRAINING_RESPONSE" "job_id") TRAINING_TASK_ID=$(extract_json_field "$TRAINING_RESPONSE" "job_id")