REFACTOR - Database logic
This commit is contained in:
@@ -24,7 +24,8 @@ warnings.filterwarnings('ignore')
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from app.models.training import TrainedModel
|
||||
from app.core.database import get_db_session
|
||||
from shared.database.base import create_database_manager
|
||||
from app.repositories import ModelRepository
|
||||
|
||||
# Simple optimization import
|
||||
import optuna
|
||||
@@ -40,10 +41,11 @@ class BakeryProphetManager:
|
||||
Drop-in replacement for the existing manager - optimization runs automatically.
|
||||
"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession = None):
|
||||
def __init__(self, database_manager=None):
|
||||
self.models = {} # In-memory model storage
|
||||
self.model_metadata = {} # Store model metadata
|
||||
self.db_session = db_session # Add database session
|
||||
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
self.db_session = None # Will be set when session is available
|
||||
|
||||
# Ensure model storage directory exists
|
||||
os.makedirs(settings.MODEL_STORAGE_PATH, exist_ok=True)
|
||||
@@ -84,15 +86,15 @@ class BakeryProphetManager:
|
||||
# Fit the model
|
||||
model.fit(prophet_data)
|
||||
|
||||
# Store model and calculate metrics (same as before)
|
||||
model_id = f"{job_id}_{product_name}_{uuid.uuid4().hex[:8]}"
|
||||
model_path = await self._store_model(
|
||||
tenant_id, product_name, model, model_id, prophet_data, regressor_columns, best_params
|
||||
)
|
||||
|
||||
# Calculate enhanced training metrics
|
||||
# Calculate enhanced training metrics first
|
||||
training_metrics = await self._calculate_training_metrics(model, prophet_data, best_params)
|
||||
|
||||
# Store model and metrics - Generate proper UUID for model_id
|
||||
model_id = str(uuid.uuid4())
|
||||
model_path = await self._store_model(
|
||||
tenant_id, product_name, model, model_id, prophet_data, regressor_columns, best_params, training_metrics
|
||||
)
|
||||
|
||||
# Return same format as before, but with optimization info
|
||||
model_info = {
|
||||
"model_id": model_id,
|
||||
@@ -517,11 +519,11 @@ class BakeryProphetManager:
|
||||
self.models[model_key] = model
|
||||
self.model_metadata[model_key] = metadata
|
||||
|
||||
# 🆕 NEW: Store in database
|
||||
if self.db_session:
|
||||
try:
|
||||
# 🆕 NEW: Store in database using new session
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
# Deactivate previous models for this product
|
||||
await self._deactivate_previous_models(tenant_id, product_name)
|
||||
await self._deactivate_previous_models_with_session(db_session, tenant_id, product_name)
|
||||
|
||||
# Create new database record
|
||||
db_model = TrainedModel(
|
||||
@@ -536,8 +538,8 @@ class BakeryProphetManager:
|
||||
features_used=regressor_columns,
|
||||
is_active=True,
|
||||
is_production=True, # New models are production-ready
|
||||
training_start_date=training_data['ds'].min(),
|
||||
training_end_date=training_data['ds'].max(),
|
||||
training_start_date=training_data['ds'].min().to_pydatetime().replace(tzinfo=None) if training_data['ds'].min().tz is None else training_data['ds'].min().to_pydatetime(),
|
||||
training_end_date=training_data['ds'].max().to_pydatetime().replace(tzinfo=None) if training_data['ds'].max().tz is None else training_data['ds'].max().to_pydatetime(),
|
||||
training_samples=len(training_data)
|
||||
)
|
||||
|
||||
@@ -549,44 +551,39 @@ class BakeryProphetManager:
|
||||
db_model.r2_score = training_metrics.get('r2')
|
||||
db_model.data_quality_score = training_metrics.get('data_quality_score')
|
||||
|
||||
self.db_session.add(db_model)
|
||||
await self.db_session.commit()
|
||||
db_session.add(db_model)
|
||||
await db_session.commit()
|
||||
|
||||
logger.info(f"Model {model_id} stored in database successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store model in database: {str(e)}")
|
||||
await self.db_session.rollback()
|
||||
# Continue execution - file storage succeeded
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store model in database: {str(e)}")
|
||||
# Continue execution - file storage succeeded
|
||||
|
||||
logger.info(f"Optimized model stored at: {model_path}")
|
||||
return str(model_path)
|
||||
|
||||
async def _deactivate_previous_models(self, tenant_id: str, product_name: str):
|
||||
"""Deactivate previous models for the same product"""
|
||||
if self.db_session:
|
||||
try:
|
||||
# ✅ FIX: Wrap SQL string with text() for SQLAlchemy 2.0
|
||||
query = text("""
|
||||
UPDATE trained_models
|
||||
SET is_active = false, is_production = false
|
||||
WHERE tenant_id = :tenant_id AND product_name = :product_name
|
||||
""")
|
||||
|
||||
await self.db_session.execute(query, {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name
|
||||
})
|
||||
|
||||
# ✅ ADD: Commit the transaction
|
||||
await self.db_session.commit()
|
||||
|
||||
logger.info(f"Successfully deactivated previous models for {product_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deactivate previous models: {str(e)}")
|
||||
# ✅ ADD: Rollback on error
|
||||
await self.db_session.rollback()
|
||||
async def _deactivate_previous_models_with_session(self, db_session, tenant_id: str, product_name: str):
|
||||
"""Deactivate previous models for the same product using provided session"""
|
||||
try:
|
||||
# ✅ FIX: Wrap SQL string with text() for SQLAlchemy 2.0
|
||||
query = text("""
|
||||
UPDATE trained_models
|
||||
SET is_active = false, is_production = false
|
||||
WHERE tenant_id = :tenant_id AND product_name = :product_name
|
||||
""")
|
||||
|
||||
await db_session.execute(query, {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name
|
||||
})
|
||||
|
||||
# Note: Don't commit here, let the calling method handle the transaction
|
||||
logger.info(f"Successfully deactivated previous models for {product_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deactivate previous models: {str(e)}")
|
||||
raise
|
||||
|
||||
# Keep all existing methods unchanged
|
||||
async def generate_forecast(self,
|
||||
|
||||
Reference in New Issue
Block a user