Improve the sales import

This commit is contained in:
Urtzi Alfaro
2025-10-15 21:09:42 +02:00
parent 8f9e9a7edc
commit dbb48d8e2c
21 changed files with 992 additions and 409 deletions

View File

@@ -26,7 +26,7 @@ from shared.monitoring.decorators import track_execution_time
from shared.monitoring.metrics import get_metrics_collector
from app.core.config import settings
from shared.routing import RouteBuilder
from shared.auth.access_control import require_user_role, enterprise_tier_required
from shared.auth.access_control import require_user_role, analytics_tier_required
route_builder = RouteBuilder('forecasting')
logger = structlog.get_logger()
@@ -44,7 +44,7 @@ def get_enhanced_forecasting_service():
response_model=ScenarioSimulationResponse
)
@require_user_role(['admin', 'owner'])
@enterprise_tier_required
@analytics_tier_required
@track_execution_time("scenario_simulation_duration_seconds", "forecasting-service")
async def simulate_scenario(
request: ScenarioSimulationRequest,
@@ -406,6 +406,7 @@ def _generate_insights(
response_model=ScenarioComparisonResponse
)
@require_user_role(['viewer', 'member', 'admin', 'owner'])
@analytics_tier_required
async def compare_scenarios(
request: ScenarioComparisonRequest,
tenant_id: str = Path(..., description="Tenant ID")

View File

@@ -365,3 +365,93 @@ async def classify_products_batch(
logger.error("Failed batch classification",
error=str(e), products_count=len(request.products), tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Batch classification failed: {str(e)}")
class BatchProductResolutionRequest(BaseModel):
"""Request for batch product resolution or creation"""
products: List[Dict[str, Any]] = Field(..., description="Products to resolve or create")
class BatchProductResolutionResponse(BaseModel):
"""Response with product name to inventory ID mappings"""
product_mappings: Dict[str, str] = Field(..., description="Product name to inventory product ID mapping")
created_count: int = Field(..., description="Number of products created")
resolved_count: int = Field(..., description="Number of existing products resolved")
failed_count: int = Field(0, description="Number of products that failed")
@router.post(
route_builder.build_operations_route("resolve-or-create-products-batch"),
response_model=BatchProductResolutionResponse
)
async def resolve_or_create_products_batch(
request: BatchProductResolutionRequest,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
db: AsyncSession = Depends(get_db)
):
"""Resolve or create multiple products in a single optimized operation for sales import"""
try:
if not request.products:
raise HTTPException(status_code=400, detail="No products provided")
service = InventoryService()
product_mappings = {}
created_count = 0
resolved_count = 0
failed_count = 0
for product_data in request.products:
product_name = product_data.get('name', product_data.get('product_name', ''))
if not product_name:
failed_count += 1
continue
try:
existing = await service.search_ingredients_by_name(product_name, tenant_id, db)
if existing:
product_mappings[product_name] = str(existing.id)
resolved_count += 1
logger.debug("Resolved existing product", product=product_name, tenant_id=tenant_id)
else:
category = product_data.get('category', 'general')
ingredient_data = {
'name': product_name,
'type': 'finished_product',
'unit': 'unit',
'current_stock': 0,
'reorder_point': 0,
'cost_per_unit': 0,
'category': category
}
created = await service.create_ingredient_fast(ingredient_data, tenant_id, db)
product_mappings[product_name] = str(created.id)
created_count += 1
logger.debug("Created new product", product=product_name, tenant_id=tenant_id)
except Exception as e:
logger.warning("Failed to resolve/create product",
product=product_name, error=str(e), tenant_id=tenant_id)
failed_count += 1
continue
logger.info("Batch product resolution complete",
total=len(request.products),
created=created_count,
resolved=resolved_count,
failed=failed_count,
tenant_id=tenant_id)
return BatchProductResolutionResponse(
product_mappings=product_mappings,
created_count=created_count,
resolved_count=resolved_count,
failed_count=failed_count
)
except Exception as e:
logger.error("Batch product resolution failed",
error=str(e), tenant_id=tenant_id)
raise HTTPException(status_code=500, detail=f"Batch resolution failed: {str(e)}")

View File

@@ -753,6 +753,67 @@ class InventoryService:
)
raise
# ===== BATCH OPERATIONS FOR SALES IMPORT =====
async def search_ingredients_by_name(
self,
product_name: str,
tenant_id: UUID,
db
) -> Optional[Ingredient]:
"""Search for an ingredient by name (case-insensitive exact match)"""
try:
repository = IngredientRepository(db)
ingredients = await repository.search_ingredients(
tenant_id=tenant_id,
search_term=product_name,
skip=0,
limit=10
)
product_name_lower = product_name.lower().strip()
for ingredient in ingredients:
if ingredient.name.lower().strip() == product_name_lower:
return ingredient
return None
except Exception as e:
logger.warning("Error searching ingredients by name",
product_name=product_name, error=str(e), tenant_id=tenant_id)
return None
async def create_ingredient_fast(
self,
ingredient_data: Dict[str, Any],
tenant_id: UUID,
db
) -> Ingredient:
"""Create ingredient without full validation for batch operations"""
try:
repository = IngredientRepository(db)
ingredient_create = IngredientCreate(
name=ingredient_data.get('name'),
product_type=ingredient_data.get('type', 'finished_product'),
unit_of_measure=ingredient_data.get('unit', 'units'),
low_stock_threshold=ingredient_data.get('current_stock', 0),
reorder_point=max(ingredient_data.get('reorder_point', 1),
ingredient_data.get('current_stock', 0) + 1),
average_cost=ingredient_data.get('cost_per_unit', 0.0),
ingredient_category=ingredient_data.get('category') if ingredient_data.get('type') == 'ingredient' else None,
product_category=ingredient_data.get('category') if ingredient_data.get('type') == 'finished_product' else None
)
ingredient = await repository.create_ingredient(ingredient_create, tenant_id)
logger.debug("Created ingredient fast", ingredient_id=ingredient.id, name=ingredient.name)
return ingredient
except Exception as e:
logger.error("Failed to create ingredient fast",
error=str(e), ingredient_data=ingredient_data, tenant_id=tenant_id)
raise
# ===== PRIVATE HELPER METHODS =====
async def _validate_ingredient_data(self, ingredient_data: IngredientCreate, tenant_id: UUID):

View File

@@ -265,18 +265,60 @@ class SalesRepository(BaseRepository[SalesData, SalesDataCreate, SalesDataUpdate
record = await self.get_by_id(record_id)
if not record:
raise ValueError(f"Sales record {record_id} not found")
update_data = {
'is_validated': True,
'validation_notes': validation_notes
}
updated_record = await self.update(record_id, update_data)
logger.info("Validated sales record", record_id=record_id)
return updated_record
except Exception as e:
logger.error("Failed to validate sales record", error=str(e), record_id=record_id)
raise
async def create_sales_records_bulk(
self,
sales_data_list: List[SalesDataCreate],
tenant_id: UUID
) -> int:
"""Bulk insert sales records for performance optimization"""
try:
from uuid import uuid4
records = []
for sales_data in sales_data_list:
is_weekend = sales_data.date.weekday() >= 5 if sales_data.date else False
record = SalesData(
id=uuid4(),
tenant_id=tenant_id,
date=sales_data.date,
inventory_product_id=sales_data.inventory_product_id,
quantity_sold=sales_data.quantity_sold,
unit_price=sales_data.unit_price,
revenue=sales_data.revenue,
location_id=sales_data.location_id,
sales_channel=sales_data.sales_channel,
source=sales_data.source,
is_weekend=is_weekend,
is_validated=getattr(sales_data, 'is_validated', False)
)
records.append(record)
self.session.add_all(records)
await self.session.flush()
logger.info(
"Bulk created sales records",
count=len(records),
tenant_id=tenant_id
)
return len(records)
except Exception as e:
logger.error("Failed to bulk create sales records", error=str(e), tenant_id=tenant_id)
raise

View File

@@ -442,17 +442,17 @@ class DataImportService:
)
async def _process_csv_data(
self,
tenant_id: str,
csv_content: str,
repository: SalesRepository,
self,
tenant_id: str,
csv_content: str,
repository: SalesRepository,
filename: Optional[str] = None
) -> Dict[str, Any]:
"""Enhanced CSV processing with batch product resolution for better reliability"""
"""Optimized CSV processing with true batch operations"""
try:
reader = csv.DictReader(io.StringIO(csv_content))
rows = list(reader)
if not rows:
return {
"success": False,
@@ -461,19 +461,18 @@ class DataImportService:
"errors": ["CSV file is empty"],
"warnings": []
}
# Enhanced column mapping
column_mapping = self._detect_columns(list(rows[0].keys()))
# Pre-process to extract unique products for batch creation
unique_products = set()
parsed_rows = []
logger.info(f"Pre-processing {len(rows)} records to identify unique products")
errors = []
warnings = []
logger.info(f"Parsing {len(rows)} CSV records")
for index, row in enumerate(rows):
try:
# Enhanced data parsing and validation
parsed_data = await self._parse_row_data(row, column_mapping, index + 1)
if not parsed_data.get("skip"):
unique_products.add((
@@ -481,38 +480,52 @@ class DataImportService:
parsed_data.get("product_category", "general")
))
parsed_rows.append((index, parsed_data))
else:
errors.extend(parsed_data.get("errors", []))
warnings.extend(parsed_data.get("warnings", []))
except Exception as e:
logger.warning(f"Failed to parse row {index + 1}: {e}")
errors.append(f"Row {index + 1}: Parse error - {str(e)}")
continue
logger.info(f"Found {len(unique_products)} unique products, attempting batch resolution")
# Try to resolve/create all unique products in batch
await self._batch_resolve_products(unique_products, tenant_id)
# Now process the actual sales records
records_created = 0
errors = []
warnings = []
logger.info(f"Processing {len(parsed_rows)} validated records for sales creation")
logger.info(f"Batch resolving {len(unique_products)} unique products")
products_batch = [
{"name": name, "category": category}
for name, category in unique_products
]
batch_result = await self.inventory_client.resolve_or_create_products_batch(
products_batch,
tenant_id
)
if batch_result and 'product_mappings' in batch_result:
self.product_cache.update(batch_result['product_mappings'])
logger.info(f"Resolved {len(batch_result['product_mappings'])} products in single batch call")
else:
logger.error("Batch product resolution failed")
return {
"success": False,
"total_rows": len(rows),
"records_created": 0,
"errors": ["Failed to resolve products in inventory"],
"warnings": warnings
}
sales_records_batch = []
for index, parsed_data in parsed_rows:
product_name = parsed_data["product_name"]
if product_name not in self.product_cache:
errors.append(f"Row {index + 1}: Product '{product_name}' not found in cache")
continue
try:
# Resolve product name to inventory_product_id (should be cached now)
inventory_product_id = await self._resolve_product_to_inventory_id(
parsed_data["product_name"],
parsed_data.get("product_category"),
tenant_id
)
if not inventory_product_id:
error_msg = f"Row {index + 1}: Could not resolve product '{parsed_data['product_name']}' to inventory ID"
errors.append(error_msg)
logger.warning("Product resolution failed", error=error_msg)
continue
# Create sales record with enhanced data
from uuid import UUID
inventory_product_id = UUID(self.product_cache[product_name])
sales_data = SalesDataCreate(
tenant_id=tenant_id,
date=parsed_data["date"],
@@ -523,32 +536,35 @@ class DataImportService:
location_id=parsed_data.get("location_id"),
source="csv"
)
created_record = await repository.create_sales_record(sales_data, tenant_id)
records_created += 1
# Enhanced progress logging
if records_created % 100 == 0:
logger.info(f"Enhanced processing: {records_created}/{len(rows)} records completed...")
sales_records_batch.append(sales_data)
except Exception as e:
error_msg = f"Row {index + 1}: {str(e)}"
errors.append(error_msg)
logger.warning("Enhanced record processing failed", error=error_msg)
errors.append(f"Row {index + 1}: {str(e)}")
continue
if sales_records_batch:
logger.info(f"Bulk inserting {len(sales_records_batch)} sales records")
records_created = await repository.create_sales_records_bulk(
sales_records_batch,
tenant_id
)
else:
records_created = 0
success_rate = (records_created / len(rows)) * 100 if rows else 0
return {
"success": records_created > 0,
"total_rows": len(rows),
"records_created": records_created,
"success_rate": success_rate,
"errors": errors,
"warnings": warnings
"errors": errors[:50],
"warnings": warnings[:50]
}
except Exception as e:
logger.error("Enhanced CSV processing failed", error=str(e))
logger.error("CSV processing failed", error=str(e))
raise
async def _process_json_data(
@@ -633,66 +649,95 @@ class DataImportService:
raise
async def _process_dataframe(
self,
tenant_id: str,
df: pd.DataFrame,
self,
tenant_id: str,
df: pd.DataFrame,
repository: SalesRepository,
source: str,
filename: Optional[str] = None
) -> Dict[str, Any]:
"""Enhanced DataFrame processing with better error handling"""
"""Optimized DataFrame processing with batch operations"""
try:
# Enhanced column mapping
column_mapping = self._detect_columns(df.columns.tolist())
if not column_mapping.get('date') or not column_mapping.get('product'):
required_missing = []
if not column_mapping.get('date'):
required_missing.append("date")
if not column_mapping.get('product'):
required_missing.append("product")
raise ValueError(f"Required columns missing: {', '.join(required_missing)}")
records_created = 0
unique_products = set()
parsed_rows = []
errors = []
warnings = []
logger.info(f"Enhanced processing of {len(df)} records from {source}")
logger.info(f"Processing {len(df)} records from {source}")
for index, row in df.iterrows():
try:
# Convert pandas row to dict
row_dict = {}
for col in df.columns:
val = row[col]
# Handle pandas NaN values
if pd.isna(val):
row_dict[col] = None
else:
row_dict[col] = val
# Enhanced data parsing
parsed_data = await self._parse_row_data(row_dict, column_mapping, index + 1)
if parsed_data.get("skip"):
if not parsed_data.get("skip"):
unique_products.add((
parsed_data["product_name"],
parsed_data.get("product_category", "general")
))
parsed_rows.append((index, parsed_data))
else:
errors.extend(parsed_data.get("errors", []))
warnings.extend(parsed_data.get("warnings", []))
continue
# Resolve product name to inventory_product_id
inventory_product_id = await self._resolve_product_to_inventory_id(
parsed_data["product_name"],
parsed_data.get("product_category"),
tenant_id
)
if not inventory_product_id:
error_msg = f"Row {index + 1}: Could not resolve product '{parsed_data['product_name']}' to inventory ID"
errors.append(error_msg)
logger.warning("Product resolution failed", error=error_msg)
continue
# Create enhanced sales record
except Exception as e:
errors.append(f"Row {index + 1}: {str(e)}")
continue
logger.info(f"Batch resolving {len(unique_products)} unique products")
products_batch = [
{"name": name, "category": category}
for name, category in unique_products
]
batch_result = await self.inventory_client.resolve_or_create_products_batch(
products_batch,
tenant_id
)
if batch_result and 'product_mappings' in batch_result:
self.product_cache.update(batch_result['product_mappings'])
logger.info(f"Resolved {len(batch_result['product_mappings'])} products in batch")
else:
return {
"success": False,
"total_rows": len(df),
"records_created": 0,
"errors": ["Failed to resolve products"],
"warnings": warnings
}
sales_records_batch = []
for index, parsed_data in parsed_rows:
product_name = parsed_data["product_name"]
if product_name not in self.product_cache:
errors.append(f"Row {index + 1}: Product '{product_name}' not in cache")
continue
try:
from uuid import UUID
inventory_product_id = UUID(self.product_cache[product_name])
sales_data = SalesDataCreate(
tenant_id=tenant_id,
date=parsed_data["date"],
@@ -703,34 +748,37 @@ class DataImportService:
location_id=parsed_data.get("location_id"),
source=source
)
created_record = await repository.create_sales_record(sales_data, tenant_id)
records_created += 1
# Progress logging for large datasets
if records_created % 100 == 0:
logger.info(f"Enhanced DataFrame processing: {records_created}/{len(df)} records completed...")
sales_records_batch.append(sales_data)
except Exception as e:
error_msg = f"Row {index + 1}: {str(e)}"
errors.append(error_msg)
logger.warning("Enhanced record processing failed", error=error_msg)
errors.append(f"Row {index + 1}: {str(e)}")
continue
if sales_records_batch:
logger.info(f"Bulk inserting {len(sales_records_batch)} sales records")
records_created = await repository.create_sales_records_bulk(
sales_records_batch,
tenant_id
)
else:
records_created = 0
success_rate = (records_created / len(df)) * 100 if len(df) > 0 else 0
return {
"success": records_created > 0,
"total_rows": len(df),
"records_created": records_created,
"success_rate": success_rate,
"errors": errors[:10], # Limit errors for performance
"warnings": warnings[:10] # Limit warnings
"errors": errors[:50],
"warnings": warnings[:50]
}
except ValueError:
raise
except Exception as e:
logger.error("Enhanced DataFrame processing failed", error=str(e))
logger.error("DataFrame processing failed", error=str(e))
raise
async def _parse_row_data(
@@ -983,194 +1031,6 @@ class DataImportService:
self.failed_products.clear()
logger.info("Import cache cleared for new session")
async def _resolve_product_to_inventory_id(self, product_name: str, product_category: Optional[str], tenant_id: UUID) -> Optional[UUID]:
"""Resolve a product name to an inventory_product_id via the inventory service with improved error handling and fallback"""
# Check cache first
if product_name in self.product_cache:
logger.debug("Product resolved from cache", product_name=product_name, tenant_id=tenant_id)
return self.product_cache[product_name]
# Skip if this product already failed to resolve after all attempts
if product_name in self.failed_products:
logger.debug("Skipping previously failed product", product_name=product_name, tenant_id=tenant_id)
return None
max_retries = 5 # Increased retries
base_delay = 2.0 # Increased base delay
fallback_retry_delay = 10.0 # Longer delay for fallback attempts
for attempt in range(max_retries):
try:
# Add progressive delay to avoid rate limiting
if attempt > 0:
# Use longer delays for later attempts
if attempt >= 3:
delay = fallback_retry_delay # Use fallback delay for later attempts
else:
delay = base_delay * (2 ** (attempt - 1)) # Exponential backoff
logger.info(f"Retrying product resolution after {delay}s delay",
product_name=product_name, attempt=attempt, tenant_id=tenant_id)
await asyncio.sleep(delay)
# First try to search for existing product by name
try:
products = await self.inventory_client.search_products(product_name, tenant_id)
if products:
# Return the first matching product's ID
product_id = products[0].get('id')
if product_id:
uuid_id = UUID(str(product_id))
self.product_cache[product_name] = uuid_id # Cache for future use
logger.info("Resolved product to existing inventory ID",
product_name=product_name, product_id=product_id, tenant_id=tenant_id)
return uuid_id
except Exception as search_error:
logger.warning("Product search failed, trying direct creation",
product_name=product_name, error=str(search_error), tenant_id=tenant_id)
# Add delay before creation attempt to avoid hitting rate limits
await asyncio.sleep(1.0)
# If not found or search failed, create a new ingredient/product in inventory
ingredient_data = {
'name': product_name,
'type': 'finished_product', # Assuming sales are of finished products
'unit': 'unit', # Default unit
'current_stock': 0, # No stock initially
'reorder_point': 0,
'cost_per_unit': 0,
'category': product_category or 'general'
}
try:
created_product = await self.inventory_client.create_ingredient(ingredient_data, str(tenant_id))
if created_product and created_product.get('id'):
product_id = created_product['id']
uuid_id = UUID(str(product_id))
self.product_cache[product_name] = uuid_id # Cache for future use
logger.info("Created new inventory product for sales data",
product_name=product_name, product_id=product_id, tenant_id=tenant_id)
return uuid_id
except Exception as creation_error:
logger.warning("Product creation failed",
product_name=product_name, error=str(creation_error), tenant_id=tenant_id)
logger.warning("Failed to resolve or create product in inventory",
product_name=product_name, tenant_id=tenant_id, attempt=attempt)
except Exception as e:
error_str = str(e)
if "429" in error_str or "rate limit" in error_str.lower() or "too many requests" in error_str.lower():
logger.warning("Rate limit or service overload detected, retrying with longer delay",
product_name=product_name, attempt=attempt, error=error_str, tenant_id=tenant_id)
if attempt < max_retries - 1:
continue # Retry with exponential backoff
elif "503" in error_str or "502" in error_str or "service unavailable" in error_str.lower():
logger.warning("Service unavailable, retrying with backoff",
product_name=product_name, attempt=attempt, error=error_str, tenant_id=tenant_id)
if attempt < max_retries - 1:
continue # Retry for service unavailable errors
elif "timeout" in error_str.lower() or "connection" in error_str.lower():
logger.warning("Network issue detected, retrying",
product_name=product_name, attempt=attempt, error=error_str, tenant_id=tenant_id)
if attempt < max_retries - 1:
continue # Retry for network issues
else:
logger.error("Non-retryable error resolving product to inventory ID",
error=error_str, product_name=product_name, tenant_id=tenant_id)
if attempt < max_retries - 1:
# Still retry even for other errors, in case it's transient
continue
else:
break # Don't retry on final attempt
# If all retries failed, log detailed error but don't mark as permanently failed yet
# Instead, we'll implement a fallback mechanism
logger.error("Failed to resolve product after all retries, attempting fallback",
product_name=product_name, tenant_id=tenant_id)
# FALLBACK: Try to create a temporary product with minimal data
try:
# Use a simplified approach with minimal data
fallback_data = {
'name': product_name,
'type': 'finished_product',
'unit': 'unit',
'current_stock': 0,
'cost_per_unit': 0
}
logger.info("Attempting fallback product creation with minimal data",
product_name=product_name, tenant_id=tenant_id)
created_product = await self.inventory_client.create_ingredient(fallback_data, str(tenant_id))
if created_product and created_product.get('id'):
product_id = created_product['id']
uuid_id = UUID(str(product_id))
self.product_cache[product_name] = uuid_id
logger.info("SUCCESS: Fallback product creation succeeded",
product_name=product_name, product_id=product_id, tenant_id=tenant_id)
return uuid_id
except Exception as fallback_error:
logger.error("Fallback product creation also failed",
product_name=product_name, error=str(fallback_error), tenant_id=tenant_id)
# Only mark as permanently failed after all attempts including fallback
self.failed_products.add(product_name)
logger.error("CRITICAL: Permanently failed to resolve product - this will result in missing training data",
product_name=product_name, tenant_id=tenant_id)
return None
async def _batch_resolve_products(self, unique_products: set, tenant_id: str) -> None:
"""Batch resolve/create products to reduce API calls and improve success rate"""
if not unique_products:
return
logger.info(f"Starting batch product resolution for {len(unique_products)} unique products")
# Convert set to list for easier handling
products_list = list(unique_products)
batch_size = 5 # Process in smaller batches to avoid overwhelming the inventory service
for i in range(0, len(products_list), batch_size):
batch = products_list[i:i + batch_size]
logger.info(f"Processing batch {i//batch_size + 1}/{(len(products_list) + batch_size - 1)//batch_size}")
# Process each product in the batch with retry logic
for product_name, product_category in batch:
try:
# Skip if already in cache or failed list
if product_name in self.product_cache or product_name in self.failed_products:
continue
# Try to resolve the product
await self._resolve_product_to_inventory_id(product_name, product_category, tenant_id)
# Add small delay between products to be gentle on the API
await asyncio.sleep(0.5)
except Exception as e:
logger.warning(f"Failed to batch process product {product_name}: {e}")
continue
# Add delay between batches
if i + batch_size < len(products_list):
logger.info("Waiting between batches to avoid rate limiting...")
await asyncio.sleep(2.0)
successful_resolutions = len([p for p, _ in products_list if p in self.product_cache])
failed_resolutions = len([p for p, _ in products_list if p in self.failed_products])
logger.info(f"Batch product resolution completed: {successful_resolutions} successful, {failed_resolutions} failed")
if failed_resolutions > 0:
logger.warning(f"ATTENTION: {failed_resolutions} products failed to resolve - these will be missing from training data")
return
def _structure_messages(self, messages: List[Union[str, Dict]]) -> List[Dict[str, Any]]:
"""Convert string messages to structured format"""

View File

@@ -123,15 +123,35 @@ class InventoryServiceClient:
try:
result = await self._shared_client.create_ingredient(ingredient_data, tenant_id)
if result:
logger.info("Created ingredient in inventory service",
logger.info("Created ingredient in inventory service",
ingredient_name=ingredient_data.get('name'), tenant_id=tenant_id)
return result
except Exception as e:
logger.error("Error creating ingredient",
logger.error("Error creating ingredient",
error=str(e), ingredient_data=ingredient_data, tenant_id=tenant_id)
return None
async def resolve_or_create_products_batch(
self,
products: List[Dict[str, Any]],
tenant_id: str
) -> Optional[Dict[str, Any]]:
"""Resolve or create multiple products in a single batch operation"""
try:
result = await self._shared_client.resolve_or_create_products_batch(products, tenant_id)
if result:
logger.info("Batch product resolution complete",
created=result.get('created_count', 0),
resolved=result.get('resolved_count', 0),
tenant_id=tenant_id)
return result
except Exception as e:
logger.error("Error in batch product resolution",
error=str(e), products_count=len(products), tenant_id=tenant_id)
return None
# Dependency injection
async def get_inventory_client() -> InventoryServiceClient:
"""Get inventory service client instance"""

View File

@@ -165,14 +165,6 @@ async def start_training_job(
if metrics:
metrics.increment_counter("enhanced_training_jobs_created_total")
# Publish training.started event immediately so WebSocket clients
# have initial state when they connect
await publish_training_started(
job_id=job_id,
tenant_id=tenant_id,
total_products=0 # Will be updated when actual training starts
)
# Calculate intelligent time estimate
# We don't know exact product count yet, so use historical average or estimate
try:
@@ -192,6 +184,19 @@ async def start_training_job(
error=str(est_error))
estimated_duration_minutes = 15 # Default fallback
# Calculate estimated completion time
estimated_completion_time = calculate_estimated_completion_time(estimated_duration_minutes)
# Publish training.started event immediately so WebSocket clients
# have initial state when they connect
await publish_training_started(
job_id=job_id,
tenant_id=tenant_id,
total_products=0, # Will be updated when actual training starts
estimated_duration_minutes=estimated_duration_minutes,
estimated_completion_time=estimated_completion_time.isoformat()
)
# Add enhanced background task
background_tasks.add_task(
execute_training_job_background,
@@ -362,15 +367,8 @@ async def execute_training_job_background(
requested_end=requested_end
)
# Update final status using repository pattern
await enhanced_training_service._update_job_status_repository(
job_id=job_id,
status="completed",
progress=100,
current_step="Enhanced training completed successfully",
results=result,
tenant_id=tenant_id
)
# Note: Final status is already updated by start_training_job() via complete_training_log()
# No need for redundant update here - it was causing duplicate log entries
# Completion event is published by the training service

View File

@@ -138,14 +138,14 @@ class DataClient:
self._fetch_sales_data_internal,
tenant_id, start_date, end_date, product_id, fetch_all
)
except CircuitBreakerError as e:
logger.error(f"Sales service circuit breaker open: {e}")
raise RuntimeError(f"Sales service unavailable: {str(e)}")
except CircuitBreakerError as exc:
logger.error("Sales service circuit breaker open", error_message=str(exc))
raise RuntimeError(f"Sales service unavailable: {str(exc)}")
except ValueError:
raise
except Exception as e:
logger.error(f"Error fetching sales data: {e}", tenant_id=tenant_id)
raise RuntimeError(f"Failed to fetch sales data: {str(e)}")
except Exception as exc:
logger.error("Error fetching sales data", tenant_id=tenant_id, error_message=str(exc))
raise RuntimeError(f"Failed to fetch sales data: {str(exc)}")
async def fetch_weather_data(
self,
@@ -176,8 +176,8 @@ class DataClient:
logger.warning("No weather data returned, will use synthetic data", tenant_id=tenant_id)
return []
except Exception as e:
logger.warning(f"Error fetching weather data, will use synthetic data: {e}", tenant_id=tenant_id)
except Exception as exc:
logger.warning("Error fetching weather data, will use synthetic data", tenant_id=tenant_id, error_message=str(exc))
return []
async def fetch_traffic_data_unified(
@@ -254,9 +254,9 @@ class DataClient:
logger.warning("No fresh traffic data available", tenant_id=tenant_id)
return []
except Exception as e:
logger.error(f"Error in unified traffic data fetch: {e}",
tenant_id=tenant_id, cache_key=cache_key)
except Exception as exc:
logger.error("Error in unified traffic data fetch",
tenant_id=tenant_id, cache_key=cache_key, error_message=str(exc))
return []
# Legacy methods for backward compatibility - now delegate to unified method
@@ -405,9 +405,9 @@ class DataClient:
return result
except Exception as e:
logger.error(f"Error validating data: {e}", tenant_id=tenant_id)
raise ValueError(f"Data validation failed: {str(e)}")
except Exception as exc:
logger.error("Error validating data", tenant_id=tenant_id, error_message=str(exc))
raise ValueError(f"Data validation failed: {str(exc)}")
# Global instance - same as before, but much simpler implementation
data_client = DataClient()

View File

@@ -6,8 +6,10 @@ Manages progress calculation for parallel product training (20-80% range)
import asyncio
import structlog
from typing import Optional
from datetime import datetime, timezone
from app.services.training_events import publish_product_training_completed
from app.utils.time_estimation import calculate_estimated_completion_time
logger = structlog.get_logger()
@@ -20,6 +22,7 @@ class ParallelProductProgressTracker:
- Each product completion contributes 60/N% to overall progress
- Progress range: 20% (after data analysis) to 80% (before completion)
- Thread-safe for concurrent product trainings
- Calculates time estimates based on elapsed time and progress
"""
def __init__(self, job_id: str, tenant_id: str, total_products: int):
@@ -28,6 +31,7 @@ class ParallelProductProgressTracker:
self.total_products = total_products
self.products_completed = 0
self._lock = asyncio.Lock()
self.start_time = datetime.now(timezone.utc)
# Calculate progress increment per product
# 60% of total progress (from 20% to 80%) divided by number of products
@@ -40,20 +44,40 @@ class ParallelProductProgressTracker:
async def mark_product_completed(self, product_name: str) -> int:
"""
Mark a product as completed and publish event.
Mark a product as completed and publish event with time estimates.
Returns the current overall progress percentage.
"""
async with self._lock:
self.products_completed += 1
current_progress = self.products_completed
# Publish product completion event
# Calculate time estimates based on elapsed time and progress
elapsed_seconds = (datetime.now(timezone.utc) - self.start_time).total_seconds()
products_remaining = self.total_products - current_progress
# Calculate estimated time remaining
# Avg time per product * remaining products
estimated_time_remaining_seconds = None
estimated_completion_time = None
if current_progress > 0 and products_remaining > 0:
avg_time_per_product = elapsed_seconds / current_progress
estimated_time_remaining_seconds = int(avg_time_per_product * products_remaining)
# Calculate estimated completion time
estimated_duration_minutes = estimated_time_remaining_seconds / 60
completion_datetime = calculate_estimated_completion_time(estimated_duration_minutes)
estimated_completion_time = completion_datetime.isoformat()
# Publish product completion event with time estimates
await publish_product_training_completed(
job_id=self.job_id,
tenant_id=self.tenant_id,
product_name=product_name,
products_completed=current_progress,
total_products=self.total_products
total_products=self.total_products,
estimated_time_remaining_seconds=estimated_time_remaining_seconds,
estimated_completion_time=estimated_completion_time
)
# Calculate overall progress (20% base + progress from completed products)
@@ -65,7 +89,8 @@ class ParallelProductProgressTracker:
product_name=product_name,
products_completed=current_progress,
total_products=self.total_products,
overall_progress=overall_progress)
overall_progress=overall_progress,
estimated_time_remaining_seconds=estimated_time_remaining_seconds)
return overall_progress

View File

@@ -91,7 +91,8 @@ async def publish_data_analysis(
job_id: str,
tenant_id: str,
analysis_details: Optional[str] = None,
estimated_time_remaining_seconds: Optional[int] = None
estimated_time_remaining_seconds: Optional[int] = None,
estimated_completion_time: Optional[str] = None
) -> bool:
"""
Event 2: Data Analysis (20% progress)
@@ -101,6 +102,7 @@ async def publish_data_analysis(
tenant_id: Tenant identifier
analysis_details: Details about the analysis
estimated_time_remaining_seconds: Estimated time remaining in seconds
estimated_completion_time: ISO timestamp of estimated completion
"""
event_data = {
"service_name": "training-service",
@@ -112,7 +114,8 @@ async def publish_data_analysis(
"progress": 20,
"current_step": "Data Analysis",
"step_details": analysis_details or "Analyzing sales, weather, and traffic data",
"estimated_time_remaining_seconds": estimated_time_remaining_seconds
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
"estimated_completion_time": estimated_completion_time
}
}
@@ -138,7 +141,8 @@ async def publish_product_training_completed(
product_name: str,
products_completed: int,
total_products: int,
estimated_time_remaining_seconds: Optional[int] = None
estimated_time_remaining_seconds: Optional[int] = None,
estimated_completion_time: Optional[str] = None
) -> bool:
"""
Event 3: Product Training Completed (contributes to 20-80% progress)
@@ -154,6 +158,7 @@ async def publish_product_training_completed(
products_completed: Number of products completed so far
total_products: Total number of products
estimated_time_remaining_seconds: Estimated time remaining in seconds
estimated_completion_time: ISO timestamp of estimated completion
"""
event_data = {
"service_name": "training-service",
@@ -167,7 +172,8 @@ async def publish_product_training_completed(
"total_products": total_products,
"current_step": "Model Training",
"step_details": f"Completed training for {product_name} ({products_completed}/{total_products})",
"estimated_time_remaining_seconds": estimated_time_remaining_seconds
"estimated_time_remaining_seconds": estimated_time_remaining_seconds,
"estimated_completion_time": estimated_completion_time
}
}

View File

@@ -238,11 +238,19 @@ class EnhancedTrainingService:
)
# Step 4: Create performance metrics
await self.training_log_repo.update_log_progress(
job_id, 94, "storing_performance_metrics", "running"
)
await self._create_performance_metrics(
tenant_id, stored_models, training_results
)
# Step 4.5: Save training performance metrics for future estimations
await self._save_training_performance_metrics(
tenant_id, job_id, training_results, training_log
)
# Step 5: Complete training log
final_result = {
"job_id": job_id,
@@ -426,7 +434,7 @@ class EnhancedTrainingService:
model_result = training_results.get("models_trained", {}).get(str(model.inventory_product_id))
if model_result and model_result.get("metrics"):
metrics = model_result["metrics"]
metric_data = {
"model_id": str(model.id),
"tenant_id": tenant_id,
@@ -439,13 +447,84 @@ class EnhancedTrainingService:
"accuracy_percentage": metrics.get("accuracy_percentage", 100 - metrics.get("mape", 0)),
"evaluation_samples": model.training_samples
}
await self.performance_repo.create_performance_metric(metric_data)
except Exception as e:
logger.error("Failed to create performance metrics",
tenant_id=tenant_id,
error=str(e))
async def _save_training_performance_metrics(
self,
tenant_id: str,
job_id: str,
training_results: Dict[str, Any],
training_log
):
"""
Save aggregated training performance metrics for time estimation.
This data is used to predict future training durations.
"""
try:
from app.models.training import TrainingPerformanceMetrics
# Extract timing and success data
models_trained = training_results.get("models_trained", {})
total_products = len(models_trained)
successful_products = sum(1 for m in models_trained.values() if m.get("status") == "completed")
failed_products = total_products - successful_products
# Calculate total duration
if training_log.start_time and training_log.end_time:
total_duration_seconds = (training_log.end_time - training_log.start_time).total_seconds()
else:
# Fallback to elapsed time
total_duration_seconds = training_results.get("total_training_time", 0)
# Calculate average time per product
if successful_products > 0:
avg_time_per_product = total_duration_seconds / successful_products
else:
avg_time_per_product = 0
# Extract timing breakdown if available
data_analysis_time = training_results.get("data_analysis_time_seconds")
training_time = training_results.get("training_time_seconds")
finalization_time = training_results.get("finalization_time_seconds")
# Create performance metrics record
metric_data = {
"tenant_id": tenant_id,
"job_id": job_id,
"total_products": total_products,
"successful_products": successful_products,
"failed_products": failed_products,
"total_duration_seconds": total_duration_seconds,
"avg_time_per_product": avg_time_per_product,
"data_analysis_time_seconds": data_analysis_time,
"training_time_seconds": training_time,
"finalization_time_seconds": finalization_time,
"completed_at": datetime.now(timezone.utc)
}
# Use repository to create record
performance_metrics = TrainingPerformanceMetrics(**metric_data)
self.session.add(performance_metrics)
await self.session.commit()
logger.info("Saved training performance metrics for future estimations",
tenant_id=tenant_id,
job_id=job_id,
avg_time_per_product=avg_time_per_product,
total_products=total_products,
successful_products=successful_products)
except Exception as e:
logger.error("Failed to save training performance metrics",
tenant_id=tenant_id,
job_id=job_id,
error=str(e))
async def get_training_status(self, job_id: str) -> Dict[str, Any]:
"""Get training job status using repository"""