REFACTOR - Database logic

This commit is contained in:
Urtzi Alfaro
2025-08-08 09:08:41 +02:00
parent 0154365bfc
commit 488bb3ef93
113 changed files with 22842 additions and 6503 deletions

View File

@@ -24,7 +24,8 @@ warnings.filterwarnings('ignore')
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from app.models.training import TrainedModel
from app.core.database import get_db_session
from shared.database.base import create_database_manager
from app.repositories import ModelRepository
# Simple optimization import
import optuna
@@ -40,10 +41,11 @@ class BakeryProphetManager:
Drop-in replacement for the existing manager - optimization runs automatically.
"""
def __init__(self, db_session: AsyncSession = None):
def __init__(self, database_manager=None):
self.models = {} # In-memory model storage
self.model_metadata = {} # Store model metadata
self.db_session = db_session # Add database session
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "training-service")
self.db_session = None # Will be set when session is available
# Ensure model storage directory exists
os.makedirs(settings.MODEL_STORAGE_PATH, exist_ok=True)
@@ -84,15 +86,15 @@ class BakeryProphetManager:
# Fit the model
model.fit(prophet_data)
# Store model and calculate metrics (same as before)
model_id = f"{job_id}_{product_name}_{uuid.uuid4().hex[:8]}"
model_path = await self._store_model(
tenant_id, product_name, model, model_id, prophet_data, regressor_columns, best_params
)
# Calculate enhanced training metrics
# Calculate enhanced training metrics first
training_metrics = await self._calculate_training_metrics(model, prophet_data, best_params)
# Store model and metrics - Generate proper UUID for model_id
model_id = str(uuid.uuid4())
model_path = await self._store_model(
tenant_id, product_name, model, model_id, prophet_data, regressor_columns, best_params, training_metrics
)
# Return same format as before, but with optimization info
model_info = {
"model_id": model_id,
@@ -517,11 +519,11 @@ class BakeryProphetManager:
self.models[model_key] = model
self.model_metadata[model_key] = metadata
# 🆕 NEW: Store in database
if self.db_session:
try:
# 🆕 NEW: Store in database using new session
try:
async with self.database_manager.get_session() as db_session:
# Deactivate previous models for this product
await self._deactivate_previous_models(tenant_id, product_name)
await self._deactivate_previous_models_with_session(db_session, tenant_id, product_name)
# Create new database record
db_model = TrainedModel(
@@ -536,8 +538,8 @@ class BakeryProphetManager:
features_used=regressor_columns,
is_active=True,
is_production=True, # New models are production-ready
training_start_date=training_data['ds'].min(),
training_end_date=training_data['ds'].max(),
training_start_date=training_data['ds'].min().to_pydatetime().replace(tzinfo=None) if training_data['ds'].min().tz is None else training_data['ds'].min().to_pydatetime(),
training_end_date=training_data['ds'].max().to_pydatetime().replace(tzinfo=None) if training_data['ds'].max().tz is None else training_data['ds'].max().to_pydatetime(),
training_samples=len(training_data)
)
@@ -549,44 +551,39 @@ class BakeryProphetManager:
db_model.r2_score = training_metrics.get('r2')
db_model.data_quality_score = training_metrics.get('data_quality_score')
self.db_session.add(db_model)
await self.db_session.commit()
db_session.add(db_model)
await db_session.commit()
logger.info(f"Model {model_id} stored in database successfully")
except Exception as e:
logger.error(f"Failed to store model in database: {str(e)}")
await self.db_session.rollback()
# Continue execution - file storage succeeded
except Exception as e:
logger.error(f"Failed to store model in database: {str(e)}")
# Continue execution - file storage succeeded
logger.info(f"Optimized model stored at: {model_path}")
return str(model_path)
async def _deactivate_previous_models(self, tenant_id: str, product_name: str):
"""Deactivate previous models for the same product"""
if self.db_session:
try:
# ✅ FIX: Wrap SQL string with text() for SQLAlchemy 2.0
query = text("""
UPDATE trained_models
SET is_active = false, is_production = false
WHERE tenant_id = :tenant_id AND product_name = :product_name
""")
await self.db_session.execute(query, {
"tenant_id": tenant_id,
"product_name": product_name
})
# ✅ ADD: Commit the transaction
await self.db_session.commit()
logger.info(f"Successfully deactivated previous models for {product_name}")
except Exception as e:
logger.error(f"Failed to deactivate previous models: {str(e)}")
# ✅ ADD: Rollback on error
await self.db_session.rollback()
async def _deactivate_previous_models_with_session(self, db_session, tenant_id: str, product_name: str):
"""Deactivate previous models for the same product using provided session"""
try:
# ✅ FIX: Wrap SQL string with text() for SQLAlchemy 2.0
query = text("""
UPDATE trained_models
SET is_active = false, is_production = false
WHERE tenant_id = :tenant_id AND product_name = :product_name
""")
await db_session.execute(query, {
"tenant_id": tenant_id,
"product_name": product_name
})
# Note: Don't commit here, let the calling method handle the transaction
logger.info(f"Successfully deactivated previous models for {product_name}")
except Exception as e:
logger.error(f"Failed to deactivate previous models: {str(e)}")
raise
# Keep all existing methods unchanged
async def generate_forecast(self,