REFACTOR data service
This commit is contained in:
@@ -1,40 +0,0 @@
|
||||
# Add this stage at the top of each service Dockerfile
|
||||
FROM python:3.11-slim as shared
|
||||
WORKDIR /shared
|
||||
COPY shared/ /shared/
|
||||
|
||||
# Then your main service stage
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements
|
||||
COPY services/data/requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared libraries from the shared stage
|
||||
COPY --from=shared /shared /app/shared
|
||||
|
||||
# Copy application code
|
||||
COPY services/data/ .
|
||||
|
||||
# Add shared libraries to Python path
|
||||
ENV PYTHONPATH="/app:/app/shared:$PYTHONPATH"
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# Run application
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -1,14 +0,0 @@
|
||||
"""
|
||||
Data Service API Layer
|
||||
API endpoints for data operations
|
||||
"""
|
||||
|
||||
from .sales import router as sales_router
|
||||
from .traffic import router as traffic_router
|
||||
from .weather import router as weather_router
|
||||
|
||||
__all__ = [
|
||||
"sales_router",
|
||||
"traffic_router",
|
||||
"weather_router"
|
||||
]
|
||||
@@ -1,500 +0,0 @@
|
||||
"""
|
||||
Enhanced Sales API Endpoints
|
||||
Updated to use repository pattern and enhanced services with dependency injection
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Query, Response, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import List, Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
import structlog
|
||||
|
||||
from app.schemas.sales import (
|
||||
SalesDataCreate,
|
||||
SalesDataResponse,
|
||||
SalesDataQuery,
|
||||
SalesDataImport,
|
||||
SalesImportResult,
|
||||
SalesValidationResult,
|
||||
SalesValidationRequest,
|
||||
SalesExportRequest
|
||||
)
|
||||
from app.services.sales_service import SalesService
|
||||
from app.services.data_import_service import EnhancedDataImportService
|
||||
from app.services.messaging import (
|
||||
publish_sales_created,
|
||||
publish_data_imported,
|
||||
publish_export_completed
|
||||
)
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
|
||||
router = APIRouter(tags=["enhanced-sales"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
def get_sales_service():
|
||||
"""Dependency injection for SalesService"""
|
||||
from app.core.config import settings
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "data-service")
|
||||
return SalesService(database_manager)
|
||||
|
||||
|
||||
def get_import_service():
|
||||
"""Dependency injection for EnhancedDataImportService"""
|
||||
from app.core.config import settings
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "data-service")
|
||||
return EnhancedDataImportService(database_manager)
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/sales", response_model=SalesDataResponse)
|
||||
async def create_sales_record(
|
||||
sales_data: SalesDataCreate,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Create a new sales record using repository pattern"""
|
||||
try:
|
||||
logger.info("Creating sales record with repository pattern",
|
||||
product=sales_data.product_name,
|
||||
quantity=sales_data.quantity_sold,
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"])
|
||||
|
||||
# Override tenant_id from URL path
|
||||
sales_data.tenant_id = tenant_id
|
||||
|
||||
record = await sales_service.create_sales_record(sales_data, str(tenant_id))
|
||||
|
||||
# Publish event (non-blocking)
|
||||
try:
|
||||
await publish_sales_created({
|
||||
"tenant_id": str(tenant_id),
|
||||
"product_name": sales_data.product_name,
|
||||
"quantity_sold": sales_data.quantity_sold,
|
||||
"revenue": sales_data.revenue,
|
||||
"source": sales_data.source,
|
||||
"created_by": current_user["user_id"],
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
except Exception as pub_error:
|
||||
logger.warning("Failed to publish sales created event", error=str(pub_error))
|
||||
|
||||
logger.info("Successfully created sales record using repository",
|
||||
record_id=record.id,
|
||||
tenant_id=tenant_id)
|
||||
return record
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create sales record",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create sales record: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/sales", response_model=List[SalesDataResponse])
|
||||
async def get_sales_data(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
start_date: Optional[datetime] = Query(None, description="Start date filter"),
|
||||
end_date: Optional[datetime] = Query(None, description="End date filter"),
|
||||
product_name: Optional[str] = Query(None, description="Product name filter"),
|
||||
limit: Optional[int] = Query(1000, le=5000, description="Maximum number of records to return"),
|
||||
offset: Optional[int] = Query(0, ge=0, description="Number of records to skip"),
|
||||
product_names: Optional[List[str]] = Query(None, description="Multiple product name filters"),
|
||||
location_ids: Optional[List[str]] = Query(None, description="Location ID filters"),
|
||||
sources: Optional[List[str]] = Query(None, description="Source filters"),
|
||||
min_quantity: Optional[int] = Query(None, description="Minimum quantity filter"),
|
||||
max_quantity: Optional[int] = Query(None, description="Maximum quantity filter"),
|
||||
min_revenue: Optional[float] = Query(None, description="Minimum revenue filter"),
|
||||
max_revenue: Optional[float] = Query(None, description="Maximum revenue filter"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Get sales data using repository pattern with enhanced filtering"""
|
||||
try:
|
||||
logger.debug("Querying sales data with repository pattern",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
offset=offset)
|
||||
|
||||
# Create enhanced query
|
||||
query = SalesDataQuery(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_names=[product_name] if product_name else product_names,
|
||||
location_ids=location_ids,
|
||||
sources=sources,
|
||||
min_quantity=min_quantity,
|
||||
max_quantity=max_quantity,
|
||||
min_revenue=min_revenue,
|
||||
max_revenue=max_revenue,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
records = await sales_service.get_sales_data(query)
|
||||
|
||||
logger.debug("Successfully retrieved sales data using repository",
|
||||
count=len(records),
|
||||
tenant_id=tenant_id)
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to query sales data",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to query sales data: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/sales/analytics")
|
||||
async def get_sales_analytics(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
start_date: Optional[datetime] = Query(None, description="Start date"),
|
||||
end_date: Optional[datetime] = Query(None, description="End date"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Get sales analytics using repository pattern"""
|
||||
try:
|
||||
logger.debug("Getting sales analytics with repository pattern",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date)
|
||||
|
||||
analytics = await sales_service.get_sales_analytics(
|
||||
str(tenant_id), start_date, end_date
|
||||
)
|
||||
|
||||
logger.debug("Analytics generated successfully using repository", tenant_id=tenant_id)
|
||||
return analytics
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to generate sales analytics",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to generate analytics: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/sales/aggregation")
|
||||
async def get_sales_aggregation(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
start_date: Optional[datetime] = Query(None, description="Start date"),
|
||||
end_date: Optional[datetime] = Query(None, description="End date"),
|
||||
group_by: str = Query("daily", description="Aggregation period: daily, weekly, monthly"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Get sales aggregation data using repository pattern"""
|
||||
try:
|
||||
logger.debug("Getting sales aggregation with repository pattern",
|
||||
tenant_id=tenant_id,
|
||||
group_by=group_by)
|
||||
|
||||
aggregation = await sales_service.get_sales_aggregation(
|
||||
str(tenant_id), start_date, end_date, group_by
|
||||
)
|
||||
|
||||
logger.debug("Aggregation generated successfully using repository",
|
||||
tenant_id=tenant_id,
|
||||
group_by=group_by)
|
||||
return aggregation
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales aggregation",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get aggregation: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/sales/import", response_model=SalesImportResult)
|
||||
async def import_sales_data(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
file: UploadFile = File(...),
|
||||
file_format: str = Form(...),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
import_service: EnhancedDataImportService = Depends(get_import_service)
|
||||
):
|
||||
"""Import sales data using enhanced repository pattern"""
|
||||
try:
|
||||
logger.info("Importing sales data with enhanced repository pattern",
|
||||
tenant_id=tenant_id,
|
||||
format=file_format,
|
||||
filename=file.filename,
|
||||
user_id=current_user["user_id"])
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
file_content = content.decode('utf-8')
|
||||
|
||||
# Process using enhanced import service
|
||||
result = await import_service.process_import(
|
||||
str(tenant_id),
|
||||
file_content,
|
||||
file_format,
|
||||
filename=file.filename
|
||||
)
|
||||
|
||||
if result.success:
|
||||
# Publish event
|
||||
try:
|
||||
await publish_data_imported({
|
||||
"tenant_id": str(tenant_id),
|
||||
"type": "file_import",
|
||||
"format": file_format,
|
||||
"filename": file.filename,
|
||||
"records_created": result.records_created,
|
||||
"imported_by": current_user["user_id"],
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
except Exception as pub_error:
|
||||
logger.warning("Failed to publish import event", error=str(pub_error))
|
||||
|
||||
logger.info("Import completed with enhanced repository pattern",
|
||||
success=result.success,
|
||||
records_created=result.records_created,
|
||||
tenant_id=tenant_id)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to import sales data",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to import sales data: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/sales/import/validate", response_model=SalesValidationResult)
|
||||
async def validate_import_data(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
file: UploadFile = File(..., description="File to validate"),
|
||||
file_format: str = Form(default="csv", description="File format: csv, json, excel"),
|
||||
validate_only: bool = Form(default=True, description="Only validate, don't import"),
|
||||
source: str = Form(default="onboarding_upload", description="Source of the upload"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
import_service: EnhancedDataImportService = Depends(get_import_service)
|
||||
):
|
||||
"""Validate import data using enhanced repository pattern"""
|
||||
try:
|
||||
logger.info("Validating import data with enhanced repository pattern",
|
||||
tenant_id=tenant_id,
|
||||
format=file_format,
|
||||
filename=file.filename,
|
||||
user_id=current_user["user_id"])
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
file_content = content.decode('utf-8')
|
||||
|
||||
# Create validation data structure
|
||||
validation_data = {
|
||||
"tenant_id": str(tenant_id),
|
||||
"data": file_content,
|
||||
"data_format": file_format,
|
||||
"source": source,
|
||||
"validate_only": validate_only
|
||||
}
|
||||
|
||||
# Use enhanced validation service
|
||||
validation_result = await import_service.validate_import_data(validation_data)
|
||||
|
||||
logger.info("Validation completed with enhanced repository pattern",
|
||||
is_valid=validation_result.is_valid,
|
||||
total_records=validation_result.total_records,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
return validation_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate import data",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to validate import data: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/sales/import/validate-json", response_model=SalesValidationResult)
|
||||
async def validate_import_data_json(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
request: SalesValidationRequest = ...,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
import_service: EnhancedDataImportService = Depends(get_import_service)
|
||||
):
|
||||
"""Validate import data from JSON request for onboarding flow"""
|
||||
|
||||
try:
|
||||
logger.info("Starting JSON-based data validation",
|
||||
tenant_id=str(tenant_id),
|
||||
data_format=request.data_format,
|
||||
data_length=len(request.data),
|
||||
validate_only=request.validate_only)
|
||||
|
||||
# Create validation data structure
|
||||
validation_data = {
|
||||
"tenant_id": str(tenant_id),
|
||||
"data": request.data, # Fixed: use 'data' not 'content'
|
||||
"data_format": request.data_format,
|
||||
"filename": f"onboarding_data.{request.data_format}",
|
||||
"source": request.source,
|
||||
"validate_only": request.validate_only
|
||||
}
|
||||
|
||||
# Use enhanced validation service
|
||||
validation_result = await import_service.validate_import_data(validation_data)
|
||||
|
||||
logger.info("JSON validation completed",
|
||||
is_valid=validation_result.is_valid,
|
||||
total_records=validation_result.total_records,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
return validation_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate JSON import data",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to validate import data: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/sales/export")
|
||||
async def export_sales_data(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
export_format: str = Query("csv", description="Export format: csv, excel, json"),
|
||||
start_date: Optional[datetime] = Query(None, description="Start date"),
|
||||
end_date: Optional[datetime] = Query(None, description="End date"),
|
||||
products: Optional[List[str]] = Query(None, description="Filter by products"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Export sales data using repository pattern"""
|
||||
try:
|
||||
logger.info("Exporting sales data with repository pattern",
|
||||
tenant_id=tenant_id,
|
||||
format=export_format,
|
||||
user_id=current_user["user_id"])
|
||||
|
||||
export_result = await sales_service.export_sales_data(
|
||||
str(tenant_id), export_format, start_date, end_date, products
|
||||
)
|
||||
|
||||
if not export_result:
|
||||
raise HTTPException(status_code=404, detail="No data found for export")
|
||||
|
||||
# Publish export event
|
||||
try:
|
||||
await publish_export_completed({
|
||||
"tenant_id": str(tenant_id),
|
||||
"format": export_format,
|
||||
"exported_by": current_user["user_id"],
|
||||
"record_count": export_result.get("record_count", 0),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
except Exception as pub_error:
|
||||
logger.warning("Failed to publish export event", error=str(pub_error))
|
||||
|
||||
logger.info("Export completed successfully using repository",
|
||||
tenant_id=tenant_id,
|
||||
format=export_format)
|
||||
|
||||
return StreamingResponse(
|
||||
iter([export_result["content"]]),
|
||||
media_type=export_result["media_type"],
|
||||
headers={"Content-Disposition": f"attachment; filename={export_result['filename']}"}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to export sales data",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to export sales data: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/tenants/{tenant_id}/sales/{record_id}")
|
||||
async def delete_sales_record(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
record_id: str = Path(..., description="Sales record ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Delete a sales record using repository pattern"""
|
||||
try:
|
||||
logger.info("Deleting sales record with repository pattern",
|
||||
record_id=record_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"])
|
||||
|
||||
success = await sales_service.delete_sales_record(record_id, str(tenant_id))
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Sales record not found")
|
||||
|
||||
logger.info("Sales record deleted successfully using repository",
|
||||
record_id=record_id,
|
||||
tenant_id=tenant_id)
|
||||
return {"status": "success", "message": "Sales record deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to delete sales record",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete sales record: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/sales/products")
|
||||
async def get_products_list(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Get list of products using repository pattern"""
|
||||
try:
|
||||
logger.debug("Getting products list with repository pattern", tenant_id=tenant_id)
|
||||
|
||||
products = await sales_service.get_products_list(str(tenant_id))
|
||||
|
||||
logger.debug("Products list retrieved using repository",
|
||||
count=len(products),
|
||||
tenant_id=tenant_id)
|
||||
return products
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get products list",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get products list: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/sales/statistics")
|
||||
async def get_sales_statistics(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Get comprehensive sales statistics using repository pattern"""
|
||||
try:
|
||||
logger.debug("Getting sales statistics with repository pattern", tenant_id=tenant_id)
|
||||
|
||||
# Get analytics which includes comprehensive statistics
|
||||
analytics = await sales_service.get_sales_analytics(str(tenant_id))
|
||||
|
||||
# Create enhanced statistics response
|
||||
statistics = {
|
||||
"tenant_id": str(tenant_id),
|
||||
"analytics": analytics,
|
||||
"generated_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
logger.debug("Sales statistics retrieved using repository", tenant_id=tenant_id)
|
||||
return statistics
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales statistics",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get statistics: {str(e)}")
|
||||
@@ -1,196 +0,0 @@
|
||||
"""
|
||||
Database configuration for data service
|
||||
Uses shared database infrastructure for consistency
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import AsyncGenerator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
|
||||
from shared.database.base import DatabaseManager, Base
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Initialize database manager using shared infrastructure
|
||||
database_manager = DatabaseManager(
|
||||
database_url=settings.DATABASE_URL,
|
||||
service_name="data",
|
||||
pool_size=15,
|
||||
max_overflow=25,
|
||||
echo=settings.DEBUG if hasattr(settings, 'DEBUG') else False
|
||||
)
|
||||
|
||||
# Alias for convenience - matches the existing interface
|
||||
get_db = database_manager.get_db
|
||||
|
||||
# Use the shared background session method
|
||||
get_background_db_session = database_manager.get_background_session
|
||||
|
||||
async def get_db_health() -> bool:
|
||||
"""Health check function for database connectivity"""
|
||||
try:
|
||||
async with database_manager.async_engine.begin() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
logger.debug("Database health check passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
return False
|
||||
|
||||
async def init_db():
|
||||
"""Initialize database tables using shared infrastructure"""
|
||||
try:
|
||||
logger.info("Initializing data service database")
|
||||
|
||||
# Import models to ensure they're registered
|
||||
from app.models.sales import SalesData
|
||||
from app.models.traffic import TrafficData
|
||||
from app.models.weather import WeatherData
|
||||
|
||||
# Create tables using shared infrastructure
|
||||
await database_manager.create_tables()
|
||||
|
||||
logger.info("Data service database initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize data service database", error=str(e))
|
||||
raise
|
||||
|
||||
# Data service specific database utilities
|
||||
class DataDatabaseUtils:
|
||||
"""Data service specific database utilities"""
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_sales_data(days_old: int = 730):
|
||||
"""Clean up old sales data (default 2 years)"""
|
||||
try:
|
||||
async with database_manager.get_background_session() as session:
|
||||
if settings.DATABASE_URL.startswith("sqlite"):
|
||||
query = text(
|
||||
"DELETE FROM sales_data "
|
||||
"WHERE created_at < datetime('now', :days_param)"
|
||||
)
|
||||
params = {"days_param": f"-{days_old} days"}
|
||||
else:
|
||||
query = text(
|
||||
"DELETE FROM sales_data "
|
||||
"WHERE created_at < NOW() - INTERVAL :days_param"
|
||||
)
|
||||
params = {"days_param": f"{days_old} days"}
|
||||
|
||||
result = await session.execute(query, params)
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up old sales data",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old sales data", error=str(e))
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def get_data_statistics(tenant_id: str = None) -> dict:
|
||||
"""Get data service statistics"""
|
||||
try:
|
||||
async with database_manager.get_background_session() as session:
|
||||
# Get sales data statistics
|
||||
if tenant_id:
|
||||
sales_query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM sales_data "
|
||||
"WHERE tenant_id = :tenant_id"
|
||||
)
|
||||
params = {"tenant_id": tenant_id}
|
||||
else:
|
||||
sales_query = text("SELECT COUNT(*) as count FROM sales_data")
|
||||
params = {}
|
||||
|
||||
sales_result = await session.execute(sales_query, params)
|
||||
sales_count = sales_result.scalar() or 0
|
||||
|
||||
# Get traffic data statistics (if exists)
|
||||
try:
|
||||
traffic_query = text("SELECT COUNT(*) as count FROM traffic_data")
|
||||
if tenant_id:
|
||||
# Traffic data might not have tenant_id, check table structure
|
||||
pass
|
||||
|
||||
traffic_result = await session.execute(traffic_query)
|
||||
traffic_count = traffic_result.scalar() or 0
|
||||
except:
|
||||
traffic_count = 0
|
||||
|
||||
# Get weather data statistics (if exists)
|
||||
try:
|
||||
weather_query = text("SELECT COUNT(*) as count FROM weather_data")
|
||||
weather_result = await session.execute(weather_query)
|
||||
weather_count = weather_result.scalar() or 0
|
||||
except:
|
||||
weather_count = 0
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"sales_records": sales_count,
|
||||
"traffic_records": traffic_count,
|
||||
"weather_records": weather_count,
|
||||
"total_records": sales_count + traffic_count + weather_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get data statistics", error=str(e))
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"sales_records": 0,
|
||||
"traffic_records": 0,
|
||||
"weather_records": 0,
|
||||
"total_records": 0
|
||||
}
|
||||
|
||||
# Enhanced database session dependency with better error handling
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Enhanced database session dependency with better logging and error handling"""
|
||||
async with database_manager.async_session_local() as session:
|
||||
try:
|
||||
logger.debug("Database session created")
|
||||
yield session
|
||||
except Exception as e:
|
||||
logger.error("Database session error", error=str(e), exc_info=True)
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
logger.debug("Database session closed")
|
||||
|
||||
# Database cleanup for data service
|
||||
async def cleanup_data_database():
|
||||
"""Cleanup database connections for data service"""
|
||||
try:
|
||||
logger.info("Cleaning up data service database connections")
|
||||
|
||||
# Close engine connections
|
||||
if hasattr(database_manager, 'async_engine') and database_manager.async_engine:
|
||||
await database_manager.async_engine.dispose()
|
||||
|
||||
logger.info("Data service database cleanup completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup data service database", error=str(e))
|
||||
|
||||
# Export the commonly used items to maintain compatibility
|
||||
__all__ = [
|
||||
'Base',
|
||||
'database_manager',
|
||||
'get_db',
|
||||
'get_background_db_session',
|
||||
'get_db_session',
|
||||
'get_db_health',
|
||||
'DataDatabaseUtils',
|
||||
'init_db',
|
||||
'cleanup_data_database'
|
||||
]
|
||||
@@ -1,312 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/app/core/performance.py
|
||||
# ================================================================
|
||||
"""
|
||||
Performance optimization utilities for async operations
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Any, Callable, Dict, Optional, TypeVar
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import hashlib
|
||||
import json
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class AsyncCache:
|
||||
"""Simple in-memory async cache with TTL"""
|
||||
|
||||
def __init__(self, default_ttl: int = 300):
|
||||
self.cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.default_ttl = default_ttl
|
||||
|
||||
def _generate_key(self, *args, **kwargs) -> str:
|
||||
"""Generate cache key from arguments"""
|
||||
key_data = {
|
||||
'args': args,
|
||||
'kwargs': sorted(kwargs.items())
|
||||
}
|
||||
key_string = json.dumps(key_data, sort_keys=True, default=str)
|
||||
return hashlib.md5(key_string.encode()).hexdigest()
|
||||
|
||||
def _is_expired(self, entry: Dict[str, Any]) -> bool:
|
||||
"""Check if cache entry is expired"""
|
||||
expires_at = entry.get('expires_at')
|
||||
if not expires_at:
|
||||
return True
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""Get value from cache"""
|
||||
if key in self.cache:
|
||||
entry = self.cache[key]
|
||||
if not self._is_expired(entry):
|
||||
logger.debug("Cache hit", cache_key=key)
|
||||
return entry['value']
|
||||
else:
|
||||
# Clean up expired entry
|
||||
del self.cache[key]
|
||||
logger.debug("Cache expired", cache_key=key)
|
||||
|
||||
logger.debug("Cache miss", cache_key=key)
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
|
||||
"""Set value in cache"""
|
||||
ttl = ttl or self.default_ttl
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=ttl)
|
||||
|
||||
self.cache[key] = {
|
||||
'value': value,
|
||||
'expires_at': expires_at,
|
||||
'created_at': datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
logger.debug("Cache set", cache_key=key, ttl=ttl)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cache entries"""
|
||||
self.cache.clear()
|
||||
logger.info("Cache cleared")
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Clean up expired entries"""
|
||||
expired_keys = [
|
||||
key for key, entry in self.cache.items()
|
||||
if self._is_expired(entry)
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self.cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.info("Cleaned up expired cache entries", count=len(expired_keys))
|
||||
|
||||
return len(expired_keys)
|
||||
|
||||
|
||||
def async_cache(ttl: int = 300, cache_instance: Optional[AsyncCache] = None):
|
||||
"""Decorator for caching async function results"""
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
_cache = cache_instance or AsyncCache(ttl)
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Generate cache key
|
||||
cache_key = _cache._generate_key(func.__name__, *args, **kwargs)
|
||||
|
||||
# Try to get from cache
|
||||
cached_result = await _cache.get(cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
# Execute function and cache result
|
||||
result = await func(*args, **kwargs)
|
||||
await _cache.set(cache_key, result, ttl)
|
||||
|
||||
return result
|
||||
|
||||
# Add cache management methods
|
||||
wrapper.cache_clear = _cache.clear
|
||||
wrapper.cache_cleanup = _cache.cleanup_expired
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
"""Simple connection pool for HTTP clients"""
|
||||
|
||||
def __init__(self, max_connections: int = 10):
|
||||
self.max_connections = max_connections
|
||||
self.semaphore = asyncio.Semaphore(max_connections)
|
||||
self._active_connections = 0
|
||||
|
||||
async def acquire(self):
|
||||
"""Acquire a connection slot"""
|
||||
await self.semaphore.acquire()
|
||||
self._active_connections += 1
|
||||
logger.debug("Connection acquired", active=self._active_connections, max=self.max_connections)
|
||||
|
||||
async def release(self):
|
||||
"""Release a connection slot"""
|
||||
self.semaphore.release()
|
||||
self._active_connections = max(0, self._active_connections - 1)
|
||||
logger.debug("Connection released", active=self._active_connections, max=self.max_connections)
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.release()
|
||||
|
||||
|
||||
def rate_limit(calls: int, period: int):
|
||||
"""Rate limiting decorator"""
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
call_times = []
|
||||
lock = asyncio.Lock()
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with lock:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Remove old call times
|
||||
cutoff = now - timedelta(seconds=period)
|
||||
call_times[:] = [t for t in call_times if t > cutoff]
|
||||
|
||||
# Check rate limit
|
||||
if len(call_times) >= calls:
|
||||
sleep_time = (call_times[0] + timedelta(seconds=period) - now).total_seconds()
|
||||
if sleep_time > 0:
|
||||
logger.warning("Rate limit reached, sleeping", sleep_time=sleep_time)
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
# Record this call
|
||||
call_times.append(now)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
async def batch_process(
|
||||
items: list,
|
||||
process_func: Callable,
|
||||
batch_size: int = 10,
|
||||
max_concurrency: int = 5
|
||||
) -> list:
|
||||
"""Process items in batches with controlled concurrency"""
|
||||
|
||||
results = []
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def process_batch(batch):
|
||||
async with semaphore:
|
||||
return await process_func(batch)
|
||||
|
||||
# Create batches
|
||||
batches = [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
|
||||
|
||||
logger.info("Processing items in batches",
|
||||
total_items=len(items),
|
||||
batches=len(batches),
|
||||
batch_size=batch_size,
|
||||
max_concurrency=max_concurrency)
|
||||
|
||||
# Process batches concurrently
|
||||
batch_results = await asyncio.gather(
|
||||
*[process_batch(batch) for batch in batches],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
# Flatten results
|
||||
for batch_result in batch_results:
|
||||
if isinstance(batch_result, Exception):
|
||||
logger.error("Batch processing error", error=str(batch_result))
|
||||
continue
|
||||
|
||||
if isinstance(batch_result, list):
|
||||
results.extend(batch_result)
|
||||
else:
|
||||
results.append(batch_result)
|
||||
|
||||
logger.info("Batch processing completed",
|
||||
processed_items=len(results),
|
||||
total_batches=len(batches))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class PerformanceMonitor:
|
||||
"""Simple performance monitoring for async functions"""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics = {}
|
||||
|
||||
def record_execution(self, func_name: str, duration: float, success: bool = True):
|
||||
"""Record function execution metrics"""
|
||||
if func_name not in self.metrics:
|
||||
self.metrics[func_name] = {
|
||||
'call_count': 0,
|
||||
'success_count': 0,
|
||||
'error_count': 0,
|
||||
'total_duration': 0.0,
|
||||
'min_duration': float('inf'),
|
||||
'max_duration': 0.0
|
||||
}
|
||||
|
||||
metric = self.metrics[func_name]
|
||||
metric['call_count'] += 1
|
||||
metric['total_duration'] += duration
|
||||
metric['min_duration'] = min(metric['min_duration'], duration)
|
||||
metric['max_duration'] = max(metric['max_duration'], duration)
|
||||
|
||||
if success:
|
||||
metric['success_count'] += 1
|
||||
else:
|
||||
metric['error_count'] += 1
|
||||
|
||||
def get_metrics(self, func_name: str = None) -> dict:
|
||||
"""Get performance metrics"""
|
||||
if func_name:
|
||||
metric = self.metrics.get(func_name, {})
|
||||
if metric and metric['call_count'] > 0:
|
||||
metric['avg_duration'] = metric['total_duration'] / metric['call_count']
|
||||
metric['success_rate'] = metric['success_count'] / metric['call_count']
|
||||
return metric
|
||||
|
||||
return self.metrics
|
||||
|
||||
|
||||
def monitor_performance(monitor: Optional[PerformanceMonitor] = None):
|
||||
"""Decorator to monitor function performance"""
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
_monitor = monitor or PerformanceMonitor()
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time = datetime.now(timezone.utc)
|
||||
success = True
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
success = False
|
||||
raise
|
||||
finally:
|
||||
end_time = datetime.now(timezone.utc)
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
_monitor.record_execution(func.__name__, duration, success)
|
||||
|
||||
logger.debug("Function performance",
|
||||
function=func.__name__,
|
||||
duration=duration,
|
||||
success=success)
|
||||
|
||||
# Add metrics access
|
||||
wrapper.get_metrics = lambda: _monitor.get_metrics(func.__name__)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Global instances
|
||||
global_cache = AsyncCache(default_ttl=300)
|
||||
global_connection_pool = ConnectionPool(max_connections=20)
|
||||
global_performance_monitor = PerformanceMonitor()
|
||||
0
services/data/app/external/__init__.py
vendored
0
services/data/app/external/__init__.py
vendored
@@ -1,34 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/app/models/sales.py - MISSING FILE
|
||||
# ================================================================
|
||||
"""Sales data models"""
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Float, Integer, Text, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.core.database import Base
|
||||
|
||||
class SalesData(Base):
|
||||
__tablename__ = "sales_data"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
date = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
product_name = Column(String(255), nullable=False, index=True)
|
||||
quantity_sold = Column(Integer, nullable=False)
|
||||
revenue = Column(Float, nullable=False)
|
||||
location_id = Column(String(100), nullable=True, index=True)
|
||||
source = Column(String(50), nullable=False, default="manual")
|
||||
notes = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_sales_tenant_date', 'tenant_id', 'date'),
|
||||
Index('idx_sales_tenant_product', 'tenant_id', 'product_name'),
|
||||
Index('idx_sales_tenant_location', 'tenant_id', 'location_id'),
|
||||
)
|
||||
@@ -1,12 +0,0 @@
|
||||
"""
|
||||
Data Service Repositories
|
||||
Repository implementations for data service
|
||||
"""
|
||||
|
||||
from .base import DataBaseRepository
|
||||
from .sales_repository import SalesRepository
|
||||
|
||||
__all__ = [
|
||||
"DataBaseRepository",
|
||||
"SalesRepository"
|
||||
]
|
||||
@@ -1,167 +0,0 @@
|
||||
"""
|
||||
Base Repository for Data Service
|
||||
Service-specific repository base class with data service utilities
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type, TypeVar, Generic
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
|
||||
from shared.database.repository import BaseRepository
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Type variables for the data service repository
|
||||
Model = TypeVar('Model')
|
||||
CreateSchema = TypeVar('CreateSchema')
|
||||
UpdateSchema = TypeVar('UpdateSchema')
|
||||
|
||||
|
||||
class DataBaseRepository(BaseRepository[Model, CreateSchema, UpdateSchema], Generic[Model, CreateSchema, UpdateSchema]):
|
||||
"""Base repository for data service with common data operations"""
|
||||
|
||||
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def get_by_tenant_id(
|
||||
self,
|
||||
tenant_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List:
|
||||
"""Get records filtered by tenant_id"""
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"tenant_id": tenant_id}
|
||||
)
|
||||
|
||||
async def get_by_date_range(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List:
|
||||
"""Get records filtered by tenant and date range"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
|
||||
# Build date range filter
|
||||
if start_date or end_date:
|
||||
if not hasattr(self.model, 'date'):
|
||||
raise ValidationError("Model does not have 'date' field for date filtering")
|
||||
|
||||
# This would need a more complex implementation for date ranges
|
||||
# For now, we'll use the basic filter
|
||||
if start_date and end_date:
|
||||
# Would need custom query building for date ranges
|
||||
pass
|
||||
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters=filters,
|
||||
order_by="date",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get records by date range",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Date range query failed: {str(e)}")
|
||||
|
||||
async def count_by_tenant(self, tenant_id: str) -> int:
|
||||
"""Count records for a specific tenant"""
|
||||
return await self.count(filters={"tenant_id": tenant_id})
|
||||
|
||||
async def validate_tenant_access(self, tenant_id: str, record_id: Any) -> bool:
|
||||
"""Validate that a record belongs to the specified tenant"""
|
||||
try:
|
||||
record = await self.get_by_id(record_id)
|
||||
if not record:
|
||||
return False
|
||||
|
||||
# Check if record has tenant_id field and matches
|
||||
if hasattr(record, 'tenant_id'):
|
||||
return str(record.tenant_id) == str(tenant_id)
|
||||
|
||||
return True # If no tenant_id field, allow access
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate tenant access",
|
||||
tenant_id=tenant_id,
|
||||
record_id=record_id,
|
||||
error=str(e))
|
||||
return False
|
||||
|
||||
async def get_tenant_stats(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get statistics for a specific tenant"""
|
||||
try:
|
||||
total_records = await self.count_by_tenant(tenant_id)
|
||||
|
||||
# Get recent activity (if model has created_at)
|
||||
recent_records = 0
|
||||
if hasattr(self.model, 'created_at'):
|
||||
# This would need custom query for date filtering
|
||||
# For now, return basic stats
|
||||
pass
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"total_records": total_records,
|
||||
"recent_records": recent_records,
|
||||
"model_type": self.model.__name__
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant statistics",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"total_records": 0,
|
||||
"recent_records": 0,
|
||||
"model_type": self.model.__name__,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def cleanup_old_records(
|
||||
self,
|
||||
tenant_id: str,
|
||||
days_old: int = 365,
|
||||
batch_size: int = 1000
|
||||
) -> int:
|
||||
"""Clean up old records for a tenant (if model has date/created_at field)"""
|
||||
try:
|
||||
if not hasattr(self.model, 'created_at') and not hasattr(self.model, 'date'):
|
||||
logger.warning(f"Model {self.model.__name__} has no date field for cleanup")
|
||||
return 0
|
||||
|
||||
# This would need custom implementation with raw SQL
|
||||
# For now, return 0 to indicate no cleanup performed
|
||||
logger.info(f"Cleanup requested for {self.model.__name__} but not implemented")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old records",
|
||||
tenant_id=tenant_id,
|
||||
days_old=days_old,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
def _ensure_utc_datetime(self, dt: Optional[datetime]) -> Optional[datetime]:
|
||||
"""Ensure datetime is UTC timezone aware"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
if dt.tzinfo is None:
|
||||
# Assume naive datetime is UTC
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
|
||||
return dt.astimezone(timezone.utc)
|
||||
@@ -1,517 +0,0 @@
|
||||
"""
|
||||
Sales Repository
|
||||
Repository for sales data operations with business-specific queries
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, func, desc, asc, text
|
||||
from datetime import datetime, timezone
|
||||
import structlog
|
||||
|
||||
from .base import DataBaseRepository
|
||||
from app.models.sales import SalesData
|
||||
from app.schemas.sales import SalesDataCreate, SalesDataResponse
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SalesRepository(DataBaseRepository[SalesData, SalesDataCreate, Dict]):
|
||||
"""Repository for sales data operations"""
|
||||
|
||||
def __init__(self, model_class: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
super().__init__(model_class, session, cache_ttl)
|
||||
|
||||
async def get_by_tenant_and_date_range(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
product_names: Optional[List[str]] = None,
|
||||
location_ids: Optional[List[str]] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[SalesData]:
|
||||
"""Get sales data filtered by tenant, date range, and optional filters"""
|
||||
try:
|
||||
query = select(self.model).where(self.model.tenant_id == tenant_id)
|
||||
|
||||
# Add date range filter
|
||||
if start_date:
|
||||
start_date = self._ensure_utc_datetime(start_date)
|
||||
query = query.where(self.model.date >= start_date)
|
||||
|
||||
if end_date:
|
||||
end_date = self._ensure_utc_datetime(end_date)
|
||||
query = query.where(self.model.date <= end_date)
|
||||
|
||||
# Add product filter
|
||||
if product_names:
|
||||
query = query.where(self.model.product_name.in_(product_names))
|
||||
|
||||
# Add location filter
|
||||
if location_ids:
|
||||
query = query.where(self.model.location_id.in_(location_ids))
|
||||
|
||||
# Order by date descending (most recent first)
|
||||
query = query.order_by(desc(self.model.date))
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales by tenant and date range",
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get sales data: {str(e)}")
|
||||
|
||||
async def get_sales_aggregation(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
group_by: str = "daily",
|
||||
product_name: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get aggregated sales data for analytics"""
|
||||
try:
|
||||
# Determine date truncation based on group_by
|
||||
if group_by == "daily":
|
||||
date_trunc = "day"
|
||||
elif group_by == "weekly":
|
||||
date_trunc = "week"
|
||||
elif group_by == "monthly":
|
||||
date_trunc = "month"
|
||||
else:
|
||||
raise ValidationError(f"Invalid group_by value: {group_by}")
|
||||
|
||||
# Build base query
|
||||
if self.session.bind.dialect.name == 'postgresql':
|
||||
query = text("""
|
||||
SELECT
|
||||
DATE_TRUNC(:date_trunc, date) as period,
|
||||
product_name,
|
||||
COUNT(*) as record_count,
|
||||
SUM(quantity_sold) as total_quantity,
|
||||
SUM(revenue) as total_revenue,
|
||||
AVG(quantity_sold) as average_quantity,
|
||||
AVG(revenue) as average_revenue
|
||||
FROM sales_data
|
||||
WHERE tenant_id = :tenant_id
|
||||
""")
|
||||
else:
|
||||
# SQLite fallback
|
||||
query = text("""
|
||||
SELECT
|
||||
DATE(date) as period,
|
||||
product_name,
|
||||
COUNT(*) as record_count,
|
||||
SUM(quantity_sold) as total_quantity,
|
||||
SUM(revenue) as total_revenue,
|
||||
AVG(quantity_sold) as average_quantity,
|
||||
AVG(revenue) as average_revenue
|
||||
FROM sales_data
|
||||
WHERE tenant_id = :tenant_id
|
||||
""")
|
||||
|
||||
params = {
|
||||
"tenant_id": tenant_id,
|
||||
"date_trunc": date_trunc
|
||||
}
|
||||
|
||||
# Add date filters
|
||||
if start_date:
|
||||
query = text(str(query) + " AND date >= :start_date")
|
||||
params["start_date"] = self._ensure_utc_datetime(start_date)
|
||||
|
||||
if end_date:
|
||||
query = text(str(query) + " AND date <= :end_date")
|
||||
params["end_date"] = self._ensure_utc_datetime(end_date)
|
||||
|
||||
# Add product filter
|
||||
if product_name:
|
||||
query = text(str(query) + " AND product_name = :product_name")
|
||||
params["product_name"] = product_name
|
||||
|
||||
# Add GROUP BY and ORDER BY
|
||||
query = text(str(query) + " GROUP BY period, product_name ORDER BY period DESC")
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
# Convert to list of dictionaries
|
||||
aggregations = []
|
||||
for row in rows:
|
||||
aggregations.append({
|
||||
"period": group_by,
|
||||
"date": row.period,
|
||||
"product_name": row.product_name,
|
||||
"record_count": row.record_count,
|
||||
"total_quantity": row.total_quantity,
|
||||
"total_revenue": float(row.total_revenue),
|
||||
"average_quantity": float(row.average_quantity),
|
||||
"average_revenue": float(row.average_revenue)
|
||||
})
|
||||
|
||||
return aggregations
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales aggregation",
|
||||
tenant_id=tenant_id,
|
||||
group_by=group_by,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Sales aggregation failed: {str(e)}")
|
||||
|
||||
async def get_top_products(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: int = 10,
|
||||
by_metric: str = "revenue"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get top products by quantity or revenue"""
|
||||
try:
|
||||
if by_metric not in ["revenue", "quantity"]:
|
||||
raise ValidationError(f"Invalid metric: {by_metric}")
|
||||
|
||||
# Choose the aggregation column
|
||||
metric_column = "revenue" if by_metric == "revenue" else "quantity_sold"
|
||||
|
||||
query = text(f"""
|
||||
SELECT
|
||||
product_name,
|
||||
COUNT(*) as sale_count,
|
||||
SUM(quantity_sold) as total_quantity,
|
||||
SUM(revenue) as total_revenue,
|
||||
AVG(revenue) as avg_revenue_per_sale
|
||||
FROM sales_data
|
||||
WHERE tenant_id = :tenant_id
|
||||
{('AND date >= :start_date' if start_date else '')}
|
||||
{('AND date <= :end_date' if end_date else '')}
|
||||
GROUP BY product_name
|
||||
ORDER BY SUM({metric_column}) DESC
|
||||
LIMIT :limit
|
||||
""")
|
||||
|
||||
params = {"tenant_id": tenant_id, "limit": limit}
|
||||
if start_date:
|
||||
params["start_date"] = self._ensure_utc_datetime(start_date)
|
||||
if end_date:
|
||||
params["end_date"] = self._ensure_utc_datetime(end_date)
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
products = []
|
||||
for row in rows:
|
||||
products.append({
|
||||
"product_name": row.product_name,
|
||||
"sale_count": row.sale_count,
|
||||
"total_quantity": row.total_quantity,
|
||||
"total_revenue": float(row.total_revenue),
|
||||
"avg_revenue_per_sale": float(row.avg_revenue_per_sale),
|
||||
"metric_used": by_metric
|
||||
})
|
||||
|
||||
return products
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get top products",
|
||||
tenant_id=tenant_id,
|
||||
by_metric=by_metric,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Top products query failed: {str(e)}")
|
||||
|
||||
async def get_sales_by_location(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get sales statistics by location"""
|
||||
try:
|
||||
query = text("""
|
||||
SELECT
|
||||
COALESCE(location_id, 'unknown') as location_id,
|
||||
COUNT(*) as sale_count,
|
||||
SUM(quantity_sold) as total_quantity,
|
||||
SUM(revenue) as total_revenue,
|
||||
AVG(revenue) as avg_revenue_per_sale
|
||||
FROM sales_data
|
||||
WHERE tenant_id = :tenant_id
|
||||
{date_filters}
|
||||
GROUP BY location_id
|
||||
ORDER BY SUM(revenue) DESC
|
||||
""".format(
|
||||
date_filters=(
|
||||
"AND date >= :start_date" if start_date else ""
|
||||
) + (
|
||||
" AND date <= :end_date" if end_date else ""
|
||||
)
|
||||
))
|
||||
|
||||
params = {"tenant_id": tenant_id}
|
||||
if start_date:
|
||||
params["start_date"] = self._ensure_utc_datetime(start_date)
|
||||
if end_date:
|
||||
params["end_date"] = self._ensure_utc_datetime(end_date)
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
locations = []
|
||||
for row in rows:
|
||||
locations.append({
|
||||
"location_id": row.location_id,
|
||||
"sale_count": row.sale_count,
|
||||
"total_quantity": row.total_quantity,
|
||||
"total_revenue": float(row.total_revenue),
|
||||
"avg_revenue_per_sale": float(row.avg_revenue_per_sale)
|
||||
})
|
||||
|
||||
return locations
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales by location",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Sales by location query failed: {str(e)}")
|
||||
|
||||
async def create_bulk_sales(
|
||||
self,
|
||||
sales_records: List[Dict[str, Any]],
|
||||
tenant_id: str
|
||||
) -> List[SalesData]:
|
||||
"""Create multiple sales records in bulk"""
|
||||
try:
|
||||
# Ensure all records have tenant_id
|
||||
for record in sales_records:
|
||||
record["tenant_id"] = tenant_id
|
||||
# Ensure dates are timezone-aware
|
||||
if "date" in record and record["date"]:
|
||||
record["date"] = self._ensure_utc_datetime(record["date"])
|
||||
|
||||
return await self.bulk_create(sales_records)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create bulk sales",
|
||||
tenant_id=tenant_id,
|
||||
record_count=len(sales_records),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Bulk sales creation failed: {str(e)}")
|
||||
|
||||
async def search_sales(
|
||||
self,
|
||||
tenant_id: str,
|
||||
search_term: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[SalesData]:
|
||||
"""Search sales by product name or notes"""
|
||||
try:
|
||||
# Use the parent search method with sales-specific fields
|
||||
search_fields = ["product_name", "notes", "location_id"]
|
||||
|
||||
# Filter by tenant first
|
||||
query = select(self.model).where(
|
||||
and_(
|
||||
self.model.tenant_id == tenant_id,
|
||||
or_(
|
||||
self.model.product_name.ilike(f"%{search_term}%"),
|
||||
self.model.notes.ilike(f"%{search_term}%") if hasattr(self.model, 'notes') else False,
|
||||
self.model.location_id.ilike(f"%{search_term}%") if hasattr(self.model, 'location_id') else False
|
||||
)
|
||||
)
|
||||
).order_by(desc(self.model.date)).offset(skip).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to search sales",
|
||||
tenant_id=tenant_id,
|
||||
search_term=search_term,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Sales search failed: {str(e)}")
|
||||
|
||||
async def get_sales_summary(
|
||||
self,
|
||||
tenant_id: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get comprehensive sales summary for a tenant"""
|
||||
try:
|
||||
base_filters = {"tenant_id": tenant_id}
|
||||
|
||||
# Build date filter for count
|
||||
date_query = select(func.count(self.model.id)).where(self.model.tenant_id == tenant_id)
|
||||
|
||||
if start_date:
|
||||
date_query = date_query.where(self.model.date >= self._ensure_utc_datetime(start_date))
|
||||
if end_date:
|
||||
date_query = date_query.where(self.model.date <= self._ensure_utc_datetime(end_date))
|
||||
|
||||
# Get basic counts
|
||||
total_result = await self.session.execute(date_query)
|
||||
total_sales = total_result.scalar() or 0
|
||||
|
||||
# Get revenue and quantity totals
|
||||
summary_query = text("""
|
||||
SELECT
|
||||
COUNT(*) as total_records,
|
||||
SUM(quantity_sold) as total_quantity,
|
||||
SUM(revenue) as total_revenue,
|
||||
AVG(revenue) as avg_revenue,
|
||||
MIN(date) as earliest_sale,
|
||||
MAX(date) as latest_sale,
|
||||
COUNT(DISTINCT product_name) as unique_products,
|
||||
COUNT(DISTINCT location_id) as unique_locations
|
||||
FROM sales_data
|
||||
WHERE tenant_id = :tenant_id
|
||||
{date_filters}
|
||||
""".format(
|
||||
date_filters=(
|
||||
"AND date >= :start_date" if start_date else ""
|
||||
) + (
|
||||
" AND date <= :end_date" if end_date else ""
|
||||
)
|
||||
))
|
||||
|
||||
params = {"tenant_id": tenant_id}
|
||||
if start_date:
|
||||
params["start_date"] = self._ensure_utc_datetime(start_date)
|
||||
if end_date:
|
||||
params["end_date"] = self._ensure_utc_datetime(end_date)
|
||||
|
||||
result = await self.session.execute(summary_query, params)
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"period_start": start_date,
|
||||
"period_end": end_date,
|
||||
"total_sales": row.total_records or 0,
|
||||
"total_quantity": row.total_quantity or 0,
|
||||
"total_revenue": float(row.total_revenue or 0),
|
||||
"average_revenue": float(row.avg_revenue or 0),
|
||||
"earliest_sale": row.earliest_sale,
|
||||
"latest_sale": row.latest_sale,
|
||||
"unique_products": row.unique_products or 0,
|
||||
"unique_locations": row.unique_locations or 0
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"period_start": start_date,
|
||||
"period_end": end_date,
|
||||
"total_sales": 0,
|
||||
"total_quantity": 0,
|
||||
"total_revenue": 0.0,
|
||||
"average_revenue": 0.0,
|
||||
"earliest_sale": None,
|
||||
"latest_sale": None,
|
||||
"unique_products": 0,
|
||||
"unique_locations": 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales summary",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Sales summary failed: {str(e)}")
|
||||
|
||||
async def validate_sales_data(self, sales_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate sales data before insertion"""
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
try:
|
||||
# Check required fields
|
||||
required_fields = ["date", "product_name", "quantity_sold", "revenue"]
|
||||
for field in required_fields:
|
||||
if field not in sales_data or sales_data[field] is None:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
|
||||
# Validate data types and ranges
|
||||
if "quantity_sold" in sales_data:
|
||||
if not isinstance(sales_data["quantity_sold"], (int, float)) or sales_data["quantity_sold"] <= 0:
|
||||
errors.append("quantity_sold must be a positive number")
|
||||
|
||||
if "revenue" in sales_data:
|
||||
if not isinstance(sales_data["revenue"], (int, float)) or sales_data["revenue"] <= 0:
|
||||
errors.append("revenue must be a positive number")
|
||||
|
||||
# Validate string lengths
|
||||
if "product_name" in sales_data and len(str(sales_data["product_name"])) > 255:
|
||||
errors.append("product_name exceeds maximum length of 255 characters")
|
||||
|
||||
# Check for suspicious data
|
||||
if "quantity_sold" in sales_data and "revenue" in sales_data:
|
||||
unit_price = sales_data["revenue"] / sales_data["quantity_sold"]
|
||||
if unit_price > 10000: # Arbitrary high price threshold
|
||||
warnings.append(f"Unusually high unit price: {unit_price:.2f}")
|
||||
elif unit_price < 0.01: # Very low price
|
||||
warnings.append(f"Unusually low unit price: {unit_price:.2f}")
|
||||
|
||||
return {
|
||||
"is_valid": len(errors) == 0,
|
||||
"errors": errors,
|
||||
"warnings": warnings
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate sales data", error=str(e))
|
||||
return {
|
||||
"is_valid": False,
|
||||
"errors": [f"Validation error: {str(e)}"],
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
async def get_product_statistics(self, tenant_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get product statistics for tenant"""
|
||||
try:
|
||||
query = text("""
|
||||
SELECT
|
||||
product_name,
|
||||
COUNT(*) as total_sales,
|
||||
SUM(quantity_sold) as total_quantity,
|
||||
SUM(revenue) as total_revenue,
|
||||
AVG(revenue) as avg_revenue,
|
||||
MIN(date) as first_sale,
|
||||
MAX(date) as last_sale
|
||||
FROM sales_data
|
||||
WHERE tenant_id = :tenant_id
|
||||
GROUP BY product_name
|
||||
ORDER BY SUM(revenue) DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {"tenant_id": tenant_id})
|
||||
rows = result.fetchall()
|
||||
|
||||
products = []
|
||||
for row in rows:
|
||||
products.append({
|
||||
"product_name": row.product_name,
|
||||
"total_sales": int(row.total_sales or 0),
|
||||
"total_quantity": int(row.total_quantity or 0),
|
||||
"total_revenue": float(row.total_revenue or 0),
|
||||
"avg_revenue": float(row.avg_revenue or 0),
|
||||
"first_sale": row.first_sale.isoformat() if row.first_sale else None,
|
||||
"last_sale": row.last_sale.isoformat() if row.last_sale else None
|
||||
})
|
||||
|
||||
logger.debug(f"Found {len(products)} products for tenant {tenant_id}")
|
||||
return products
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting product statistics: {str(e)}", tenant_id=tenant_id)
|
||||
return []
|
||||
@@ -1,874 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/app/repositories/traffic_repository.py
|
||||
# ================================================================
|
||||
"""
|
||||
Traffic Repository - Enhanced for multiple cities with comprehensive data access patterns
|
||||
Follows existing repository architecture while adding city-specific functionality
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, func, desc, asc, text, update, delete
|
||||
from sqlalchemy.orm import selectinload
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import DataBaseRepository
|
||||
from app.models.traffic import TrafficData, TrafficMeasurementPoint, TrafficDataBackgroundJob
|
||||
from app.schemas.traffic import TrafficDataCreate, TrafficDataResponse
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TrafficRepository(DataBaseRepository[TrafficData, TrafficDataCreate, Dict]):
|
||||
"""
|
||||
Enhanced repository for traffic data operations across multiple cities
|
||||
Provides city-aware queries and advanced traffic analytics
|
||||
"""
|
||||
|
||||
def __init__(self, model_class: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
super().__init__(model_class, session, cache_ttl)
|
||||
|
||||
# ================================================================
|
||||
# CORE TRAFFIC DATA OPERATIONS
|
||||
# ================================================================
|
||||
|
||||
async def get_by_location_and_date_range(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
city: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[TrafficData]:
|
||||
"""Get traffic data by location and date range with city filtering"""
|
||||
try:
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
# Build base query
|
||||
query = select(self.model).where(self.model.location_id == location_id)
|
||||
|
||||
# Add city filter if specified
|
||||
if city:
|
||||
query = query.where(self.model.city == city)
|
||||
|
||||
# Add tenant filter if specified
|
||||
if tenant_id:
|
||||
query = query.where(self.model.tenant_id == tenant_id)
|
||||
|
||||
# Add date range filters
|
||||
if start_date:
|
||||
start_date = self._ensure_utc_datetime(start_date)
|
||||
query = query.where(self.model.date >= start_date)
|
||||
|
||||
if end_date:
|
||||
end_date = self._ensure_utc_datetime(end_date)
|
||||
query = query.where(self.model.date <= end_date)
|
||||
|
||||
# Order by date descending (most recent first)
|
||||
query = query.order_by(desc(self.model.date))
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get traffic data by location and date range",
|
||||
latitude=latitude, longitude=longitude,
|
||||
city=city, error=str(e))
|
||||
raise DatabaseError(f"Failed to get traffic data: {str(e)}")
|
||||
|
||||
async def get_by_city_and_date_range(
|
||||
self,
|
||||
city: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
district: Optional[str] = None,
|
||||
measurement_point_ids: Optional[List[str]] = None,
|
||||
include_synthetic: bool = True,
|
||||
tenant_id: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 1000
|
||||
) -> List[TrafficData]:
|
||||
"""Get traffic data by city with advanced filtering options"""
|
||||
try:
|
||||
# Build base query
|
||||
query = select(self.model).where(self.model.city == city)
|
||||
|
||||
# Add tenant filter if specified
|
||||
if tenant_id:
|
||||
query = query.where(self.model.tenant_id == tenant_id)
|
||||
|
||||
# Add date range filters
|
||||
if start_date:
|
||||
start_date = self._ensure_utc_datetime(start_date)
|
||||
query = query.where(self.model.date >= start_date)
|
||||
|
||||
if end_date:
|
||||
end_date = self._ensure_utc_datetime(end_date)
|
||||
query = query.where(self.model.date <= end_date)
|
||||
|
||||
# Add district filter
|
||||
if district:
|
||||
query = query.where(self.model.district == district)
|
||||
|
||||
# Add measurement point filter
|
||||
if measurement_point_ids:
|
||||
query = query.where(self.model.measurement_point_id.in_(measurement_point_ids))
|
||||
|
||||
# Filter synthetic data if requested
|
||||
if not include_synthetic:
|
||||
query = query.where(self.model.is_synthetic == False)
|
||||
|
||||
# Order by date and measurement point
|
||||
query = query.order_by(desc(self.model.date), self.model.measurement_point_id)
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get traffic data by city",
|
||||
city=city, district=district, error=str(e))
|
||||
raise DatabaseError(f"Failed to get traffic data: {str(e)}")
|
||||
|
||||
async def get_latest_by_measurement_points(
|
||||
self,
|
||||
measurement_point_ids: List[str],
|
||||
city: str,
|
||||
hours_back: int = 24
|
||||
) -> List[TrafficData]:
|
||||
"""Get latest traffic data for specific measurement points"""
|
||||
try:
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
|
||||
|
||||
query = select(self.model).where(
|
||||
and_(
|
||||
self.model.city == city,
|
||||
self.model.measurement_point_id.in_(measurement_point_ids),
|
||||
self.model.date >= cutoff_time
|
||||
)
|
||||
).order_by(
|
||||
self.model.measurement_point_id,
|
||||
desc(self.model.date)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
all_records = result.scalars().all()
|
||||
|
||||
# Get the latest record for each measurement point
|
||||
latest_records = {}
|
||||
for record in all_records:
|
||||
point_id = record.measurement_point_id
|
||||
if point_id not in latest_records:
|
||||
latest_records[point_id] = record
|
||||
|
||||
return list(latest_records.values())
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get latest traffic data by measurement points",
|
||||
city=city, points=len(measurement_point_ids), error=str(e))
|
||||
raise DatabaseError(f"Failed to get latest traffic data: {str(e)}")
|
||||
|
||||
# ================================================================
|
||||
# ANALYTICS AND AGGREGATIONS
|
||||
# ================================================================
|
||||
|
||||
async def get_traffic_statistics_by_city(
|
||||
self,
|
||||
city: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
group_by: str = "daily"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get aggregated traffic statistics by city"""
|
||||
try:
|
||||
# Determine date truncation based on group_by
|
||||
if group_by == "hourly":
|
||||
date_trunc = "hour"
|
||||
elif group_by == "daily":
|
||||
date_trunc = "day"
|
||||
elif group_by == "weekly":
|
||||
date_trunc = "week"
|
||||
elif group_by == "monthly":
|
||||
date_trunc = "month"
|
||||
else:
|
||||
raise ValidationError(f"Invalid group_by value: {group_by}")
|
||||
|
||||
# Build aggregation query
|
||||
if self.session.bind.dialect.name == 'postgresql':
|
||||
query = text("""
|
||||
SELECT
|
||||
DATE_TRUNC(:date_trunc, date) as period,
|
||||
city,
|
||||
district,
|
||||
COUNT(*) as record_count,
|
||||
AVG(traffic_volume) as avg_traffic_volume,
|
||||
MAX(traffic_volume) as max_traffic_volume,
|
||||
AVG(pedestrian_count) as avg_pedestrian_count,
|
||||
AVG(average_speed) as avg_speed,
|
||||
COUNT(CASE WHEN congestion_level = 'high' THEN 1 END) as high_congestion_count,
|
||||
COUNT(CASE WHEN is_synthetic = false THEN 1 END) as real_data_count,
|
||||
COUNT(CASE WHEN has_pedestrian_inference = true THEN 1 END) as pedestrian_inference_count
|
||||
FROM traffic_data
|
||||
WHERE city = :city
|
||||
""")
|
||||
else:
|
||||
# SQLite fallback
|
||||
query = text("""
|
||||
SELECT
|
||||
DATE(date) as period,
|
||||
city,
|
||||
district,
|
||||
COUNT(*) as record_count,
|
||||
AVG(traffic_volume) as avg_traffic_volume,
|
||||
MAX(traffic_volume) as max_traffic_volume,
|
||||
AVG(pedestrian_count) as avg_pedestrian_count,
|
||||
AVG(average_speed) as avg_speed,
|
||||
SUM(CASE WHEN congestion_level = 'high' THEN 1 ELSE 0 END) as high_congestion_count,
|
||||
SUM(CASE WHEN is_synthetic = 0 THEN 1 ELSE 0 END) as real_data_count,
|
||||
SUM(CASE WHEN has_pedestrian_inference = 1 THEN 1 ELSE 0 END) as pedestrian_inference_count
|
||||
FROM traffic_data
|
||||
WHERE city = :city
|
||||
""")
|
||||
|
||||
params = {
|
||||
"city": city,
|
||||
"date_trunc": date_trunc
|
||||
}
|
||||
|
||||
# Add date filters
|
||||
if start_date:
|
||||
query = text(str(query) + " AND date >= :start_date")
|
||||
params["start_date"] = self._ensure_utc_datetime(start_date)
|
||||
|
||||
if end_date:
|
||||
query = text(str(query) + " AND date <= :end_date")
|
||||
params["end_date"] = self._ensure_utc_datetime(end_date)
|
||||
|
||||
# Add GROUP BY and ORDER BY
|
||||
query = text(str(query) + " GROUP BY period, city, district ORDER BY period DESC")
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
# Convert to list of dictionaries
|
||||
statistics = []
|
||||
for row in rows:
|
||||
statistics.append({
|
||||
"period": group_by,
|
||||
"date": row.period,
|
||||
"city": row.city,
|
||||
"district": row.district,
|
||||
"record_count": row.record_count,
|
||||
"avg_traffic_volume": float(row.avg_traffic_volume or 0),
|
||||
"max_traffic_volume": row.max_traffic_volume or 0,
|
||||
"avg_pedestrian_count": float(row.avg_pedestrian_count or 0),
|
||||
"avg_speed": float(row.avg_speed or 0),
|
||||
"high_congestion_count": row.high_congestion_count or 0,
|
||||
"real_data_percentage": round((row.real_data_count or 0) / max(1, row.record_count) * 100, 2),
|
||||
"pedestrian_inference_percentage": round((row.pedestrian_inference_count or 0) / max(1, row.record_count) * 100, 2)
|
||||
})
|
||||
|
||||
return statistics
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get traffic statistics by city",
|
||||
city=city, group_by=group_by, error=str(e))
|
||||
raise DatabaseError(f"Traffic statistics query failed: {str(e)}")
|
||||
|
||||
async def get_congestion_heatmap_data(
|
||||
self,
|
||||
city: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
time_granularity: str = "hour"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get congestion data for heatmap visualization"""
|
||||
try:
|
||||
if time_granularity == "hour":
|
||||
time_extract = "EXTRACT(hour FROM date)"
|
||||
elif time_granularity == "day_of_week":
|
||||
time_extract = "EXTRACT(dow FROM date)"
|
||||
else:
|
||||
time_extract = "EXTRACT(hour FROM date)"
|
||||
|
||||
query = text(f"""
|
||||
SELECT
|
||||
{time_extract} as time_period,
|
||||
district,
|
||||
measurement_point_id,
|
||||
latitude,
|
||||
longitude,
|
||||
AVG(CASE
|
||||
WHEN congestion_level = 'low' THEN 1
|
||||
WHEN congestion_level = 'medium' THEN 2
|
||||
WHEN congestion_level = 'high' THEN 3
|
||||
WHEN congestion_level = 'blocked' THEN 4
|
||||
ELSE 1
|
||||
END) as avg_congestion_score,
|
||||
COUNT(*) as data_points,
|
||||
AVG(traffic_volume) as avg_traffic_volume,
|
||||
AVG(pedestrian_count) as avg_pedestrian_count
|
||||
FROM traffic_data
|
||||
WHERE city = :city
|
||||
AND date >= :start_date
|
||||
AND date <= :end_date
|
||||
AND latitude IS NOT NULL
|
||||
AND longitude IS NOT NULL
|
||||
GROUP BY time_period, district, measurement_point_id, latitude, longitude
|
||||
ORDER BY time_period, district, avg_congestion_score DESC
|
||||
""")
|
||||
|
||||
params = {
|
||||
"city": city,
|
||||
"start_date": self._ensure_utc_datetime(start_date),
|
||||
"end_date": self._ensure_utc_datetime(end_date)
|
||||
}
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
heatmap_data = []
|
||||
for row in rows:
|
||||
heatmap_data.append({
|
||||
"time_period": int(row.time_period or 0),
|
||||
"district": row.district,
|
||||
"measurement_point_id": row.measurement_point_id,
|
||||
"latitude": float(row.latitude),
|
||||
"longitude": float(row.longitude),
|
||||
"avg_congestion_score": float(row.avg_congestion_score),
|
||||
"data_points": row.data_points,
|
||||
"avg_traffic_volume": float(row.avg_traffic_volume or 0),
|
||||
"avg_pedestrian_count": float(row.avg_pedestrian_count or 0)
|
||||
})
|
||||
|
||||
return heatmap_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get congestion heatmap data",
|
||||
city=city, error=str(e))
|
||||
raise DatabaseError(f"Congestion heatmap query failed: {str(e)}")
|
||||
|
||||
# ================================================================
|
||||
# BULK OPERATIONS AND DATA MANAGEMENT
|
||||
# ================================================================
|
||||
|
||||
async def create_bulk_traffic_data(
|
||||
self,
|
||||
traffic_records: List[Dict[str, Any]],
|
||||
city: str,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> List[TrafficData]:
|
||||
"""Create multiple traffic records in bulk with enhanced validation"""
|
||||
try:
|
||||
# Ensure all records have city and tenant_id
|
||||
for record in traffic_records:
|
||||
record["city"] = city
|
||||
if tenant_id:
|
||||
record["tenant_id"] = tenant_id
|
||||
# Ensure dates are timezone-aware
|
||||
if "date" in record and record["date"]:
|
||||
record["date"] = self._ensure_utc_datetime(record["date"])
|
||||
|
||||
# Enhanced validation
|
||||
validated_records = []
|
||||
for record in traffic_records:
|
||||
if self._validate_traffic_record(record):
|
||||
validated_records.append(record)
|
||||
else:
|
||||
logger.warning("Invalid traffic record skipped",
|
||||
city=city, record_keys=list(record.keys()))
|
||||
|
||||
if not validated_records:
|
||||
logger.warning("No valid traffic records to create", city=city)
|
||||
return []
|
||||
|
||||
# Use bulk create with deduplication
|
||||
created_records = await self.bulk_create_with_deduplication(validated_records)
|
||||
|
||||
logger.info("Bulk traffic data creation completed",
|
||||
city=city, requested=len(traffic_records),
|
||||
validated=len(validated_records), created=len(created_records))
|
||||
|
||||
return created_records
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create bulk traffic data",
|
||||
city=city, record_count=len(traffic_records), error=str(e))
|
||||
raise DatabaseError(f"Bulk traffic creation failed: {str(e)}")
|
||||
|
||||
async def bulk_create_with_deduplication(
|
||||
self,
|
||||
records: List[Dict[str, Any]]
|
||||
) -> List[TrafficData]:
|
||||
"""Bulk create with automatic deduplication based on location, city, and date"""
|
||||
try:
|
||||
if not records:
|
||||
return []
|
||||
|
||||
# Extract unique keys for deduplication check
|
||||
unique_keys = []
|
||||
for record in records:
|
||||
key = (
|
||||
record.get('location_id'),
|
||||
record.get('city'),
|
||||
record.get('date'),
|
||||
record.get('measurement_point_id')
|
||||
)
|
||||
unique_keys.append(key)
|
||||
|
||||
# Check for existing records
|
||||
location_ids = [key[0] for key in unique_keys if key[0]]
|
||||
cities = [key[1] for key in unique_keys if key[1]]
|
||||
dates = [key[2] for key in unique_keys if key[2]]
|
||||
|
||||
# For large datasets, use chunked deduplication to avoid memory issues
|
||||
if len(location_ids) > 1000:
|
||||
logger.info(f"Large dataset detected ({len(records)} records), using chunked deduplication")
|
||||
new_records = []
|
||||
chunk_size = 1000
|
||||
|
||||
for i in range(0, len(records), chunk_size):
|
||||
chunk_records = records[i:i + chunk_size]
|
||||
chunk_keys = unique_keys[i:i + chunk_size]
|
||||
|
||||
# Get unique values for this chunk
|
||||
chunk_location_ids = list(set(key[0] for key in chunk_keys if key[0]))
|
||||
chunk_cities = list(set(key[1] for key in chunk_keys if key[1]))
|
||||
chunk_dates = list(set(key[2] for key in chunk_keys if key[2]))
|
||||
|
||||
if chunk_location_ids and chunk_cities and chunk_dates:
|
||||
existing_query = select(
|
||||
self.model.location_id,
|
||||
self.model.city,
|
||||
self.model.date,
|
||||
self.model.measurement_point_id
|
||||
).where(
|
||||
and_(
|
||||
self.model.location_id.in_(chunk_location_ids),
|
||||
self.model.city.in_(chunk_cities),
|
||||
self.model.date.in_(chunk_dates)
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(existing_query)
|
||||
chunk_existing_keys = set(result.fetchall())
|
||||
|
||||
# Filter chunk duplicates
|
||||
for j, record in enumerate(chunk_records):
|
||||
key = chunk_keys[j]
|
||||
if key not in chunk_existing_keys:
|
||||
new_records.append(record)
|
||||
else:
|
||||
new_records.extend(chunk_records)
|
||||
|
||||
logger.debug("Chunked deduplication completed",
|
||||
total_records=len(records),
|
||||
new_records=len(new_records))
|
||||
records = new_records
|
||||
|
||||
elif location_ids and cities and dates:
|
||||
existing_query = select(
|
||||
self.model.location_id,
|
||||
self.model.city,
|
||||
self.model.date,
|
||||
self.model.measurement_point_id
|
||||
).where(
|
||||
and_(
|
||||
self.model.location_id.in_(location_ids),
|
||||
self.model.city.in_(cities),
|
||||
self.model.date.in_(dates)
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(existing_query)
|
||||
existing_keys = set(result.fetchall())
|
||||
|
||||
# Filter out duplicates
|
||||
new_records = []
|
||||
for i, record in enumerate(records):
|
||||
key = unique_keys[i]
|
||||
if key not in existing_keys:
|
||||
new_records.append(record)
|
||||
|
||||
logger.debug("Standard deduplication completed",
|
||||
total_records=len(records),
|
||||
existing_records=len(existing_keys),
|
||||
new_records=len(new_records))
|
||||
|
||||
records = new_records
|
||||
|
||||
# Proceed with bulk creation
|
||||
return await self.bulk_create(records)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed bulk create with deduplication", error=str(e))
|
||||
raise DatabaseError(f"Bulk create with deduplication failed: {str(e)}")
|
||||
|
||||
def _validate_traffic_record(self, record: Dict[str, Any]) -> bool:
|
||||
"""Enhanced validation for traffic records"""
|
||||
required_fields = ['date', 'city']
|
||||
|
||||
# Check required fields
|
||||
for field in required_fields:
|
||||
if not record.get(field):
|
||||
return False
|
||||
|
||||
# Validate city
|
||||
city = record.get('city', '').lower()
|
||||
if city not in ['madrid', 'barcelona', 'valencia', 'test']: # Extendable list
|
||||
return False
|
||||
|
||||
# Validate data ranges
|
||||
traffic_volume = record.get('traffic_volume')
|
||||
if traffic_volume is not None and (traffic_volume < 0 or traffic_volume > 50000):
|
||||
return False
|
||||
|
||||
pedestrian_count = record.get('pedestrian_count')
|
||||
if pedestrian_count is not None and (pedestrian_count < 0 or pedestrian_count > 10000):
|
||||
return False
|
||||
|
||||
average_speed = record.get('average_speed')
|
||||
if average_speed is not None and (average_speed < 0 or average_speed > 200):
|
||||
return False
|
||||
|
||||
congestion_level = record.get('congestion_level')
|
||||
if congestion_level and congestion_level not in ['low', 'medium', 'high', 'blocked']:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# ================================================================
|
||||
# TRAINING DATA SPECIFIC OPERATIONS
|
||||
# ================================================================
|
||||
|
||||
async def get_training_data_by_location(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
tenant_id: Optional[str] = None,
|
||||
include_pedestrian_inference: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get optimized training data for ML models"""
|
||||
try:
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
query = select(self.model).where(
|
||||
and_(
|
||||
self.model.location_id == location_id,
|
||||
self.model.date >= self._ensure_utc_datetime(start_date),
|
||||
self.model.date <= self._ensure_utc_datetime(end_date)
|
||||
)
|
||||
)
|
||||
|
||||
if tenant_id:
|
||||
query = query.where(self.model.tenant_id == tenant_id)
|
||||
|
||||
if include_pedestrian_inference:
|
||||
# Prefer records with pedestrian inference
|
||||
query = query.order_by(
|
||||
desc(self.model.has_pedestrian_inference),
|
||||
desc(self.model.data_quality_score),
|
||||
self.model.date
|
||||
)
|
||||
else:
|
||||
query = query.order_by(
|
||||
desc(self.model.data_quality_score),
|
||||
self.model.date
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
records = result.scalars().all()
|
||||
|
||||
# Convert to training format with enhanced features
|
||||
training_data = []
|
||||
for record in records:
|
||||
training_record = {
|
||||
'date': record.date,
|
||||
'traffic_volume': record.traffic_volume or 0,
|
||||
'pedestrian_count': record.pedestrian_count or 0,
|
||||
'congestion_level': record.congestion_level or 'medium',
|
||||
'average_speed': record.average_speed or 25.0,
|
||||
'city': record.city,
|
||||
'district': record.district,
|
||||
'measurement_point_id': record.measurement_point_id,
|
||||
'source': record.source,
|
||||
'is_synthetic': record.is_synthetic or False,
|
||||
'has_pedestrian_inference': record.has_pedestrian_inference or False,
|
||||
'data_quality_score': record.data_quality_score or 50.0,
|
||||
|
||||
# Enhanced features for training
|
||||
'hour_of_day': record.date.hour if record.date else 12,
|
||||
'day_of_week': record.date.weekday() if record.date else 0,
|
||||
'month': record.date.month if record.date else 1,
|
||||
|
||||
# City-specific features
|
||||
'city_specific_data': record.city_specific_data or {}
|
||||
}
|
||||
|
||||
training_data.append(training_record)
|
||||
|
||||
logger.info("Retrieved training data",
|
||||
location_id=location_id, records=len(training_data),
|
||||
with_pedestrian_inference=sum(1 for r in training_data if r['has_pedestrian_inference']))
|
||||
|
||||
return training_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get training data",
|
||||
latitude=latitude, longitude=longitude, error=str(e))
|
||||
raise DatabaseError(f"Training data retrieval failed: {str(e)}")
|
||||
|
||||
async def get_historical_data_by_location(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> List[TrafficData]:
|
||||
"""Get historical traffic data for a specific location and date range"""
|
||||
return await self.get_by_location_and_date_range(
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
tenant_id=tenant_id,
|
||||
limit=1000000 # Large limit for historical data
|
||||
)
|
||||
|
||||
async def count_records_in_period(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
city: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""Count traffic records for a specific location and time period"""
|
||||
try:
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
query = select(func.count(self.model.id)).where(
|
||||
and_(
|
||||
self.model.location_id == location_id,
|
||||
self.model.date >= self._ensure_utc_datetime(start_date),
|
||||
self.model.date <= self._ensure_utc_datetime(end_date)
|
||||
)
|
||||
)
|
||||
|
||||
if city:
|
||||
query = query.where(self.model.city == city)
|
||||
|
||||
if tenant_id:
|
||||
query = query.where(self.model.tenant_id == tenant_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
count = result.scalar()
|
||||
|
||||
return count or 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to count records in period",
|
||||
latitude=latitude, longitude=longitude, error=str(e))
|
||||
raise DatabaseError(f"Record count failed: {str(e)}")
|
||||
|
||||
# ================================================================
|
||||
# DATA QUALITY AND MAINTENANCE
|
||||
# ================================================================
|
||||
|
||||
async def update_data_quality_scores(self, city: str) -> int:
|
||||
"""Update data quality scores based on various criteria"""
|
||||
try:
|
||||
# Calculate quality scores based on data completeness and consistency
|
||||
query = text("""
|
||||
UPDATE traffic_data
|
||||
SET data_quality_score = (
|
||||
CASE
|
||||
WHEN traffic_volume IS NOT NULL THEN 20 ELSE 0 END +
|
||||
CASE
|
||||
WHEN pedestrian_count IS NOT NULL THEN 20 ELSE 0 END +
|
||||
CASE
|
||||
WHEN average_speed IS NOT NULL AND average_speed > 0 THEN 20 ELSE 0 END +
|
||||
CASE
|
||||
WHEN congestion_level IS NOT NULL THEN 15 ELSE 0 END +
|
||||
CASE
|
||||
WHEN measurement_point_id IS NOT NULL THEN 10 ELSE 0 END +
|
||||
CASE
|
||||
WHEN district IS NOT NULL THEN 10 ELSE 0 END +
|
||||
CASE
|
||||
WHEN has_pedestrian_inference = true THEN 5 ELSE 0 END
|
||||
),
|
||||
updated_at = :updated_at
|
||||
WHERE city = :city AND data_quality_score IS NULL
|
||||
""")
|
||||
|
||||
params = {
|
||||
"city": city,
|
||||
"updated_at": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
updated_count = result.rowcount
|
||||
await self.session.commit()
|
||||
|
||||
logger.info("Updated data quality scores",
|
||||
city=city, updated_count=updated_count)
|
||||
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update data quality scores",
|
||||
city=city, error=str(e))
|
||||
await self.session.rollback()
|
||||
raise DatabaseError(f"Data quality update failed: {str(e)}")
|
||||
|
||||
async def cleanup_old_synthetic_data(
|
||||
self,
|
||||
city: str,
|
||||
days_to_keep: int = 90
|
||||
) -> int:
|
||||
"""Clean up old synthetic data while preserving real data"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
|
||||
query = delete(self.model).where(
|
||||
and_(
|
||||
self.model.city == city,
|
||||
self.model.is_synthetic == True,
|
||||
self.model.date < cutoff_date
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
deleted_count = result.rowcount
|
||||
await self.session.commit()
|
||||
|
||||
logger.info("Cleaned up old synthetic data",
|
||||
city=city, deleted_count=deleted_count, days_kept=days_to_keep)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old synthetic data",
|
||||
city=city, error=str(e))
|
||||
await self.session.rollback()
|
||||
raise DatabaseError(f"Synthetic data cleanup failed: {str(e)}")
|
||||
|
||||
|
||||
class TrafficMeasurementPointRepository(DataBaseRepository[TrafficMeasurementPoint, Dict, Dict]):
|
||||
"""Repository for traffic measurement points across cities"""
|
||||
|
||||
async def get_points_near_location(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
city: str,
|
||||
radius_km: float = 10.0,
|
||||
limit: int = 20
|
||||
) -> List[TrafficMeasurementPoint]:
|
||||
"""Get measurement points near a location using spatial query"""
|
||||
try:
|
||||
# Simple distance calculation (for more precise, use PostGIS)
|
||||
query = text("""
|
||||
SELECT *,
|
||||
(6371 * acos(
|
||||
cos(radians(:lat)) * cos(radians(latitude)) *
|
||||
cos(radians(longitude) - radians(:lon)) +
|
||||
sin(radians(:lat)) * sin(radians(latitude))
|
||||
)) as distance_km
|
||||
FROM traffic_measurement_points
|
||||
WHERE city = :city
|
||||
AND is_active = true
|
||||
HAVING distance_km <= :radius_km
|
||||
ORDER BY distance_km
|
||||
LIMIT :limit
|
||||
""")
|
||||
|
||||
params = {
|
||||
"lat": latitude,
|
||||
"lon": longitude,
|
||||
"city": city,
|
||||
"radius_km": radius_km,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
# Convert rows to model instances
|
||||
points = []
|
||||
for row in rows:
|
||||
point = TrafficMeasurementPoint()
|
||||
for key, value in row._mapping.items():
|
||||
if hasattr(point, key) and key != 'distance_km':
|
||||
setattr(point, key, value)
|
||||
points.append(point)
|
||||
|
||||
return points
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get measurement points near location",
|
||||
latitude=latitude, longitude=longitude, city=city, error=str(e))
|
||||
raise DatabaseError(f"Measurement points query failed: {str(e)}")
|
||||
|
||||
|
||||
class TrafficBackgroundJobRepository(DataBaseRepository[TrafficDataBackgroundJob, Dict, Dict]):
|
||||
"""Repository for managing background traffic data jobs"""
|
||||
|
||||
async def get_pending_jobs_by_city(self, city: str) -> List[TrafficDataBackgroundJob]:
|
||||
"""Get pending background jobs for a specific city"""
|
||||
try:
|
||||
query = select(self.model).where(
|
||||
and_(
|
||||
self.model.city == city,
|
||||
self.model.status == 'pending'
|
||||
)
|
||||
).order_by(self.model.scheduled_at)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get pending jobs by city", city=city, error=str(e))
|
||||
raise DatabaseError(f"Background jobs query failed: {str(e)}")
|
||||
|
||||
async def update_job_progress(
|
||||
self,
|
||||
job_id: str,
|
||||
progress_percentage: float,
|
||||
records_processed: int = 0,
|
||||
records_stored: int = 0
|
||||
) -> bool:
|
||||
"""Update job progress"""
|
||||
try:
|
||||
query = update(self.model).where(
|
||||
self.model.id == job_id
|
||||
).values(
|
||||
progress_percentage=progress_percentage,
|
||||
records_processed=records_processed,
|
||||
records_stored=records_stored,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
await self.session.commit()
|
||||
|
||||
return result.rowcount > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update job progress", job_id=job_id, error=str(e))
|
||||
await self.session.rollback()
|
||||
raise DatabaseError(f"Job progress update failed: {str(e)}")
|
||||
@@ -1,62 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/app/schemas/external.py
|
||||
# ================================================================
|
||||
"""External API response schemas"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
class WeatherDataResponse(BaseModel):
|
||||
date: datetime
|
||||
temperature: Optional[float]
|
||||
precipitation: Optional[float]
|
||||
humidity: Optional[float]
|
||||
wind_speed: Optional[float]
|
||||
pressure: Optional[float]
|
||||
description: Optional[str]
|
||||
source: str
|
||||
|
||||
class WeatherForecastResponse(BaseModel):
|
||||
forecast_date: datetime
|
||||
generated_at: datetime
|
||||
temperature: Optional[float]
|
||||
precipitation: Optional[float]
|
||||
humidity: Optional[float]
|
||||
wind_speed: Optional[float]
|
||||
description: Optional[str]
|
||||
source: str
|
||||
|
||||
class TrafficDataResponse(BaseModel):
|
||||
date: datetime
|
||||
traffic_volume: Optional[int]
|
||||
pedestrian_count: Optional[int]
|
||||
congestion_level: Optional[str]
|
||||
average_speed: Optional[float]
|
||||
source: str
|
||||
|
||||
class LocationRequest(BaseModel):
|
||||
latitude: float
|
||||
longitude: float
|
||||
address: Optional[str] = None
|
||||
|
||||
class DateRangeRequest(BaseModel):
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
|
||||
class HistoricalTrafficRequest(BaseModel):
|
||||
latitude: float
|
||||
longitude: float
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
|
||||
class HistoricalWeatherRequest(BaseModel):
|
||||
latitude: float
|
||||
longitude: float
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
|
||||
class WeatherForecastRequest(BaseModel):
|
||||
latitude: float
|
||||
longitude: float
|
||||
days: int
|
||||
@@ -1,171 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/app/schemas/sales.py - MISSING FILE
|
||||
# ================================================================
|
||||
"""Sales data schemas"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
class SalesDataCreate(BaseModel):
|
||||
"""Schema for creating sales data - FIXED to work with gateway"""
|
||||
# ✅ FIX: Make tenant_id optional since it comes from URL path
|
||||
tenant_id: Optional[UUID] = Field(None, description="Tenant ID (auto-injected from URL path)")
|
||||
date: datetime
|
||||
product_name: str = Field(..., min_length=1, max_length=255)
|
||||
quantity_sold: int = Field(..., gt=0)
|
||||
revenue: float = Field(..., gt=0)
|
||||
location_id: Optional[str] = Field(None, max_length=100)
|
||||
source: str = Field(default="manual", max_length=50)
|
||||
notes: Optional[str] = Field(None, max_length=500)
|
||||
|
||||
@field_validator('product_name')
|
||||
@classmethod
|
||||
def normalize_product_name(cls, v):
|
||||
return v.strip().lower()
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"date": "2024-01-15T10:00:00Z",
|
||||
"product_name": "Pan Integral",
|
||||
"quantity_sold": 25,
|
||||
"revenue": 37.50,
|
||||
"source": "manual"
|
||||
# Note: tenant_id is automatically injected from URL path by gateway
|
||||
}
|
||||
}
|
||||
|
||||
class SalesDataResponse(BaseModel):
|
||||
"""Schema for sales data response"""
|
||||
id: UUID
|
||||
tenant_id: UUID
|
||||
date: datetime
|
||||
product_name: str
|
||||
quantity_sold: int
|
||||
revenue: float
|
||||
location_id: Optional[str]
|
||||
source: str
|
||||
notes: Optional[str]
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class SalesDataQuery(BaseModel):
|
||||
"""Schema for querying sales data"""
|
||||
tenant_id: UUID
|
||||
start_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None
|
||||
product_names: Optional[List[str]] = None
|
||||
location_ids: Optional[List[str]] = None
|
||||
sources: Optional[List[str]] = None
|
||||
min_quantity: Optional[int] = None
|
||||
max_quantity: Optional[int] = None
|
||||
min_revenue: Optional[float] = None
|
||||
max_revenue: Optional[float] = None
|
||||
limit: Optional[int] = Field(default=1000, le=5000)
|
||||
offset: Optional[int] = Field(default=0, ge=0)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class SalesDataImport(BaseModel):
|
||||
"""Schema for importing sales data - FIXED to work with gateway"""
|
||||
# ✅ FIX: Make tenant_id optional since it comes from URL path
|
||||
tenant_id: Optional[UUID] = Field(None, description="Tenant ID (auto-injected from URL path)")
|
||||
data: str = Field(..., description="JSON string or CSV content")
|
||||
data_format: str = Field(..., pattern="^(csv|json|excel)$")
|
||||
source: str = Field(default="import", max_length=50)
|
||||
validate_only: bool = Field(default=False)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"data": "date,product,quantity,revenue\n2024-01-01,bread,10,25.50",
|
||||
"data_format": "csv",
|
||||
# Note: tenant_id is automatically injected from URL path by gateway
|
||||
}
|
||||
}
|
||||
|
||||
class SalesDataBulkCreate(BaseModel):
|
||||
"""Schema for bulk creating sales data"""
|
||||
tenant_id: UUID
|
||||
records: List[Dict[str, Any]]
|
||||
source: str = Field(default="bulk_import", max_length=50)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class SalesValidationResult(BaseModel):
|
||||
"""Schema for sales data validation result"""
|
||||
is_valid: bool
|
||||
total_records: int
|
||||
valid_records: int
|
||||
invalid_records: int
|
||||
errors: List[Dict[str, Any]]
|
||||
warnings: List[Dict[str, Any]]
|
||||
summary: Dict[str, Any]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class SalesImportResult(BaseModel):
|
||||
"""Complete schema that includes all expected fields"""
|
||||
success: bool
|
||||
records_processed: int # total_rows
|
||||
records_created: int
|
||||
records_updated: int = 0 # Default to 0 if not tracking updates
|
||||
records_failed: int # error_count or calculated
|
||||
errors: List[Dict[str, Any]] # Structured error objects
|
||||
warnings: List[Dict[str, Any]] # Structured warning objects
|
||||
processing_time_seconds: float
|
||||
|
||||
# Optional additional fields
|
||||
source: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
success_rate: Optional[float] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class SalesAggregation(BaseModel):
|
||||
"""Schema for sales aggregation results"""
|
||||
period: str # "daily", "weekly", "monthly"
|
||||
date: datetime
|
||||
product_name: Optional[str] = None
|
||||
total_quantity: int
|
||||
total_revenue: float
|
||||
average_quantity: float
|
||||
average_revenue: float
|
||||
record_count: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class SalesExportRequest(BaseModel):
|
||||
"""Schema for sales export request"""
|
||||
tenant_id: UUID
|
||||
format: str = Field(..., pattern="^(csv|json|excel)$")
|
||||
start_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None
|
||||
product_names: Optional[List[str]] = None
|
||||
location_ids: Optional[List[str]] = None
|
||||
include_metadata: bool = Field(default=True)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class SalesValidationRequest(BaseModel):
|
||||
"""Schema for JSON-based sales data validation request"""
|
||||
data: str = Field(..., description="Raw data content (CSV, JSON, etc.)")
|
||||
data_format: str = Field(..., pattern="^(csv|json|excel)$", description="Format of the data")
|
||||
validate_only: bool = Field(default=True, description="Only validate, don't import")
|
||||
source: str = Field(default="onboarding_upload", description="Source of the data")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@@ -1,20 +0,0 @@
|
||||
"""
|
||||
Data Service Layer
|
||||
Business logic services for data operations
|
||||
"""
|
||||
|
||||
from .sales_service import SalesService
|
||||
from .data_import_service import DataImportService, EnhancedDataImportService
|
||||
from .traffic_service import TrafficService
|
||||
from .weather_service import WeatherService
|
||||
from .messaging import publish_sales_data_imported, publish_data_updated
|
||||
|
||||
__all__ = [
|
||||
"SalesService",
|
||||
"DataImportService",
|
||||
"EnhancedDataImportService",
|
||||
"TrafficService",
|
||||
"WeatherService",
|
||||
"publish_sales_data_imported",
|
||||
"publish_data_updated"
|
||||
]
|
||||
@@ -1,134 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/app/services/messaging.py - FIXED VERSION
|
||||
# ================================================================
|
||||
"""Fixed messaging service with proper error handling"""
|
||||
|
||||
from shared.messaging.rabbitmq import RabbitMQClient
|
||||
from app.core.config import settings
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Single global instance
|
||||
data_publisher = RabbitMQClient(settings.RABBITMQ_URL, "data-service")
|
||||
|
||||
async def setup_messaging():
|
||||
"""Initialize messaging for data service"""
|
||||
try:
|
||||
success = await data_publisher.connect()
|
||||
if success:
|
||||
logger.info("Data service messaging initialized")
|
||||
else:
|
||||
logger.warning("Data service messaging failed to initialize")
|
||||
return success
|
||||
except Exception as e:
|
||||
logger.warning("Failed to setup messaging", error=str(e))
|
||||
return False
|
||||
|
||||
async def cleanup_messaging():
|
||||
"""Cleanup messaging for data service"""
|
||||
try:
|
||||
await data_publisher.disconnect()
|
||||
logger.info("Data service messaging cleaned up")
|
||||
except Exception as e:
|
||||
logger.warning("Error during messaging cleanup", error=str(e))
|
||||
|
||||
# Convenience functions for data-specific events with error handling
|
||||
async def publish_data_imported(data: dict) -> bool:
|
||||
"""Publish data imported event"""
|
||||
try:
|
||||
return await data_publisher.publish_data_event("imported", data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish data imported event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_weather_updated(data: dict) -> bool:
|
||||
"""Publish weather updated event"""
|
||||
try:
|
||||
return await data_publisher.publish_data_event("weather.updated", data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish weather updated event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_traffic_updated(data: dict) -> bool:
|
||||
"""Publish traffic updated event"""
|
||||
try:
|
||||
return await data_publisher.publish_data_event("traffic.updated", data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish traffic updated event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_sales_created(data: dict) -> bool:
|
||||
"""Publish sales created event"""
|
||||
try:
|
||||
return await data_publisher.publish_data_event("sales.created", data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish sales created event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_analytics_generated(data: dict) -> bool:
|
||||
"""Publish analytics generated event"""
|
||||
try:
|
||||
return await data_publisher.publish_data_event("analytics.generated", data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish analytics generated event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_export_completed(data: dict) -> bool:
|
||||
"""Publish export completed event"""
|
||||
try:
|
||||
return await data_publisher.publish_data_event("export.completed", data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish export completed event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_import_started(data: dict) -> bool:
|
||||
"""Publish import started event"""
|
||||
try:
|
||||
return await data_publisher.publish_data_event("import.started", data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish import started event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_import_completed(data: dict) -> bool:
|
||||
"""Publish import completed event"""
|
||||
try:
|
||||
return await data_publisher.publish_data_event("import.completed", data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish import completed event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_import_failed(data: dict) -> bool:
|
||||
"""Publish import failed event"""
|
||||
try:
|
||||
return await data_publisher.publish_data_event("import.failed", data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish import failed event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_sales_data_imported(data: dict) -> bool:
|
||||
"""Publish sales data imported event"""
|
||||
try:
|
||||
return await data_publisher.publish_data_event("sales.imported", data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish sales data imported event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_data_updated(data: dict) -> bool:
|
||||
"""Publish data updated event"""
|
||||
try:
|
||||
return await data_publisher.publish_data_event("data.updated", data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish data updated event", error=str(e))
|
||||
return False
|
||||
|
||||
# Health check for messaging
|
||||
async def check_messaging_health() -> dict:
|
||||
"""Check messaging system health"""
|
||||
try:
|
||||
if data_publisher.connected:
|
||||
return {"status": "healthy", "service": "rabbitmq", "connected": True}
|
||||
else:
|
||||
return {"status": "unhealthy", "service": "rabbitmq", "connected": False, "error": "Not connected"}
|
||||
except Exception as e:
|
||||
return {"status": "unhealthy", "service": "rabbitmq", "connected": False, "error": str(e)}
|
||||
@@ -1,292 +0,0 @@
|
||||
"""
|
||||
Sales Service with Repository Pattern
|
||||
Enhanced service using the new repository architecture for better separation of concerns
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import structlog
|
||||
|
||||
from app.repositories.sales_repository import SalesRepository
|
||||
from app.models.sales import SalesData
|
||||
from app.schemas.sales import (
|
||||
SalesDataCreate,
|
||||
SalesDataResponse,
|
||||
SalesDataQuery,
|
||||
SalesAggregation,
|
||||
SalesImportResult,
|
||||
SalesValidationResult
|
||||
)
|
||||
from shared.database.unit_of_work import UnitOfWork
|
||||
from shared.database.transactions import transactional
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class SalesService:
|
||||
"""Enhanced Sales Service using Repository Pattern and Unit of Work"""
|
||||
|
||||
def __init__(self, database_manager):
|
||||
"""Initialize service with database manager for dependency injection"""
|
||||
self.database_manager = database_manager
|
||||
|
||||
async def create_sales_record(self, sales_data: SalesDataCreate, tenant_id: str) -> SalesDataResponse:
|
||||
"""Create a new sales record using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
async with UnitOfWork(session) as uow:
|
||||
# Register sales repository
|
||||
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
|
||||
|
||||
# Ensure tenant_id is set
|
||||
record_data = sales_data.model_dump()
|
||||
record_data["tenant_id"] = tenant_id
|
||||
|
||||
# Validate the data first
|
||||
validation_result = await sales_repo.validate_sales_data(record_data)
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid sales data: {validation_result['errors']}")
|
||||
|
||||
# Create the record
|
||||
db_record = await sales_repo.create(record_data)
|
||||
|
||||
# Commit transaction
|
||||
await uow.commit()
|
||||
|
||||
logger.debug("Sales record created",
|
||||
record_id=db_record.id,
|
||||
product=db_record.product_name,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
return SalesDataResponse.model_validate(db_record)
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create sales record",
|
||||
tenant_id=tenant_id,
|
||||
product=sales_data.product_name,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create sales record: {str(e)}")
|
||||
|
||||
async def get_sales_data(self, query: SalesDataQuery) -> List[SalesDataResponse]:
|
||||
"""Get sales data based on query parameters using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
async with UnitOfWork(session) as uow:
|
||||
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
|
||||
|
||||
# Use repository's advanced query method
|
||||
records = await sales_repo.get_by_tenant_and_date_range(
|
||||
tenant_id=str(query.tenant_id),
|
||||
start_date=query.start_date,
|
||||
end_date=query.end_date,
|
||||
product_names=query.product_names,
|
||||
location_ids=query.location_ids,
|
||||
skip=query.offset or 0,
|
||||
limit=query.limit or 100
|
||||
)
|
||||
|
||||
logger.debug("Sales data retrieved",
|
||||
count=len(records),
|
||||
tenant_id=query.tenant_id)
|
||||
|
||||
return [SalesDataResponse.model_validate(record) for record in records]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to retrieve sales data",
|
||||
tenant_id=query.tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to retrieve sales data: {str(e)}")
|
||||
|
||||
async def get_sales_analytics(self, tenant_id: str, start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""Get comprehensive sales analytics using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
async with UnitOfWork(session) as uow:
|
||||
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
|
||||
|
||||
# Get summary data
|
||||
summary = await sales_repo.get_sales_summary(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
# Get top products
|
||||
top_products = await sales_repo.get_top_products(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=5
|
||||
)
|
||||
|
||||
# Get aggregated data by day
|
||||
daily_aggregation = await sales_repo.get_sales_aggregation(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
group_by="daily"
|
||||
)
|
||||
|
||||
analytics = {
|
||||
**summary,
|
||||
"top_products": top_products,
|
||||
"daily_sales": daily_aggregation[:30], # Last 30 days
|
||||
"average_order_value": (
|
||||
summary["total_revenue"] / max(summary["total_sales"], 1)
|
||||
if summary["total_sales"] > 0 else 0.0
|
||||
)
|
||||
}
|
||||
|
||||
logger.debug("Sales analytics generated",
|
||||
tenant_id=tenant_id,
|
||||
total_records=analytics["total_sales"])
|
||||
|
||||
return analytics
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to generate sales analytics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to generate analytics: {str(e)}")
|
||||
|
||||
async def get_sales_aggregation(self, tenant_id: str, start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None, group_by: str = "daily") -> List[SalesAggregation]:
|
||||
"""Get sales aggregation data"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
async with UnitOfWork(session) as uow:
|
||||
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
|
||||
|
||||
aggregations = await sales_repo.get_sales_aggregation(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
group_by=group_by
|
||||
)
|
||||
|
||||
return [
|
||||
SalesAggregation(
|
||||
period=agg["period"],
|
||||
date=agg["date"],
|
||||
product_name=agg["product_name"],
|
||||
total_quantity=agg["total_quantity"],
|
||||
total_revenue=agg["total_revenue"],
|
||||
average_quantity=agg["average_quantity"],
|
||||
average_revenue=agg["average_revenue"],
|
||||
record_count=agg["record_count"]
|
||||
)
|
||||
for agg in aggregations
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales aggregation",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get aggregation: {str(e)}")
|
||||
|
||||
async def export_sales_data(self, tenant_id: str, export_format: str, start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None, products: Optional[List[str]] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Export sales data in specified format using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
async with UnitOfWork(session) as uow:
|
||||
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
|
||||
|
||||
# Get sales data based on filters
|
||||
records = await sales_repo.get_by_tenant_and_date_range(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_names=products,
|
||||
skip=0,
|
||||
limit=10000 # Large limit for export
|
||||
)
|
||||
|
||||
if not records:
|
||||
return None
|
||||
|
||||
# Simple CSV export
|
||||
if export_format.lower() == "csv":
|
||||
import io
|
||||
output = io.StringIO()
|
||||
output.write("date,product_name,quantity_sold,revenue,location_id,source\n")
|
||||
|
||||
for record in records:
|
||||
output.write(f"{record.date},{record.product_name},{record.quantity_sold},{record.revenue},{record.location_id or ''},{record.source}\n")
|
||||
|
||||
logger.info("Sales data exported",
|
||||
tenant_id=tenant_id,
|
||||
format=export_format,
|
||||
record_count=len(records))
|
||||
|
||||
return {
|
||||
"content": output.getvalue(),
|
||||
"media_type": "text/csv",
|
||||
"filename": f"sales_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to export sales data",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to export sales data: {str(e)}")
|
||||
|
||||
async def delete_sales_record(self, record_id: str, tenant_id: str) -> bool:
|
||||
"""Delete a sales record using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
async with UnitOfWork(session) as uow:
|
||||
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
|
||||
|
||||
# First verify the record exists and belongs to the tenant
|
||||
record = await sales_repo.get_by_id(record_id)
|
||||
if not record:
|
||||
return False
|
||||
|
||||
if str(record.tenant_id) != tenant_id:
|
||||
raise ValidationError("Record does not belong to the specified tenant")
|
||||
|
||||
# Delete the record
|
||||
success = await sales_repo.delete(record_id)
|
||||
|
||||
if success:
|
||||
logger.info("Sales record deleted",
|
||||
record_id=record_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
return success
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to delete sales record",
|
||||
record_id=record_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to delete sales record: {str(e)}")
|
||||
|
||||
async def get_products_list(self, tenant_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get list of all products with sales data for tenant using repository pattern"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
async with UnitOfWork(session) as uow:
|
||||
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
|
||||
|
||||
# Use repository method for product statistics
|
||||
products = await sales_repo.get_product_statistics(tenant_id)
|
||||
|
||||
logger.debug("Products list retrieved successfully",
|
||||
tenant_id=tenant_id,
|
||||
product_count=len(products))
|
||||
|
||||
return products
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get products list",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise DatabaseError(f"Failed to get products list: {str(e)}")
|
||||
@@ -1,468 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/app/services/traffic_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Abstracted Traffic Service - Universal interface for traffic data across multiple cities
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
import structlog
|
||||
|
||||
from app.external.apis.traffic import UniversalTrafficClient
|
||||
from app.models.traffic import TrafficData
|
||||
from app.core.performance import (
|
||||
async_cache,
|
||||
monitor_performance,
|
||||
global_connection_pool,
|
||||
global_performance_monitor,
|
||||
batch_process
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TrafficService:
|
||||
"""
|
||||
Abstracted traffic service providing unified interface for traffic data
|
||||
Routes requests to appropriate city-specific clients automatically
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.universal_client = UniversalTrafficClient()
|
||||
self.logger = structlog.get_logger(__name__)
|
||||
|
||||
@async_cache(ttl=300) # Cache for 5 minutes
|
||||
@monitor_performance(monitor=global_performance_monitor)
|
||||
async def get_current_traffic(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get current traffic data for any supported location
|
||||
|
||||
Args:
|
||||
latitude: Query location latitude
|
||||
longitude: Query location longitude
|
||||
tenant_id: Optional tenant identifier for logging/analytics
|
||||
|
||||
Returns:
|
||||
Dict with current traffic data or None if not available
|
||||
"""
|
||||
try:
|
||||
self.logger.info("Getting current traffic data",
|
||||
lat=latitude, lon=longitude, tenant_id=tenant_id)
|
||||
|
||||
# Delegate to universal client
|
||||
traffic_data = await self.universal_client.get_current_traffic(latitude, longitude)
|
||||
|
||||
if traffic_data:
|
||||
# Add service metadata
|
||||
traffic_data['service_metadata'] = {
|
||||
'request_timestamp': datetime.now().isoformat(),
|
||||
'tenant_id': tenant_id,
|
||||
'service_version': '2.0',
|
||||
'query_location': {'latitude': latitude, 'longitude': longitude}
|
||||
}
|
||||
|
||||
self.logger.info("Successfully retrieved current traffic data",
|
||||
lat=latitude, lon=longitude,
|
||||
source=traffic_data.get('source', 'unknown'))
|
||||
|
||||
return traffic_data
|
||||
else:
|
||||
self.logger.warning("No current traffic data available",
|
||||
lat=latitude, lon=longitude)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Error getting current traffic data",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return None
|
||||
|
||||
@async_cache(ttl=1800) # Cache for 30 minutes (historical data changes less frequently)
|
||||
@monitor_performance(monitor=global_performance_monitor)
|
||||
async def get_historical_traffic(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
tenant_id: Optional[str] = None,
|
||||
db: Optional[AsyncSession] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get historical traffic data for any supported location with database storage
|
||||
|
||||
Args:
|
||||
latitude: Query location latitude
|
||||
longitude: Query location longitude
|
||||
start_date: Start date for historical data
|
||||
end_date: End date for historical data
|
||||
tenant_id: Optional tenant identifier
|
||||
db: Optional database session for storage
|
||||
|
||||
Returns:
|
||||
List of historical traffic data dictionaries
|
||||
"""
|
||||
try:
|
||||
self.logger.info("Getting historical traffic data",
|
||||
lat=latitude, lon=longitude,
|
||||
start=start_date, end=end_date, tenant_id=tenant_id)
|
||||
|
||||
# Validate date range
|
||||
if start_date >= end_date:
|
||||
self.logger.warning("Invalid date range", start=start_date, end=end_date)
|
||||
return []
|
||||
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
# Check database first if session provided
|
||||
if db:
|
||||
stmt = select(TrafficData).where(
|
||||
and_(
|
||||
TrafficData.location_id == location_id,
|
||||
TrafficData.date >= start_date,
|
||||
TrafficData.date <= end_date
|
||||
)
|
||||
).order_by(TrafficData.date)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
db_records = result.scalars().all()
|
||||
|
||||
if db_records:
|
||||
self.logger.info("Historical traffic data found in database",
|
||||
count=len(db_records))
|
||||
return [self._convert_db_record_to_dict(record) for record in db_records]
|
||||
|
||||
# Delegate to universal client
|
||||
traffic_data = await self.universal_client.get_historical_traffic(
|
||||
latitude, longitude, start_date, end_date
|
||||
)
|
||||
|
||||
if traffic_data:
|
||||
# Add service metadata to each record
|
||||
for record in traffic_data:
|
||||
record['service_metadata'] = {
|
||||
'request_timestamp': datetime.now().isoformat(),
|
||||
'tenant_id': tenant_id,
|
||||
'service_version': '2.0',
|
||||
'query_location': {'latitude': latitude, 'longitude': longitude},
|
||||
'date_range': {
|
||||
'start': start_date.isoformat(),
|
||||
'end': end_date.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# Store in database if session provided
|
||||
if db:
|
||||
stored_count = await self._store_traffic_data_batch(
|
||||
traffic_data, location_id, db
|
||||
)
|
||||
self.logger.info("Traffic data stored for re-training",
|
||||
fetched=len(traffic_data), stored=stored_count,
|
||||
location=location_id)
|
||||
|
||||
self.logger.info("Successfully retrieved historical traffic data",
|
||||
lat=latitude, lon=longitude, records=len(traffic_data))
|
||||
|
||||
return traffic_data
|
||||
else:
|
||||
self.logger.info("No historical traffic data available",
|
||||
lat=latitude, lon=longitude)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Error getting historical traffic data",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return []
|
||||
|
||||
def _convert_db_record_to_dict(self, record: TrafficData) -> Dict[str, Any]:
|
||||
"""Convert database record to dictionary format"""
|
||||
return {
|
||||
'date': record.date,
|
||||
'traffic_volume': record.traffic_volume,
|
||||
'pedestrian_count': record.pedestrian_count,
|
||||
'congestion_level': record.congestion_level,
|
||||
'average_speed': record.average_speed,
|
||||
'source': record.source,
|
||||
'location_id': record.location_id,
|
||||
'raw_data': record.raw_data
|
||||
}
|
||||
|
||||
async def get_traffic_events(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
radius_km: float = 5.0,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get traffic events and incidents for any supported location
|
||||
|
||||
Args:
|
||||
latitude: Query location latitude
|
||||
longitude: Query location longitude
|
||||
radius_km: Search radius in kilometers
|
||||
tenant_id: Optional tenant identifier
|
||||
|
||||
Returns:
|
||||
List of traffic events
|
||||
"""
|
||||
try:
|
||||
self.logger.info("Getting traffic events",
|
||||
lat=latitude, lon=longitude, radius=radius_km, tenant_id=tenant_id)
|
||||
|
||||
# Delegate to universal client
|
||||
events = await self.universal_client.get_events(latitude, longitude, radius_km)
|
||||
|
||||
# Add metadata to events
|
||||
for event in events:
|
||||
event['service_metadata'] = {
|
||||
'request_timestamp': datetime.now().isoformat(),
|
||||
'tenant_id': tenant_id,
|
||||
'service_version': '2.0',
|
||||
'query_location': {'latitude': latitude, 'longitude': longitude},
|
||||
'search_radius_km': radius_km
|
||||
}
|
||||
|
||||
self.logger.info("Retrieved traffic events",
|
||||
lat=latitude, lon=longitude, events=len(events))
|
||||
|
||||
return events
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Error getting traffic events",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return []
|
||||
|
||||
def get_location_info(self, latitude: float, longitude: float) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about traffic data availability for location
|
||||
|
||||
Args:
|
||||
latitude: Query location latitude
|
||||
longitude: Query location longitude
|
||||
|
||||
Returns:
|
||||
Dict with location support information
|
||||
"""
|
||||
try:
|
||||
info = self.universal_client.get_location_info(latitude, longitude)
|
||||
|
||||
# Add service layer information
|
||||
info['service_layer'] = {
|
||||
'version': '2.0',
|
||||
'abstraction_level': 'universal',
|
||||
'supported_operations': [
|
||||
'current_traffic',
|
||||
'historical_traffic',
|
||||
'traffic_events',
|
||||
'bulk_requests'
|
||||
]
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Error getting location info",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return {
|
||||
'supported': False,
|
||||
'error': str(e),
|
||||
'service_layer': {'version': '2.0'}
|
||||
}
|
||||
|
||||
async def store_traffic_data(self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
traffic_data: Dict[str, Any],
|
||||
db: AsyncSession) -> bool:
|
||||
"""Store single traffic data record to database"""
|
||||
try:
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
traffic_record = TrafficData(
|
||||
location_id=location_id,
|
||||
date=traffic_data.get("date", datetime.now()),
|
||||
traffic_volume=traffic_data.get("traffic_volume"),
|
||||
pedestrian_count=traffic_data.get("pedestrian_count"),
|
||||
congestion_level=traffic_data.get("congestion_level"),
|
||||
average_speed=traffic_data.get("average_speed"),
|
||||
source=traffic_data.get("source", "madrid_opendata"),
|
||||
raw_data=str(traffic_data) if traffic_data else None
|
||||
)
|
||||
|
||||
db.add(traffic_record)
|
||||
await db.commit()
|
||||
|
||||
logger.debug("Traffic data stored successfully", location_id=location_id)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to store traffic data", error=str(e))
|
||||
await db.rollback()
|
||||
return False
|
||||
|
||||
async def _store_traffic_data_batch(self,
|
||||
traffic_data: List[Dict[str, Any]],
|
||||
location_id: str,
|
||||
db: AsyncSession) -> int:
|
||||
"""Store batch of traffic data with enhanced validation and duplicate handling"""
|
||||
stored_count = 0
|
||||
|
||||
try:
|
||||
# Check for existing records to avoid duplicates
|
||||
if traffic_data:
|
||||
dates = [data.get('date') for data in traffic_data if data.get('date')]
|
||||
if dates:
|
||||
# Query existing records for this location and date range
|
||||
existing_stmt = select(TrafficData.date).where(
|
||||
and_(
|
||||
TrafficData.location_id == location_id,
|
||||
TrafficData.date.in_(dates)
|
||||
)
|
||||
)
|
||||
result = await db.execute(existing_stmt)
|
||||
existing_dates = {row[0] for row in result.fetchall()}
|
||||
|
||||
logger.debug(f"Found {len(existing_dates)} existing records for location {location_id}")
|
||||
else:
|
||||
existing_dates = set()
|
||||
else:
|
||||
existing_dates = set()
|
||||
|
||||
# Prepare batch of new records for bulk insert
|
||||
batch_records = []
|
||||
for data in traffic_data:
|
||||
try:
|
||||
record_date = data.get('date')
|
||||
if not record_date or record_date in existing_dates:
|
||||
continue # Skip duplicates
|
||||
|
||||
# Validate required fields
|
||||
if not self._validate_traffic_data(data):
|
||||
logger.warning("Invalid traffic data, skipping", data=data)
|
||||
continue
|
||||
|
||||
# Prepare record data for bulk insert
|
||||
record_data = {
|
||||
'location_id': location_id,
|
||||
'date': record_date,
|
||||
'traffic_volume': data.get('traffic_volume'),
|
||||
'pedestrian_count': data.get('pedestrian_count'),
|
||||
'congestion_level': data.get('congestion_level'),
|
||||
'average_speed': data.get('average_speed'),
|
||||
'source': data.get('source', 'madrid_opendata'),
|
||||
'raw_data': str(data)
|
||||
}
|
||||
batch_records.append(record_data)
|
||||
|
||||
except Exception as record_error:
|
||||
logger.warning("Failed to prepare traffic record",
|
||||
error=str(record_error), data=data)
|
||||
continue
|
||||
|
||||
# Use efficient bulk insert instead of individual records
|
||||
if batch_records:
|
||||
# Process in chunks to avoid memory issues
|
||||
chunk_size = 5000
|
||||
for i in range(0, len(batch_records), chunk_size):
|
||||
chunk = batch_records[i:i + chunk_size]
|
||||
|
||||
# Use SQLAlchemy bulk insert for maximum performance
|
||||
await db.execute(
|
||||
TrafficData.__table__.insert(),
|
||||
chunk
|
||||
)
|
||||
await db.commit()
|
||||
stored_count += len(chunk)
|
||||
|
||||
logger.debug(f"Bulk inserted {len(chunk)} records (total: {stored_count})")
|
||||
|
||||
logger.info(f"Successfully stored {stored_count} traffic records for location {location_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to store traffic data batch",
|
||||
error=str(e), location_id=location_id)
|
||||
await db.rollback()
|
||||
|
||||
return stored_count
|
||||
|
||||
def _validate_traffic_data(self, data: Dict[str, Any]) -> bool:
|
||||
"""Validate traffic data before storage"""
|
||||
required_fields = ['date']
|
||||
|
||||
# Check required fields
|
||||
for field in required_fields:
|
||||
if not data.get(field):
|
||||
return False
|
||||
|
||||
# Validate data types and ranges
|
||||
traffic_volume = data.get('traffic_volume')
|
||||
if traffic_volume is not None and (traffic_volume < 0 or traffic_volume > 10000):
|
||||
return False
|
||||
|
||||
pedestrian_count = data.get('pedestrian_count')
|
||||
if pedestrian_count is not None and (pedestrian_count < 0 or pedestrian_count > 10000):
|
||||
return False
|
||||
|
||||
average_speed = data.get('average_speed')
|
||||
if average_speed is not None and (average_speed < 0 or average_speed > 200):
|
||||
return False
|
||||
|
||||
congestion_level = data.get('congestion_level')
|
||||
if congestion_level and congestion_level not in ['low', 'medium', 'high', 'blocked']:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def get_stored_traffic_for_training(self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
db: AsyncSession) -> List[Dict[str, Any]]:
|
||||
"""Retrieve stored traffic data specifically for training purposes"""
|
||||
try:
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
stmt = select(TrafficData).where(
|
||||
and_(
|
||||
TrafficData.location_id == location_id,
|
||||
TrafficData.date >= start_date,
|
||||
TrafficData.date <= end_date
|
||||
)
|
||||
).order_by(TrafficData.date)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
records = result.scalars().all()
|
||||
|
||||
# Convert to training format
|
||||
training_data = []
|
||||
for record in records:
|
||||
training_data.append({
|
||||
'date': record.date,
|
||||
'traffic_volume': record.traffic_volume,
|
||||
'pedestrian_count': record.pedestrian_count,
|
||||
'congestion_level': record.congestion_level,
|
||||
'average_speed': record.average_speed,
|
||||
'location_id': record.location_id,
|
||||
'source': record.source,
|
||||
'measurement_point_id': record.raw_data # Contains additional metadata
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(training_data)} traffic records for training",
|
||||
location_id=location_id, start=start_date, end=end_date)
|
||||
|
||||
return training_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to retrieve traffic data for training",
|
||||
error=str(e), location_id=location_id)
|
||||
return []
|
||||
@@ -1,117 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/alembic.ini
|
||||
# ================================================================
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = migrations
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python-dateutil library that can be
|
||||
# installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to dateutil.tz.gettz()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version number format. This value is passed to the Python
|
||||
# datetime.datetime.strftime() method for formatting the creation date.
|
||||
# For UTC time zone add 'utc' prefix (ex: utc%Y_%m_%d_%H%M )
|
||||
version_num_format = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses
|
||||
# os.pathsep. If this key is omitted entirely, it falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -1,68 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/migrations/env.py
|
||||
# ================================================================
|
||||
"""Alembic environment configuration"""
|
||||
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
from alembic import context
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import Base
|
||||
from app.models import sales, weather, traffic
|
||||
|
||||
# this is the Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Set database URL
|
||||
config.set_main_option("sqlalchemy.url", settings.DATABASE_URL.replace('+asyncpg', ''))
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""Run migrations in 'online' mode with async engine."""
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -1,29 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/migrations/script.py.mako
|
||||
# ================================================================
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -1,54 +0,0 @@
|
||||
"""Create traffic_data table for storing traffic data for re-training
|
||||
|
||||
Revision ID: 001_traffic_data
|
||||
Revises:
|
||||
Create Date: 2025-01-08 12:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '001_traffic_data'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""Create traffic_data table"""
|
||||
op.create_table('traffic_data',
|
||||
sa.Column('id', UUID(as_uuid=True), nullable=False, primary_key=True),
|
||||
sa.Column('location_id', sa.String(100), nullable=False, index=True),
|
||||
sa.Column('date', sa.DateTime(timezone=True), nullable=False, index=True),
|
||||
sa.Column('traffic_volume', sa.Integer, nullable=True),
|
||||
sa.Column('pedestrian_count', sa.Integer, nullable=True),
|
||||
sa.Column('congestion_level', sa.String(20), nullable=True),
|
||||
sa.Column('average_speed', sa.Float, nullable=True),
|
||||
sa.Column('source', sa.String(50), nullable=False, server_default='madrid_opendata'),
|
||||
sa.Column('raw_data', sa.Text, nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
)
|
||||
|
||||
# Create index for efficient querying by location and date
|
||||
op.create_index(
|
||||
'idx_traffic_location_date',
|
||||
'traffic_data',
|
||||
['location_id', 'date']
|
||||
)
|
||||
|
||||
# Create index for date range queries
|
||||
op.create_index(
|
||||
'idx_traffic_date_range',
|
||||
'traffic_data',
|
||||
['date']
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""Drop traffic_data table"""
|
||||
op.drop_index('idx_traffic_date_range', table_name='traffic_data')
|
||||
op.drop_index('idx_traffic_location_date', table_name='traffic_data')
|
||||
op.drop_table('traffic_data')
|
||||
@@ -1,49 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/migrations/versions/20250727_add_timezone_to_datetime_columns.py
|
||||
# ================================================================
|
||||
"""Add timezone support to datetime columns
|
||||
|
||||
Revision ID: 20250727_193000
|
||||
Revises:
|
||||
Create Date: 2025-07-27 19:30:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '20250727_193000'
|
||||
down_revision = None # Replace with actual previous revision if exists
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Convert TIMESTAMP WITHOUT TIME ZONE to TIMESTAMP WITH TIME ZONE"""
|
||||
|
||||
# Weather data table
|
||||
op.execute("ALTER TABLE weather_data ALTER COLUMN date TYPE TIMESTAMP WITH TIME ZONE USING date AT TIME ZONE 'UTC'")
|
||||
op.execute("ALTER TABLE weather_data ALTER COLUMN created_at TYPE TIMESTAMP WITH TIME ZONE USING created_at AT TIME ZONE 'UTC'")
|
||||
|
||||
# Weather forecasts table
|
||||
op.execute("ALTER TABLE weather_forecasts ALTER COLUMN forecast_date TYPE TIMESTAMP WITH TIME ZONE USING forecast_date AT TIME ZONE 'UTC'")
|
||||
op.execute("ALTER TABLE weather_forecasts ALTER COLUMN generated_at TYPE TIMESTAMP WITH TIME ZONE USING generated_at AT TIME ZONE 'UTC'")
|
||||
|
||||
# Traffic data table
|
||||
op.execute("ALTER TABLE traffic_data ALTER COLUMN date TYPE TIMESTAMP WITH TIME ZONE USING date AT TIME ZONE 'UTC'")
|
||||
op.execute("ALTER TABLE traffic_data ALTER COLUMN created_at TYPE TIMESTAMP WITH TIME ZONE USING created_at AT TIME ZONE 'UTC'")
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Convert TIMESTAMP WITH TIME ZONE back to TIMESTAMP WITHOUT TIME ZONE"""
|
||||
|
||||
# Weather data table
|
||||
op.execute("ALTER TABLE weather_data ALTER COLUMN date TYPE TIMESTAMP WITHOUT TIME ZONE USING date AT TIME ZONE 'UTC'")
|
||||
op.execute("ALTER TABLE weather_data ALTER COLUMN created_at TYPE TIMESTAMP WITHOUT TIME ZONE USING created_at AT TIME ZONE 'UTC'")
|
||||
|
||||
# Weather forecasts table
|
||||
op.execute("ALTER TABLE weather_forecasts ALTER COLUMN forecast_date TYPE TIMESTAMP WITHOUT TIME ZONE USING forecast_date AT TIME ZONE 'UTC'")
|
||||
op.execute("ALTER TABLE weather_forecasts ALTER COLUMN generated_at TYPE TIMESTAMP WITHOUT TIME ZONE USING generated_at AT TIME ZONE 'UTC'")
|
||||
|
||||
# Traffic data table
|
||||
op.execute("ALTER TABLE traffic_data ALTER COLUMN date TYPE TIMESTAMP WITHOUT TIME ZONE USING date AT TIME ZONE 'UTC'")
|
||||
op.execute("ALTER TABLE traffic_data ALTER COLUMN created_at TYPE TIMESTAMP WITHOUT TIME ZONE USING created_at AT TIME ZONE 'UTC'")
|
||||
@@ -1,52 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/requirements.txt - UPDATED
|
||||
# ================================================================
|
||||
|
||||
# FastAPI and web framework
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
|
||||
# Database
|
||||
sqlalchemy[asyncio]==2.0.23
|
||||
asyncpg==0.29.0
|
||||
alembic==1.12.1
|
||||
|
||||
# Data validation
|
||||
pydantic==2.5.0
|
||||
pydantic-settings==2.1.0
|
||||
|
||||
# Cache and messaging
|
||||
redis==5.0.1
|
||||
aio-pika==9.3.1
|
||||
|
||||
# HTTP client
|
||||
httpx==0.25.2
|
||||
|
||||
# Data processing
|
||||
pandas==2.1.3
|
||||
numpy==1.25.2
|
||||
openpyxl==3.1.2 # For Excel (.xlsx) files
|
||||
xlrd==2.0.1 # For Excel (.xls) files
|
||||
python-multipart==0.0.6
|
||||
|
||||
# Monitoring and logging
|
||||
prometheus-client==0.19.0
|
||||
structlog==23.2.0
|
||||
python-logstash==0.4.8
|
||||
python-json-logger==2.0.4
|
||||
|
||||
# Security
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
|
||||
# Testing
|
||||
pytest==7.4.3
|
||||
pytest-asyncio==0.21.1
|
||||
pytest-cov==4.1.0
|
||||
pytest-mock==3.12.0
|
||||
pytest-xdist==3.5.0
|
||||
pytest-timeout==2.2.0
|
||||
psutil==5.9.8
|
||||
|
||||
# Cartographic projections and coordinate transformations library
|
||||
pyproj==3.4.0
|
||||
@@ -1,653 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/tests/conftest.py - AEMET Test Configuration
|
||||
# ================================================================
|
||||
"""
|
||||
Test configuration and fixtures for AEMET weather API client tests
|
||||
Provides shared fixtures, mock data, and test utilities
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from typing import Dict, List, Any, Generator
|
||||
import os
|
||||
|
||||
# Import the classes we're testing
|
||||
from app.external.aemet import (
|
||||
AEMETClient,
|
||||
WeatherDataParser,
|
||||
SyntheticWeatherGenerator,
|
||||
LocationService,
|
||||
WeatherSource
|
||||
)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# PYTEST CONFIGURATION
|
||||
# ================================================================
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for the test session."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# CLIENT AND SERVICE FIXTURES
|
||||
# ================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def aemet_client():
|
||||
"""Create AEMET client instance for testing"""
|
||||
return AEMETClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def weather_parser():
|
||||
"""Create WeatherDataParser instance for testing"""
|
||||
return WeatherDataParser()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def synthetic_generator():
|
||||
"""Create SyntheticWeatherGenerator instance for testing"""
|
||||
return SyntheticWeatherGenerator()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def location_service():
|
||||
"""Create LocationService instance for testing"""
|
||||
return LocationService()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# COORDINATE AND LOCATION FIXTURES
|
||||
# ================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def madrid_coords():
|
||||
"""Standard Madrid coordinates for testing"""
|
||||
return (40.4168, -3.7038) # Madrid city center
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def madrid_coords_variants():
|
||||
"""Various Madrid area coordinates for testing"""
|
||||
return {
|
||||
"center": (40.4168, -3.7038), # Madrid center
|
||||
"north": (40.4677, -3.5552), # Madrid north (near station)
|
||||
"south": (40.2987, -3.7216), # Madrid south (near station)
|
||||
"east": (40.4200, -3.6500), # Madrid east
|
||||
"west": (40.4100, -3.7500), # Madrid west
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_coords():
|
||||
"""Invalid coordinates for error testing"""
|
||||
return [
|
||||
(200, 200), # Out of range
|
||||
(-200, -200), # Out of range
|
||||
(0, 0), # Not in Madrid area
|
||||
(50, 10), # Europe but not Madrid
|
||||
(None, None), # None values
|
||||
]
|
||||
|
||||
|
||||
# ================================================================
|
||||
# DATE AND TIME FIXTURES
|
||||
# ================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def test_dates():
|
||||
"""Standard date ranges for testing"""
|
||||
now = datetime.now()
|
||||
return {
|
||||
"now": now,
|
||||
"yesterday": now - timedelta(days=1),
|
||||
"last_week": now - timedelta(days=7),
|
||||
"last_month": now - timedelta(days=30),
|
||||
"last_quarter": now - timedelta(days=90),
|
||||
"one_year_ago": now - timedelta(days=365),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def historical_date_ranges():
|
||||
"""Historical date ranges for testing"""
|
||||
end_date = datetime.now()
|
||||
return {
|
||||
"one_day": {
|
||||
"start": end_date - timedelta(days=1),
|
||||
"end": end_date,
|
||||
"expected_days": 1
|
||||
},
|
||||
"one_week": {
|
||||
"start": end_date - timedelta(days=7),
|
||||
"end": end_date,
|
||||
"expected_days": 7
|
||||
},
|
||||
"one_month": {
|
||||
"start": end_date - timedelta(days=30),
|
||||
"end": end_date,
|
||||
"expected_days": 30
|
||||
},
|
||||
"large_range": {
|
||||
"start": end_date - timedelta(days=65),
|
||||
"end": end_date,
|
||||
"expected_days": 65
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ================================================================
|
||||
# MOCK API RESPONSE FIXTURES
|
||||
# ================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def mock_aemet_api_response():
|
||||
"""Mock AEMET API initial response structure"""
|
||||
return {
|
||||
"datos": "https://opendata.aemet.es/opendata/sh/12345abcdef",
|
||||
"metadatos": "https://opendata.aemet.es/opendata/sh/metadata123"
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_aemet_error_response():
|
||||
"""Mock AEMET API error response"""
|
||||
return {
|
||||
"descripcion": "Error en la petición",
|
||||
"estado": 404
|
||||
}
|
||||
|
||||
|
||||
# ================================================================
|
||||
# WEATHER DATA FIXTURES
|
||||
# ================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_weather_data():
|
||||
"""Mock current weather data from AEMET API"""
|
||||
return {
|
||||
"idema": "3195", # Station ID
|
||||
"ubi": "MADRID", # Location
|
||||
"fint": "2025-07-24T14:00:00", # Observation time
|
||||
"ta": 18.5, # Temperature (°C)
|
||||
"tamin": 12.3, # Min temperature
|
||||
"tamax": 25.7, # Max temperature
|
||||
"hr": 65.0, # Humidity (%)
|
||||
"prec": 0.0, # Precipitation (mm)
|
||||
"vv": 12.0, # Wind speed (km/h)
|
||||
"dv": 180, # Wind direction (degrees)
|
||||
"pres": 1015.2, # Pressure (hPa)
|
||||
"presMax": 1018.5, # Max pressure
|
||||
"presMin": 1012.1, # Min pressure
|
||||
"descripcion": "Despejado" # Description
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_forecast_data():
|
||||
"""Mock forecast data from AEMET API"""
|
||||
return [{
|
||||
"origen": {
|
||||
"productor": "Agencia Estatal de Meteorología - AEMET"
|
||||
},
|
||||
"elaborado": "2025-07-24T12:00:00UTC",
|
||||
"nombre": "Madrid",
|
||||
"provincia": "Madrid",
|
||||
"prediccion": {
|
||||
"dia": [
|
||||
{
|
||||
"fecha": "2025-07-25T00:00:00",
|
||||
"temperatura": {
|
||||
"maxima": 28,
|
||||
"minima": 15,
|
||||
"dato": [
|
||||
{"value": 15, "hora": 6},
|
||||
{"value": 28, "hora": 15}
|
||||
]
|
||||
},
|
||||
"sensTermica": {
|
||||
"maxima": 30,
|
||||
"minima": 16
|
||||
},
|
||||
"humedadRelativa": {
|
||||
"maxima": 85,
|
||||
"minima": 45,
|
||||
"dato": [
|
||||
{"value": 85, "hora": 6},
|
||||
{"value": 45, "hora": 15}
|
||||
]
|
||||
},
|
||||
"probPrecipitacion": [
|
||||
{"value": 10, "periodo": "00-24"}
|
||||
],
|
||||
"viento": [
|
||||
{
|
||||
"direccion": ["N"],
|
||||
"velocidad": [15],
|
||||
"periodo": "00-24"
|
||||
}
|
||||
],
|
||||
"estadoCielo": [
|
||||
{
|
||||
"value": "11",
|
||||
"descripcion": "Despejado",
|
||||
"periodo": "00-24"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"fecha": "2025-07-26T00:00:00",
|
||||
"temperatura": {
|
||||
"maxima": 30,
|
||||
"minima": 17
|
||||
},
|
||||
"probPrecipitacion": [
|
||||
{"value": 5, "periodo": "00-24"}
|
||||
],
|
||||
"viento": [
|
||||
{
|
||||
"direccion": ["NE"],
|
||||
"velocidad": [10],
|
||||
"periodo": "00-24"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_historical_data():
|
||||
"""Mock historical weather data from AEMET API"""
|
||||
return [
|
||||
{
|
||||
"indicativo": "3195",
|
||||
"nombre": "MADRID",
|
||||
"fecha": "2025-07-20",
|
||||
"tmax": 25.2,
|
||||
"horatmax": "1530",
|
||||
"tmin": 14.8,
|
||||
"horatmin": "0630",
|
||||
"tmed": 20.0,
|
||||
"prec": 0.0,
|
||||
"racha": 25.0,
|
||||
"horaracha": "1445",
|
||||
"sol": 8.5,
|
||||
"presMax": 1018.5,
|
||||
"horaPresMax": "1000",
|
||||
"presMin": 1012.3,
|
||||
"horaPresMin": "1700",
|
||||
"hr": 58,
|
||||
"velmedia": 8.5,
|
||||
"dir": "180"
|
||||
},
|
||||
{
|
||||
"indicativo": "3195",
|
||||
"nombre": "MADRID",
|
||||
"fecha": "2025-07-21",
|
||||
"tmax": 27.1,
|
||||
"horatmax": "1615",
|
||||
"tmin": 16.2,
|
||||
"horatmin": "0700",
|
||||
"tmed": 21.6,
|
||||
"prec": 2.5,
|
||||
"racha": 30.0,
|
||||
"horaracha": "1330",
|
||||
"sol": 6.2,
|
||||
"presMax": 1015.8,
|
||||
"horaPresMax": "0930",
|
||||
"presMin": 1010.1,
|
||||
"horaPresMin": "1800",
|
||||
"hr": 72,
|
||||
"velmedia": 12.0,
|
||||
"dir": "225"
|
||||
},
|
||||
{
|
||||
"indicativo": "3195",
|
||||
"nombre": "MADRID",
|
||||
"fecha": "2025-07-22",
|
||||
"tmax": 23.8,
|
||||
"horatmax": "1500",
|
||||
"tmin": 13.5,
|
||||
"horatmin": "0615",
|
||||
"tmed": 18.7,
|
||||
"prec": 0.2,
|
||||
"racha": 22.0,
|
||||
"horaracha": "1200",
|
||||
"sol": 7.8,
|
||||
"presMax": 1020.2,
|
||||
"horaPresMax": "1100",
|
||||
"presMin": 1014.7,
|
||||
"horaPresMin": "1900",
|
||||
"hr": 63,
|
||||
"velmedia": 9.2,
|
||||
"dir": "270"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# ================================================================
|
||||
# EXPECTED RESULT FIXTURES
|
||||
# ================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def expected_current_weather_structure():
|
||||
"""Expected structure for current weather results"""
|
||||
return {
|
||||
"required_fields": [
|
||||
"date", "temperature", "precipitation", "humidity",
|
||||
"wind_speed", "pressure", "description", "source"
|
||||
],
|
||||
"field_types": {
|
||||
"date": datetime,
|
||||
"temperature": (int, float),
|
||||
"precipitation": (int, float),
|
||||
"humidity": (int, float),
|
||||
"wind_speed": (int, float),
|
||||
"pressure": (int, float),
|
||||
"description": str,
|
||||
"source": str
|
||||
},
|
||||
"valid_ranges": {
|
||||
"temperature": (-30, 50),
|
||||
"precipitation": (0, 200),
|
||||
"humidity": (0, 100),
|
||||
"wind_speed": (0, 200),
|
||||
"pressure": (900, 1100)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_forecast_structure():
|
||||
"""Expected structure for forecast results"""
|
||||
return {
|
||||
"required_fields": [
|
||||
"forecast_date", "generated_at", "temperature", "precipitation",
|
||||
"humidity", "wind_speed", "description", "source"
|
||||
],
|
||||
"field_types": {
|
||||
"forecast_date": datetime,
|
||||
"generated_at": datetime,
|
||||
"temperature": (int, float),
|
||||
"precipitation": (int, float),
|
||||
"humidity": (int, float),
|
||||
"wind_speed": (int, float),
|
||||
"description": str,
|
||||
"source": str
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_historical_structure():
|
||||
"""Expected structure for historical weather results"""
|
||||
return {
|
||||
"required_fields": [
|
||||
"date", "temperature", "precipitation", "humidity",
|
||||
"wind_speed", "pressure", "description", "source"
|
||||
],
|
||||
"field_types": {
|
||||
"date": datetime,
|
||||
"temperature": (int, float, type(None)),
|
||||
"precipitation": (int, float),
|
||||
"humidity": (int, float, type(None)),
|
||||
"wind_speed": (int, float, type(None)),
|
||||
"pressure": (int, float, type(None)),
|
||||
"description": str,
|
||||
"source": str
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ================================================================
|
||||
# MOCK AND PATCH FIXTURES
|
||||
# ================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def mock_successful_api_calls():
|
||||
"""Mock successful AEMET API calls"""
|
||||
def _mock_api_calls(client, response_data, fetch_data):
|
||||
with patch.object(client, '_get', new_callable=AsyncMock) as mock_get, \
|
||||
patch.object(client, '_fetch_from_url', new_callable=AsyncMock) as mock_fetch:
|
||||
|
||||
mock_get.return_value = response_data
|
||||
mock_fetch.return_value = fetch_data
|
||||
|
||||
return mock_get, mock_fetch
|
||||
|
||||
return _mock_api_calls
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_failed_api_calls():
|
||||
"""Mock failed AEMET API calls"""
|
||||
def _mock_failed_calls(client, error_type="network"):
|
||||
if error_type == "network":
|
||||
return patch.object(client, '_get', side_effect=Exception("Network error"))
|
||||
elif error_type == "timeout":
|
||||
return patch.object(client, '_get', side_effect=asyncio.TimeoutError("Request timeout"))
|
||||
elif error_type == "invalid_response":
|
||||
return patch.object(client, '_get', new_callable=AsyncMock, return_value=None)
|
||||
else:
|
||||
return patch.object(client, '_get', new_callable=AsyncMock, return_value={"error": "API error"})
|
||||
|
||||
return _mock_failed_calls
|
||||
|
||||
|
||||
# ================================================================
|
||||
# VALIDATION HELPER FIXTURES
|
||||
# ================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def weather_data_validator():
|
||||
"""Weather data validation helper functions"""
|
||||
|
||||
def validate_weather_record(record: Dict[str, Any], expected_structure: Dict[str, Any]) -> None:
|
||||
"""Validate a weather record against expected structure"""
|
||||
# Check required fields
|
||||
for field in expected_structure["required_fields"]:
|
||||
assert field in record, f"Missing required field: {field}"
|
||||
|
||||
# Check field types
|
||||
for field, expected_type in expected_structure["field_types"].items():
|
||||
if field in record and record[field] is not None:
|
||||
assert isinstance(record[field], expected_type), f"Field {field} has wrong type: {type(record[field])}"
|
||||
|
||||
# Check valid ranges where applicable
|
||||
if "valid_ranges" in expected_structure:
|
||||
for field, (min_val, max_val) in expected_structure["valid_ranges"].items():
|
||||
if field in record and record[field] is not None:
|
||||
value = record[field]
|
||||
assert min_val <= value <= max_val, f"Field {field} value {value} outside valid range [{min_val}, {max_val}]"
|
||||
|
||||
def validate_weather_list(records: List[Dict[str, Any]], expected_structure: Dict[str, Any]) -> None:
|
||||
"""Validate a list of weather records"""
|
||||
assert isinstance(records, list), "Records should be a list"
|
||||
|
||||
for i, record in enumerate(records):
|
||||
try:
|
||||
validate_weather_record(record, expected_structure)
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"Record {i} validation failed: {e}")
|
||||
|
||||
def validate_date_sequence(records: List[Dict[str, Any]], date_field: str = "date") -> None:
|
||||
"""Validate that dates in records are in chronological order"""
|
||||
dates = [r[date_field] for r in records if date_field in r and r[date_field] is not None]
|
||||
|
||||
if len(dates) > 1:
|
||||
assert dates == sorted(dates), "Dates should be in chronological order"
|
||||
|
||||
return {
|
||||
"validate_record": validate_weather_record,
|
||||
"validate_list": validate_weather_list,
|
||||
"validate_dates": validate_date_sequence
|
||||
}
|
||||
|
||||
|
||||
# ================================================================
|
||||
# PERFORMANCE TESTING FIXTURES
|
||||
# ================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def performance_tracker():
|
||||
"""Performance tracking utilities for tests"""
|
||||
|
||||
class PerformanceTracker:
|
||||
def __init__(self):
|
||||
self.start_time = None
|
||||
self.measurements = {}
|
||||
|
||||
def start(self, operation_name: str = "default"):
|
||||
self.start_time = datetime.now()
|
||||
self.operation_name = operation_name
|
||||
|
||||
def stop(self) -> float:
|
||||
if self.start_time:
|
||||
duration = (datetime.now() - self.start_time).total_seconds() * 1000
|
||||
self.measurements[self.operation_name] = duration
|
||||
return duration
|
||||
return 0.0
|
||||
|
||||
def assert_performance(self, max_duration_ms: float, operation_name: str = "default"):
|
||||
duration = self.measurements.get(operation_name, float('inf'))
|
||||
assert duration <= max_duration_ms, f"Operation {operation_name} took {duration:.0f}ms, expected <= {max_duration_ms}ms"
|
||||
|
||||
return PerformanceTracker()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# INTEGRATION TEST FIXTURES
|
||||
# ================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def integration_test_config():
|
||||
"""Configuration for integration tests"""
|
||||
return {
|
||||
"api_timeout_ms": 5000,
|
||||
"max_retries": 3,
|
||||
"test_api_key": os.getenv("AEMET_API_KEY_TEST", ""),
|
||||
"skip_real_api_tests": os.getenv("SKIP_REAL_API_TESTS", "false").lower() == "true",
|
||||
"madrid_test_coords": (40.4168, -3.7038),
|
||||
"performance_thresholds": {
|
||||
"current_weather_ms": 5000,
|
||||
"forecast_ms": 5000,
|
||||
"historical_ms": 10000
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ================================================================
|
||||
# TEST REPORTING FIXTURES
|
||||
# ================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def test_reporter():
|
||||
"""Test reporting utilities"""
|
||||
|
||||
class TestReporter:
|
||||
def __init__(self):
|
||||
self.results = []
|
||||
|
||||
def log_success(self, test_name: str, details: str = ""):
|
||||
message = f"✅ {test_name}"
|
||||
if details:
|
||||
message += f" - {details}"
|
||||
print(message)
|
||||
self.results.append({"test": test_name, "status": "PASS", "details": details})
|
||||
|
||||
def log_failure(self, test_name: str, error: str = ""):
|
||||
message = f"❌ {test_name}"
|
||||
if error:
|
||||
message += f" - {error}"
|
||||
print(message)
|
||||
self.results.append({"test": test_name, "status": "FAIL", "error": error})
|
||||
|
||||
def log_info(self, test_name: str, info: str = ""):
|
||||
message = f"ℹ️ {test_name}"
|
||||
if info:
|
||||
message += f" - {info}"
|
||||
print(message)
|
||||
self.results.append({"test": test_name, "status": "INFO", "info": info})
|
||||
|
||||
def summary(self):
|
||||
passed = len([r for r in self.results if r["status"] == "PASS"])
|
||||
failed = len([r for r in self.results if r["status"] == "FAIL"])
|
||||
print(f"\n📊 Test Summary: {passed} passed, {failed} failed")
|
||||
return passed, failed
|
||||
|
||||
return TestReporter()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# CLEANUP FIXTURES
|
||||
# ================================================================
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_after_test():
|
||||
"""Automatic cleanup after each test"""
|
||||
yield
|
||||
# Add any cleanup logic here
|
||||
# For example, clearing caches, resetting global state, etc.
|
||||
pass
|
||||
|
||||
|
||||
# ================================================================
|
||||
# HELPER FUNCTIONS
|
||||
# ================================================================
|
||||
|
||||
def assert_weather_data_structure(data: Dict[str, Any], data_type: str = "current"):
|
||||
"""Assert that weather data has the correct structure"""
|
||||
if data_type == "current":
|
||||
required_fields = ["date", "temperature", "precipitation", "humidity", "wind_speed", "pressure", "description", "source"]
|
||||
elif data_type == "forecast":
|
||||
required_fields = ["forecast_date", "generated_at", "temperature", "precipitation", "humidity", "wind_speed", "description", "source"]
|
||||
elif data_type == "historical":
|
||||
required_fields = ["date", "temperature", "precipitation", "humidity", "wind_speed", "pressure", "description", "source"]
|
||||
else:
|
||||
raise ValueError(f"Unknown data type: {data_type}")
|
||||
|
||||
for field in required_fields:
|
||||
assert field in data, f"Missing required field: {field}"
|
||||
|
||||
# Validate source
|
||||
valid_sources = [WeatherSource.AEMET.value, WeatherSource.SYNTHETIC.value, WeatherSource.DEFAULT.value]
|
||||
assert data["source"] in valid_sources, f"Invalid source: {data['source']}"
|
||||
|
||||
|
||||
def assert_forecast_list_structure(forecast_list: List[Dict[str, Any]], expected_days: int):
|
||||
"""Assert that forecast list has correct structure"""
|
||||
assert isinstance(forecast_list, list), "Forecast should be a list"
|
||||
assert len(forecast_list) == expected_days, f"Expected {expected_days} forecast days, got {len(forecast_list)}"
|
||||
|
||||
for i, day in enumerate(forecast_list):
|
||||
assert_weather_data_structure(day, "forecast")
|
||||
|
||||
# Check date progression
|
||||
if len(forecast_list) > 1:
|
||||
for i in range(1, len(forecast_list)):
|
||||
prev_date = forecast_list[i-1]["forecast_date"]
|
||||
curr_date = forecast_list[i]["forecast_date"]
|
||||
date_diff = (curr_date - prev_date).days
|
||||
assert date_diff == 1, f"Forecast dates should be consecutive, got {date_diff} day difference"
|
||||
|
||||
|
||||
def assert_historical_list_structure(historical_list: List[Dict[str, Any]]):
|
||||
"""Assert that historical list has correct structure"""
|
||||
assert isinstance(historical_list, list), "Historical data should be a list"
|
||||
|
||||
for i, record in enumerate(historical_list):
|
||||
assert_weather_data_structure(record, "historical")
|
||||
|
||||
# Check date ordering
|
||||
dates = [r["date"] for r in historical_list if "date" in r]
|
||||
if len(dates) > 1:
|
||||
assert dates == sorted(dates), "Historical dates should be in chronological order"
|
||||
@@ -1,44 +0,0 @@
|
||||
[tool:pytest]
|
||||
# pytest.ini - Configuration file for AEMET tests
|
||||
|
||||
# Minimum version requirements
|
||||
minversion = 6.0
|
||||
|
||||
# Add options
|
||||
addopts =
|
||||
-ra
|
||||
--strict-markers
|
||||
--strict-config
|
||||
--disable-warnings
|
||||
--tb=short
|
||||
-v
|
||||
|
||||
# Test discovery
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
|
||||
# Async support
|
||||
asyncio_mode = auto
|
||||
|
||||
# Markers
|
||||
markers =
|
||||
unit: Unit tests
|
||||
integration: Integration tests
|
||||
api: API tests
|
||||
performance: Performance tests
|
||||
slow: Slow tests
|
||||
asyncio: Async tests
|
||||
|
||||
# Logging
|
||||
log_cli = true
|
||||
log_cli_level = INFO
|
||||
log_cli_format = %(asctime)s [%(levelname)8s] %(name)s: %(message)s
|
||||
log_cli_date_format = %Y-%m-%d %H:%M:%S
|
||||
|
||||
# Filtering
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::PendingDeprecationWarning
|
||||
ignore::PytestUnhandledCoroutineWarning
|
||||
@@ -1,677 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/tests/test_aemet.py
|
||||
# ================================================================
|
||||
"""
|
||||
Comprehensive test suite for AEMET weather API client
|
||||
Following the same patterns as test_madrid_opendata.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
import math
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from app.external.aemet import (
|
||||
AEMETClient,
|
||||
WeatherDataParser,
|
||||
SyntheticWeatherGenerator,
|
||||
LocationService,
|
||||
AEMETConstants,
|
||||
WeatherSource,
|
||||
WeatherStation,
|
||||
GeographicBounds
|
||||
)
|
||||
|
||||
# Configure pytest-asyncio
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class TestAEMETClient:
|
||||
"""Main test class for AEMET API client functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
"""Create AEMET client instance for testing"""
|
||||
return AEMETClient()
|
||||
|
||||
@pytest.fixture
|
||||
def madrid_coords(self):
|
||||
"""Standard Madrid coordinates for testing"""
|
||||
return (40.4168, -3.7038) # Madrid city center
|
||||
|
||||
@pytest.fixture
|
||||
def mock_aemet_response(self):
|
||||
"""Mock AEMET API response structure"""
|
||||
return {
|
||||
"datos": "https://opendata.aemet.es/opendata/sh/12345",
|
||||
"metadatos": "https://opendata.aemet.es/opendata/sh/metadata"
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_weather_data(self):
|
||||
"""Mock current weather data from AEMET"""
|
||||
return {
|
||||
"ta": 18.5, # Temperature
|
||||
"prec": 0.0, # Precipitation
|
||||
"hr": 65.0, # Humidity
|
||||
"vv": 12.0, # Wind speed
|
||||
"pres": 1015.2, # Pressure
|
||||
"descripcion": "Despejado"
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_forecast_data(self):
|
||||
"""Mock forecast data from AEMET"""
|
||||
return [{
|
||||
"prediccion": {
|
||||
"dia": [
|
||||
{
|
||||
"fecha": "2025-07-25T00:00:00",
|
||||
"temperatura": {
|
||||
"maxima": 28,
|
||||
"minima": 15
|
||||
},
|
||||
"probPrecipitacion": [
|
||||
{"value": 10, "periodo": "00-24"}
|
||||
],
|
||||
"viento": [
|
||||
{"velocidad": [15], "direccion": ["N"]}
|
||||
]
|
||||
},
|
||||
{
|
||||
"fecha": "2025-07-26T00:00:00",
|
||||
"temperatura": {
|
||||
"maxima": 30,
|
||||
"minima": 17
|
||||
},
|
||||
"probPrecipitacion": [
|
||||
{"value": 5, "periodo": "00-24"}
|
||||
],
|
||||
"viento": [
|
||||
{"velocidad": [10], "direccion": ["NE"]}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}]
|
||||
|
||||
@pytest.fixture
|
||||
def mock_historical_data(self):
|
||||
"""Mock historical weather data from AEMET"""
|
||||
return [
|
||||
{
|
||||
"fecha": "2025-07-20",
|
||||
"tmax": 25.2,
|
||||
"tmin": 14.8,
|
||||
"prec": 0.0,
|
||||
"hr": 58,
|
||||
"velmedia": 8.5,
|
||||
"presMax": 1018.5,
|
||||
"presMin": 1012.3
|
||||
},
|
||||
{
|
||||
"fecha": "2025-07-21",
|
||||
"tmax": 27.1,
|
||||
"tmin": 16.2,
|
||||
"prec": 2.5,
|
||||
"hr": 72,
|
||||
"velmedia": 12.0,
|
||||
"presMax": 1015.8,
|
||||
"presMin": 1010.1
|
||||
}
|
||||
]
|
||||
|
||||
# ================================================================
|
||||
# CURRENT WEATHER TESTS
|
||||
# ================================================================
|
||||
|
||||
async def test_get_current_weather_success(self, client, madrid_coords, mock_aemet_response, mock_weather_data):
|
||||
"""Test successful current weather retrieval"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
with patch.object(client, '_get', new_callable=AsyncMock) as mock_get, \
|
||||
patch.object(client, '_fetch_from_url', new_callable=AsyncMock) as mock_fetch:
|
||||
|
||||
mock_get.return_value = mock_aemet_response
|
||||
mock_fetch.return_value = [mock_weather_data]
|
||||
|
||||
result = await client.get_current_weather(lat, lon)
|
||||
|
||||
# Validate result structure
|
||||
assert result is not None, "Should return weather data"
|
||||
assert isinstance(result, dict), "Result should be a dictionary"
|
||||
|
||||
# Check required fields
|
||||
required_fields = ['date', 'temperature', 'precipitation', 'humidity', 'wind_speed', 'pressure', 'description', 'source']
|
||||
for field in required_fields:
|
||||
assert field in result, f"Missing required field: {field}"
|
||||
|
||||
# Validate data types and ranges
|
||||
assert isinstance(result['temperature'], float), "Temperature should be float"
|
||||
assert -20 <= result['temperature'] <= 50, "Temperature should be reasonable"
|
||||
assert isinstance(result['precipitation'], float), "Precipitation should be float"
|
||||
assert result['precipitation'] >= 0, "Precipitation should be non-negative"
|
||||
assert 0 <= result['humidity'] <= 100, "Humidity should be percentage"
|
||||
assert result['wind_speed'] >= 0, "Wind speed should be non-negative"
|
||||
assert result['pressure'] > 900, "Pressure should be reasonable"
|
||||
assert result['source'] == WeatherSource.AEMET.value, "Source should be AEMET"
|
||||
|
||||
print(f"✅ Current weather test passed - Temp: {result['temperature']}°C, Source: {result['source']}")
|
||||
|
||||
async def test_get_current_weather_fallback_to_synthetic(self, client, madrid_coords):
|
||||
"""Test fallback to synthetic data when AEMET API fails"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
with patch.object(client, '_get', new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = None # Simulate API failure
|
||||
|
||||
result = await client.get_current_weather(lat, lon)
|
||||
|
||||
assert result is not None, "Should return synthetic data"
|
||||
assert result['source'] == WeatherSource.SYNTHETIC.value, "Should use synthetic source"
|
||||
assert isinstance(result['temperature'], float), "Temperature should be float"
|
||||
|
||||
print(f"✅ Synthetic fallback test passed - Source: {result['source']}")
|
||||
|
||||
async def test_get_current_weather_invalid_coordinates(self, client):
|
||||
"""Test current weather with invalid coordinates"""
|
||||
invalid_coords = [
|
||||
(200, 200), # Out of range
|
||||
(-200, -200), # Out of range
|
||||
(0, 0), # Not in Madrid area
|
||||
]
|
||||
|
||||
for lat, lon in invalid_coords:
|
||||
result = await client.get_current_weather(lat, lon)
|
||||
|
||||
# Should still return data (synthetic)
|
||||
assert result is not None, f"Should handle invalid coords ({lat}, {lon})"
|
||||
assert result['source'] == WeatherSource.SYNTHETIC.value, "Should use synthetic for invalid coords"
|
||||
|
||||
print(f"✅ Invalid coordinates test passed")
|
||||
|
||||
# ================================================================
|
||||
# FORECAST TESTS
|
||||
# ================================================================
|
||||
|
||||
async def test_get_forecast_success(self, client, madrid_coords, mock_aemet_response, mock_forecast_data):
|
||||
"""Test successful weather forecast retrieval"""
|
||||
lat, lon = madrid_coords
|
||||
days = 7
|
||||
|
||||
with patch.object(client, '_get', new_callable=AsyncMock) as mock_get, \
|
||||
patch.object(client, '_fetch_from_url', new_callable=AsyncMock) as mock_fetch:
|
||||
|
||||
mock_get.return_value = mock_aemet_response
|
||||
mock_fetch.return_value = mock_forecast_data
|
||||
|
||||
result = await client.get_forecast(lat, lon, days)
|
||||
|
||||
# Validate result structure
|
||||
assert isinstance(result, list), "Result should be a list"
|
||||
assert len(result) == days, f"Should return {days} forecast days"
|
||||
|
||||
# Check first forecast day
|
||||
if result:
|
||||
forecast_day = result[0]
|
||||
|
||||
required_fields = ['forecast_date', 'generated_at', 'temperature', 'precipitation', 'humidity', 'wind_speed', 'description', 'source']
|
||||
for field in required_fields:
|
||||
assert field in forecast_day, f"Missing required field: {field}"
|
||||
|
||||
# Validate data types
|
||||
assert isinstance(forecast_day['forecast_date'], datetime), "Forecast date should be datetime"
|
||||
assert isinstance(forecast_day['temperature'], (int, float)), "Temperature should be numeric"
|
||||
assert isinstance(forecast_day['precipitation'], (int, float)), "Precipitation should be numeric"
|
||||
assert forecast_day['source'] in [WeatherSource.AEMET.value, WeatherSource.SYNTHETIC.value], "Valid source"
|
||||
|
||||
print(f"✅ Forecast test passed - {len(result)} days, Source: {forecast_day['source']}")
|
||||
|
||||
async def test_get_forecast_different_durations(self, client, madrid_coords):
|
||||
"""Test forecast for different time durations"""
|
||||
lat, lon = madrid_coords
|
||||
test_durations = [1, 3, 7, 14]
|
||||
|
||||
for days in test_durations:
|
||||
result = await client.get_forecast(lat, lon, days)
|
||||
|
||||
assert isinstance(result, list), f"Result should be list for {days} days"
|
||||
assert len(result) == days, f"Should return exactly {days} forecast days"
|
||||
|
||||
# Check date progression
|
||||
if len(result) > 1:
|
||||
for i in range(1, len(result)):
|
||||
date_diff = result[i]['forecast_date'] - result[i-1]['forecast_date']
|
||||
assert date_diff.days == 1, "Forecast dates should be consecutive days"
|
||||
|
||||
print(f"✅ Multiple duration forecast test passed")
|
||||
|
||||
async def test_get_forecast_fallback_to_synthetic(self, client, madrid_coords):
|
||||
"""Test forecast fallback to synthetic data"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
with patch.object(client.location_service, 'get_municipality_code') as mock_municipality:
|
||||
mock_municipality.return_value = None # No municipality found
|
||||
|
||||
result = await client.get_forecast(lat, lon, 7)
|
||||
|
||||
assert isinstance(result, list), "Should return synthetic forecast"
|
||||
assert len(result) == 7, "Should return 7 days"
|
||||
assert all(day['source'] == WeatherSource.SYNTHETIC.value for day in result), "All should be synthetic"
|
||||
|
||||
print(f"✅ Forecast synthetic fallback test passed")
|
||||
|
||||
# ================================================================
|
||||
# HISTORICAL WEATHER TESTS
|
||||
# ================================================================
|
||||
|
||||
async def test_get_historical_weather_success(self, client, madrid_coords, mock_aemet_response, mock_historical_data):
|
||||
"""Test successful historical weather retrieval"""
|
||||
lat, lon = madrid_coords
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=7)
|
||||
|
||||
with patch.object(client, '_get', new_callable=AsyncMock) as mock_get, \
|
||||
patch.object(client, '_fetch_from_url', new_callable=AsyncMock) as mock_fetch:
|
||||
|
||||
mock_get.return_value = mock_aemet_response
|
||||
mock_fetch.return_value = mock_historical_data
|
||||
|
||||
result = await client.get_historical_weather(lat, lon, start_date, end_date)
|
||||
|
||||
# Validate result structure
|
||||
assert isinstance(result, list), "Result should be a list"
|
||||
assert len(result) > 0, "Should return historical data"
|
||||
|
||||
# Check first historical record
|
||||
if result:
|
||||
record = result[0]
|
||||
|
||||
required_fields = ['date', 'temperature', 'precipitation', 'humidity', 'wind_speed', 'pressure', 'description', 'source']
|
||||
for field in required_fields:
|
||||
assert field in record, f"Missing required field: {field}"
|
||||
|
||||
# Validate data types and ranges
|
||||
assert isinstance(record['date'], datetime), "Date should be datetime"
|
||||
assert isinstance(record['temperature'], (int, float, type(None))), "Temperature should be numeric or None"
|
||||
if record['temperature']:
|
||||
assert -30 <= record['temperature'] <= 50, "Temperature should be reasonable"
|
||||
assert record['precipitation'] >= 0, "Precipitation should be non-negative"
|
||||
assert record['source'] == WeatherSource.AEMET.value, "Source should be AEMET"
|
||||
|
||||
print(f"✅ Historical weather test passed - {len(result)} records, Source: {record['source']}")
|
||||
|
||||
async def test_get_historical_weather_date_ranges(self, client, madrid_coords):
|
||||
"""Test historical weather with different date ranges"""
|
||||
lat, lon = madrid_coords
|
||||
end_date = datetime.now()
|
||||
|
||||
test_ranges = [
|
||||
1, # 1 day
|
||||
7, # 1 week
|
||||
30, # 1 month
|
||||
90, # 3 months
|
||||
]
|
||||
|
||||
for days in test_ranges:
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
result = await client.get_historical_weather(lat, lon, start_date, end_date)
|
||||
|
||||
assert isinstance(result, list), f"Result should be list for {days} days"
|
||||
# Note: Actual count may vary due to chunking and data availability
|
||||
assert len(result) >= 0, f"Should return non-negative count for {days} days"
|
||||
|
||||
if result:
|
||||
# Check date ordering
|
||||
dates = [r['date'] for r in result if 'date' in r]
|
||||
if len(dates) > 1:
|
||||
assert dates == sorted(dates), "Historical dates should be in chronological order"
|
||||
|
||||
print(f"✅ Historical date ranges test passed")
|
||||
|
||||
async def test_get_historical_weather_chunking(self, client, madrid_coords):
|
||||
"""Test historical weather data chunking for large date ranges"""
|
||||
lat, lon = madrid_coords
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=65) # More than 30 days to trigger chunking
|
||||
|
||||
with patch.object(client, '_fetch_historical_chunk', new_callable=AsyncMock) as mock_chunk:
|
||||
mock_chunk.return_value = [] # Empty chunks
|
||||
|
||||
result = await client.get_historical_weather(lat, lon, start_date, end_date)
|
||||
|
||||
# Should have called chunking at least twice (65 days > 30 day limit)
|
||||
assert mock_chunk.call_count >= 2, "Should chunk large date ranges"
|
||||
|
||||
print(f"✅ Historical chunking test passed - {mock_chunk.call_count} chunks")
|
||||
|
||||
# ================================================================
|
||||
# COMPONENT TESTS
|
||||
# ================================================================
|
||||
@pytest.mark.skip_asyncio
|
||||
def test_weather_data_parser(self):
|
||||
"""Test WeatherDataParser functionality"""
|
||||
parser = WeatherDataParser()
|
||||
|
||||
# Test safe_float
|
||||
assert parser.safe_float("15.5", 0.0) == 15.5
|
||||
assert parser.safe_float(None, 10.0) == 10.0
|
||||
assert parser.safe_float("invalid", 5.0) == 5.0
|
||||
assert parser.safe_float(20) == 20.0
|
||||
|
||||
# Test extract_temperature_value
|
||||
assert parser.extract_temperature_value(25.5) == 25.5
|
||||
assert parser.extract_temperature_value("20.0") == 20.0
|
||||
assert parser.extract_temperature_value({"valor": 18.5}) == 18.5
|
||||
assert parser.extract_temperature_value([{"valor": 22.0}]) == 22.0
|
||||
assert parser.extract_temperature_value(None) is None
|
||||
|
||||
# Test generate_weather_description
|
||||
assert "Lluvioso" in parser.generate_weather_description(20, 6.0, 60)
|
||||
assert "Nuboso con lluvia" in parser.generate_weather_description(20, 1.0, 60)
|
||||
assert "Nuboso" in parser.generate_weather_description(20, 0, 85)
|
||||
assert "Soleado y cálido" in parser.generate_weather_description(30, 0, 60)
|
||||
assert "Frío" in parser.generate_weather_description(2, 0, 60)
|
||||
|
||||
print(f"✅ WeatherDataParser tests passed")
|
||||
|
||||
@pytest.mark.skip_asyncio
|
||||
def test_synthetic_weather_generator(self):
|
||||
"""Test SyntheticWeatherGenerator functionality"""
|
||||
generator = SyntheticWeatherGenerator()
|
||||
|
||||
# Test current weather generation
|
||||
current = generator.generate_current_weather()
|
||||
|
||||
assert isinstance(current, dict), "Should return dictionary"
|
||||
assert 'temperature' in current, "Should have temperature"
|
||||
assert 'precipitation' in current, "Should have precipitation"
|
||||
assert current['source'] == WeatherSource.SYNTHETIC.value, "Should be synthetic source"
|
||||
assert isinstance(current['date'], datetime), "Should have datetime"
|
||||
|
||||
# Test forecast generation
|
||||
forecast = generator.generate_forecast_sync(5)
|
||||
|
||||
assert isinstance(forecast, list), "Should return list"
|
||||
assert len(forecast) == 5, "Should return requested days"
|
||||
assert all('forecast_date' in day for day in forecast), "All days should have forecast_date"
|
||||
assert all(day['source'] == WeatherSource.SYNTHETIC.value for day in forecast), "All should be synthetic"
|
||||
|
||||
# Test historical generation
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=7)
|
||||
historical = generator.generate_historical_data(start_date, end_date)
|
||||
|
||||
assert isinstance(historical, list), "Should return list"
|
||||
assert len(historical) == 8, "Should return 8 days (inclusive)"
|
||||
assert all('date' in day for day in historical), "All days should have date"
|
||||
assert all(day['source'] == WeatherSource.SYNTHETIC.value for day in historical), "All should be synthetic"
|
||||
|
||||
print(f"✅ SyntheticWeatherGenerator tests passed")
|
||||
|
||||
@pytest.mark.skip_asyncio
|
||||
def test_location_service(self):
|
||||
"""Test LocationService functionality"""
|
||||
# Test distance calculation
|
||||
madrid_center = (40.4168, -3.7038)
|
||||
madrid_north = (40.4677, -3.5552)
|
||||
|
||||
distance = LocationService.calculate_distance(
|
||||
madrid_center[0], madrid_center[1],
|
||||
madrid_north[0], madrid_north[1]
|
||||
)
|
||||
|
||||
assert isinstance(distance, float), "Distance should be float"
|
||||
assert 0 < distance < 50, "Distance should be reasonable for Madrid area"
|
||||
|
||||
# Test nearest station finding
|
||||
station_id = LocationService.find_nearest_station(madrid_center[0], madrid_center[1])
|
||||
|
||||
assert station_id is not None, "Should find a station"
|
||||
assert station_id in [station.id for station in AEMETConstants.MADRID_STATIONS], "Should be valid station"
|
||||
|
||||
# Test municipality code
|
||||
municipality = LocationService.get_municipality_code(madrid_center[0], madrid_center[1])
|
||||
assert municipality == AEMETConstants.MADRID_MUNICIPALITY_CODE, "Should return Madrid code"
|
||||
|
||||
# Test outside Madrid
|
||||
outside_madrid = LocationService.get_municipality_code(41.0, -4.0) # Outside bounds
|
||||
assert outside_madrid is None, "Should return None for outside Madrid"
|
||||
|
||||
print(f"✅ LocationService tests passed")
|
||||
|
||||
@pytest.mark.skip_asyncio
|
||||
def test_constants_and_enums(self):
|
||||
"""Test constants and enum definitions"""
|
||||
# Test WeatherSource enum
|
||||
assert WeatherSource.AEMET.value == "aemet"
|
||||
assert WeatherSource.SYNTHETIC.value == "synthetic"
|
||||
assert WeatherSource.DEFAULT.value == "default"
|
||||
|
||||
# Test GeographicBounds
|
||||
bounds = AEMETConstants.MADRID_BOUNDS
|
||||
assert bounds.contains(40.4168, -3.7038), "Should contain Madrid center"
|
||||
assert not bounds.contains(41.0, -4.0), "Should not contain coordinates outside Madrid"
|
||||
|
||||
# Test WeatherStation
|
||||
station = AEMETConstants.MADRID_STATIONS[0]
|
||||
assert isinstance(station, WeatherStation), "Should be WeatherStation instance"
|
||||
assert station.id is not None, "Station should have ID"
|
||||
assert station.name is not None, "Station should have name"
|
||||
|
||||
print(f"✅ Constants and enums tests passed")
|
||||
|
||||
# ================================================================
|
||||
# ERROR HANDLING TESTS
|
||||
# ================================================================
|
||||
|
||||
async def test_api_error_handling(self, client, madrid_coords):
|
||||
"""Test handling of various API errors"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Test network error
|
||||
with patch.object(client, '_get', side_effect=Exception("Network error")):
|
||||
result = await client.get_current_weather(lat, lon)
|
||||
assert result['source'] == WeatherSource.SYNTHETIC.value, "Should fallback on network error"
|
||||
|
||||
# Test invalid API response
|
||||
with patch.object(client, '_get', new_callable=AsyncMock) as mock_get:
|
||||
mock_get.return_value = {"error": "Invalid API key"}
|
||||
result = await client.get_current_weather(lat, lon)
|
||||
assert result['source'] == WeatherSource.SYNTHETIC.value, "Should fallback on API error"
|
||||
|
||||
# Test malformed data
|
||||
with patch.object(client, '_get', new_callable=AsyncMock) as mock_get, \
|
||||
patch.object(client, '_fetch_from_url', new_callable=AsyncMock) as mock_fetch:
|
||||
|
||||
mock_get.return_value = {"datos": "http://example.com"}
|
||||
mock_fetch.return_value = [{"invalid": "data"}] # Missing expected fields
|
||||
|
||||
result = await client.get_current_weather(lat, lon)
|
||||
assert result is not None, "Should handle malformed data gracefully"
|
||||
|
||||
print(f"✅ API error handling tests passed")
|
||||
|
||||
async def test_timeout_handling(self, client, madrid_coords):
|
||||
"""Test timeout handling"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
with patch.object(client, '_get', side_effect=asyncio.TimeoutError("Request timeout")):
|
||||
result = await client.get_current_weather(lat, lon)
|
||||
assert result['source'] == WeatherSource.SYNTHETIC.value, "Should fallback on timeout"
|
||||
|
||||
print(f"✅ Timeout handling test passed")
|
||||
|
||||
# ================================================================
|
||||
# PERFORMANCE TESTS
|
||||
# ================================================================
|
||||
|
||||
async def test_performance_current_weather(self, client, madrid_coords):
|
||||
"""Test current weather performance"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
start_time = datetime.now()
|
||||
result = await client.get_current_weather(lat, lon)
|
||||
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
assert result is not None, "Should return weather data"
|
||||
assert execution_time < 5000, "Should execute within 5 seconds"
|
||||
|
||||
print(f"✅ Current weather performance test passed - {execution_time:.0f}ms")
|
||||
|
||||
async def test_performance_forecast(self, client, madrid_coords):
|
||||
"""Test forecast performance"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
start_time = datetime.now()
|
||||
result = await client.get_forecast(lat, lon, 7)
|
||||
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
assert isinstance(result, list), "Should return forecast list"
|
||||
assert len(result) == 7, "Should return 7 days"
|
||||
assert execution_time < 5000, "Should execute within 5 seconds"
|
||||
|
||||
print(f"✅ Forecast performance test passed - {execution_time:.0f}ms")
|
||||
|
||||
async def test_performance_historical(self, client, madrid_coords):
|
||||
"""Test historical weather performance"""
|
||||
lat, lon = madrid_coords
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=7)
|
||||
|
||||
start_time = datetime.now()
|
||||
result = await client.get_historical_weather(lat, lon, start_date, end_date)
|
||||
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
assert isinstance(result, list), "Should return historical list"
|
||||
assert execution_time < 10000, "Should execute within 10 seconds (allowing for API calls)"
|
||||
|
||||
print(f"✅ Historical performance test passed - {execution_time:.0f}ms")
|
||||
|
||||
# ================================================================
|
||||
# INTEGRATION TESTS
|
||||
# ================================================================
|
||||
|
||||
async def test_real_aemet_api_access(self, client, madrid_coords):
|
||||
"""Test actual AEMET API access (if API key is available)"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
try:
|
||||
# Test current weather
|
||||
current_result = await client.get_current_weather(lat, lon)
|
||||
assert current_result is not None, "Should get current weather"
|
||||
|
||||
if current_result['source'] == WeatherSource.AEMET.value:
|
||||
print(f"🎉 SUCCESS: Got real AEMET current weather data!")
|
||||
print(f" Temperature: {current_result['temperature']}°C")
|
||||
print(f" Description: {current_result['description']}")
|
||||
else:
|
||||
print(f"ℹ️ Got synthetic current weather (API key may not be configured)")
|
||||
|
||||
# Test forecast
|
||||
forecast_result = await client.get_forecast(lat, lon, 3)
|
||||
assert len(forecast_result) == 3, "Should get 3-day forecast"
|
||||
|
||||
if forecast_result[0]['source'] == WeatherSource.AEMET.value:
|
||||
print(f"🎉 SUCCESS: Got real AEMET forecast data!")
|
||||
print(f" Tomorrow: {forecast_result[1]['temperature']}°C - {forecast_result[1]['description']}")
|
||||
else:
|
||||
print(f"ℹ️ Got synthetic forecast (API key may not be configured)")
|
||||
|
||||
# Test historical (last week)
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=7)
|
||||
historical_result = await client.get_historical_weather(lat, lon, start_date, end_date)
|
||||
|
||||
assert isinstance(historical_result, list), "Should get historical data"
|
||||
|
||||
real_historical = [r for r in historical_result if r['source'] == WeatherSource.AEMET.value]
|
||||
if real_historical:
|
||||
print(f"🎉 SUCCESS: Got real AEMET historical data!")
|
||||
print(f" Records: {len(real_historical)} real + {len(historical_result) - len(real_historical)} synthetic")
|
||||
else:
|
||||
print(f"ℹ️ Got synthetic historical data (API limitations or key issues)")
|
||||
|
||||
print(f"✅ Real AEMET API integration test completed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ AEMET API integration test failed: {e}")
|
||||
# This is acceptable if API key is not configured
|
||||
|
||||
async def test_data_consistency(self, client, madrid_coords):
|
||||
"""Test data consistency across different methods"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Get current weather
|
||||
current = await client.get_current_weather(lat, lon)
|
||||
|
||||
# Get today's forecast
|
||||
forecast = await client.get_forecast(lat, lon, 1)
|
||||
today_forecast = forecast[0] if forecast else None
|
||||
|
||||
if current and today_forecast:
|
||||
# Temperature should be somewhat consistent
|
||||
temp_diff = abs(current['temperature'] - today_forecast['temperature'])
|
||||
assert temp_diff < 15, "Current and forecast temperature should be reasonably consistent"
|
||||
|
||||
# Both should use same source type preference
|
||||
if current['source'] == WeatherSource.AEMET.value:
|
||||
assert today_forecast['source'] == WeatherSource.AEMET.value, "Should use consistent data sources"
|
||||
|
||||
print(f"✅ Data consistency test passed")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# STANDALONE TEST FUNCTIONS
|
||||
# ================================================================
|
||||
|
||||
async def run_manual_test():
|
||||
"""Manual test function that can be run directly"""
|
||||
print("="*60)
|
||||
print("AEMET WEATHER CLIENT TEST - JULY 2025")
|
||||
print("="*60)
|
||||
|
||||
client = AEMETClient()
|
||||
madrid_lat, madrid_lon = 40.4168, -3.7038 # Madrid center
|
||||
|
||||
print(f"\n=== Testing Madrid Weather ({madrid_lat}, {madrid_lon}) ===")
|
||||
|
||||
# Test current weather
|
||||
print(f"\n1. Testing Current Weather...")
|
||||
current = await client.get_current_weather(madrid_lat, madrid_lon)
|
||||
if current:
|
||||
print(f" Temperature: {current['temperature']}°C")
|
||||
print(f" Description: {current['description']}")
|
||||
print(f" Humidity: {current['humidity']}%")
|
||||
print(f" Wind: {current['wind_speed']} km/h")
|
||||
print(f" Source: {current['source']}")
|
||||
|
||||
# Test forecast
|
||||
print(f"\n2. Testing 7-Day Forecast...")
|
||||
forecast = await client.get_forecast(madrid_lat, madrid_lon, 7)
|
||||
if forecast:
|
||||
print(f" Forecast days: {len(forecast)}")
|
||||
print(f" Tomorrow: {forecast[1]['temperature']}°C - {forecast[1]['description']}")
|
||||
print(f" Source: {forecast[0]['source']}")
|
||||
|
||||
# Test historical
|
||||
print(f"\n3. Testing Historical Weather (last 7 days)...")
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=7)
|
||||
historical = await client.get_historical_weather(madrid_lat, madrid_lon, start_date, end_date)
|
||||
if historical:
|
||||
print(f" Historical records: {len(historical)}")
|
||||
if historical:
|
||||
real_count = len([r for r in historical if r['source'] == WeatherSource.AEMET.value])
|
||||
synthetic_count = len(historical) - real_count
|
||||
print(f" Real data: {real_count}, Synthetic: {synthetic_count}")
|
||||
|
||||
print(f"\n✅ Manual test completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# If run directly, execute manual test
|
||||
asyncio.run(run_manual_test())
|
||||
@@ -1,594 +0,0 @@
|
||||
# ================================================================
|
||||
# services/data/tests/test_aemet_edge_cases.py
|
||||
# ================================================================
|
||||
"""
|
||||
Edge cases and integration tests for AEMET weather API client
|
||||
Covers boundary conditions, error scenarios, and complex integrations
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
import json
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from app.external.aemet import (
|
||||
AEMETClient,
|
||||
WeatherDataParser,
|
||||
SyntheticWeatherGenerator,
|
||||
LocationService,
|
||||
AEMETConstants,
|
||||
WeatherSource
|
||||
)
|
||||
|
||||
# Configure pytest-asyncio
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
class TestAEMETEdgeCases:
|
||||
"""Test edge cases and boundary conditions"""
|
||||
|
||||
async def test_extreme_coordinates(self, aemet_client):
|
||||
"""Test handling of extreme coordinate values"""
|
||||
extreme_coords = [
|
||||
(90, 180), # North pole, antimeridian
|
||||
(-90, -180), # South pole, antimeridian
|
||||
(0, 0), # Null island
|
||||
(40.5, -180), # Valid latitude, extreme longitude
|
||||
(90, -3.7), # Extreme latitude, Madrid longitude
|
||||
]
|
||||
|
||||
for lat, lon in extreme_coords:
|
||||
result = await aemet_client.get_current_weather(lat, lon)
|
||||
|
||||
assert result is not None, f"Should handle extreme coords ({lat}, {lon})"
|
||||
assert result['source'] == WeatherSource.SYNTHETIC.value, "Should fallback to synthetic for extreme coords"
|
||||
assert isinstance(result['temperature'], (int, float)), "Should have valid temperature"
|
||||
|
||||
async def test_boundary_date_ranges(self, aemet_client, madrid_coords):
|
||||
"""Test boundary conditions for date ranges"""
|
||||
lat, lon = madrid_coords
|
||||
now = datetime.now()
|
||||
|
||||
# Test same start and end date
|
||||
result = await aemet_client.get_historical_weather(lat, lon, now, now)
|
||||
assert isinstance(result, list), "Should return list for same-day request"
|
||||
|
||||
# Test reverse date range (end before start)
|
||||
start_date = now
|
||||
end_date = now - timedelta(days=1)
|
||||
result = await aemet_client.get_historical_weather(lat, lon, start_date, end_date)
|
||||
assert isinstance(result, list), "Should handle reverse date range gracefully"
|
||||
|
||||
# Test extremely large date range
|
||||
start_date = now - timedelta(days=1000)
|
||||
end_date = now
|
||||
result = await aemet_client.get_historical_weather(lat, lon, start_date, end_date)
|
||||
assert isinstance(result, list), "Should handle very large date ranges"
|
||||
|
||||
async def test_forecast_edge_durations(self, aemet_client, madrid_coords):
|
||||
"""Test forecast with edge case durations"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
edge_durations = [0, 1, 30, 365, -1, 1000]
|
||||
|
||||
for days in edge_durations:
|
||||
try:
|
||||
result = await aemet_client.get_forecast(lat, lon, days)
|
||||
|
||||
if days <= 0:
|
||||
assert len(result) == 0 or result is None, f"Should handle non-positive days ({days})"
|
||||
elif days > 100:
|
||||
# Should handle gracefully, possibly with synthetic data
|
||||
assert isinstance(result, list), f"Should handle large day count ({days})"
|
||||
else:
|
||||
assert len(result) == days, f"Should return {days} forecast days"
|
||||
|
||||
except Exception as e:
|
||||
# Some edge cases might raise exceptions, which is acceptable
|
||||
print(f"ℹ️ Days={days} raised exception: {e}")
|
||||
|
||||
def test_parser_edge_cases(self, weather_parser):
|
||||
"""Test weather data parser with edge case inputs"""
|
||||
# Test with None values
|
||||
result = weather_parser.safe_float(None, 10.0)
|
||||
assert result == 10.0, "Should return default for None"
|
||||
|
||||
# Test with empty strings
|
||||
result = weather_parser.safe_float("", 5.0)
|
||||
assert result == 5.0, "Should return default for empty string"
|
||||
|
||||
# Test with extreme values
|
||||
result = weather_parser.safe_float("999999.99", 0.0)
|
||||
assert result == 999999.99, "Should handle large numbers"
|
||||
|
||||
result = weather_parser.safe_float("-999.99", 0.0)
|
||||
assert result == -999.99, "Should handle negative numbers"
|
||||
|
||||
# Test temperature extraction edge cases
|
||||
assert weather_parser.extract_temperature_value([]) is None, "Should handle empty list"
|
||||
assert weather_parser.extract_temperature_value({}) is None, "Should handle empty dict"
|
||||
assert weather_parser.extract_temperature_value("invalid") is None, "Should handle invalid string"
|
||||
|
||||
def test_synthetic_generator_edge_cases(self, synthetic_generator):
|
||||
"""Test synthetic weather generator edge cases"""
|
||||
# Test with extreme date ranges
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=1000)
|
||||
|
||||
result = synthetic_generator.generate_historical_data(start_date, end_date)
|
||||
assert isinstance(result, list), "Should handle large date ranges"
|
||||
assert len(result) == 1001, "Should generate correct number of days"
|
||||
|
||||
# Test forecast with zero days
|
||||
result = synthetic_generator.generate_forecast_sync(0)
|
||||
assert result == [], "Should return empty list for zero days"
|
||||
|
||||
# Test forecast with large number of days
|
||||
result = synthetic_generator.generate_forecast_sync(1000)
|
||||
assert len(result) == 1000, "Should handle large forecast ranges"
|
||||
|
||||
def test_location_service_edge_cases(self):
|
||||
"""Test location service edge cases"""
|
||||
# Test distance calculation with same points
|
||||
distance = LocationService.calculate_distance(40.4, -3.7, 40.4, -3.7)
|
||||
assert distance == 0.0, "Distance between same points should be zero"
|
||||
|
||||
# Test distance calculation with antipodal points
|
||||
distance = LocationService.calculate_distance(40.4, -3.7, -40.4, 176.3)
|
||||
assert distance > 15000, "Antipodal points should be far apart"
|
||||
|
||||
# Test station finding with no stations (if list were empty)
|
||||
with patch.object(AEMETConstants, 'MADRID_STATIONS', []):
|
||||
station = LocationService.find_nearest_station(40.4, -3.7)
|
||||
assert station is None, "Should return None when no stations available"
|
||||
|
||||
|
||||
class TestAEMETDataIntegrity:
|
||||
"""Test data integrity and consistency"""
|
||||
|
||||
async def test_data_type_consistency(self, aemet_client, madrid_coords):
|
||||
"""Test that data types are consistent across calls"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Get current weather multiple times
|
||||
results = []
|
||||
for _ in range(3):
|
||||
result = await aemet_client.get_current_weather(lat, lon)
|
||||
results.append(result)
|
||||
|
||||
# Check that field types are consistent
|
||||
if all(r is not None for r in results):
|
||||
for field in ['temperature', 'precipitation', 'humidity', 'wind_speed', 'pressure']:
|
||||
types = [type(r[field]) for r in results if field in r]
|
||||
if types:
|
||||
first_type = types[0]
|
||||
assert all(t == first_type for t in types), f"Inconsistent types for {field}: {types}"
|
||||
|
||||
async def test_temperature_consistency(self, aemet_client, madrid_coords):
|
||||
"""Test temperature consistency between different data sources"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Get current weather and today's forecast
|
||||
current = await aemet_client.get_current_weather(lat, lon)
|
||||
forecast = await aemet_client.get_forecast(lat, lon, 1)
|
||||
|
||||
if current and forecast and len(forecast) > 0:
|
||||
current_temp = current['temperature']
|
||||
forecast_temp = forecast[0]['temperature']
|
||||
|
||||
# Temperatures should be reasonably close (within 15°C)
|
||||
temp_diff = abs(current_temp - forecast_temp)
|
||||
assert temp_diff < 15, f"Temperature difference too large: current={current_temp}°C, forecast={forecast_temp}°C"
|
||||
|
||||
async def test_source_consistency(self, aemet_client, madrid_coords):
|
||||
"""Test that data source is consistent within same time period"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Get multiple current weather readings
|
||||
current1 = await aemet_client.get_current_weather(lat, lon)
|
||||
current2 = await aemet_client.get_current_weather(lat, lon)
|
||||
|
||||
if current1 and current2:
|
||||
# Should use same source type (both real or both synthetic)
|
||||
assert current1['source'] == current2['source'], "Should use consistent data source"
|
||||
|
||||
def test_historical_data_ordering(self, weather_parser, mock_historical_data):
|
||||
"""Test that historical data is properly ordered"""
|
||||
parsed_data = weather_parser.parse_historical_data(mock_historical_data)
|
||||
|
||||
if len(parsed_data) > 1:
|
||||
dates = [record['date'] for record in parsed_data]
|
||||
assert dates == sorted(dates), "Historical data should be chronologically ordered"
|
||||
|
||||
def test_forecast_date_progression(self, weather_parser, mock_forecast_data):
|
||||
"""Test that forecast dates progress correctly"""
|
||||
parsed_forecast = weather_parser.parse_forecast_data(mock_forecast_data, 7)
|
||||
|
||||
if len(parsed_forecast) > 1:
|
||||
for i in range(1, len(parsed_forecast)):
|
||||
prev_date = parsed_forecast[i-1]['forecast_date']
|
||||
curr_date = parsed_forecast[i]['forecast_date']
|
||||
diff = (curr_date - prev_date).days
|
||||
assert diff == 1, f"Forecast dates should be consecutive days, got {diff} day difference"
|
||||
|
||||
|
||||
class TestAEMETErrorRecovery:
|
||||
"""Test error recovery and resilience"""
|
||||
|
||||
async def test_network_interruption_recovery(self, aemet_client, madrid_coords):
|
||||
"""Test recovery from network interruptions"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Mock intermittent network failures
|
||||
call_count = 0
|
||||
|
||||
async def mock_get_with_failures(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count <= 2: # Fail first two calls
|
||||
raise Exception("Network timeout")
|
||||
else:
|
||||
return {"datos": "http://example.com/data"}
|
||||
|
||||
with patch.object(aemet_client, '_get', side_effect=mock_get_with_failures):
|
||||
result = await aemet_client.get_current_weather(lat, lon)
|
||||
|
||||
# Should eventually succeed or fallback to synthetic
|
||||
assert result is not None, "Should recover from network failures"
|
||||
assert result['source'] in [WeatherSource.AEMET.value, WeatherSource.SYNTHETIC.value]
|
||||
|
||||
async def test_partial_data_recovery(self, aemet_client, madrid_coords, weather_parser):
|
||||
"""Test recovery from partial/corrupted data"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Mock corrupted historical data (some records missing fields)
|
||||
corrupted_data = [
|
||||
{"fecha": "2025-07-20", "tmax": 25.2}, # Missing tmin and other fields
|
||||
{"fecha": "2025-07-21"}, # Only has date
|
||||
{"tmax": 27.0, "tmin": 15.0}, # Missing date
|
||||
{"fecha": "2025-07-22", "tmax": 23.0, "tmin": 14.0, "prec": 0.0} # Complete record
|
||||
]
|
||||
|
||||
parsed_data = weather_parser.parse_historical_data(corrupted_data)
|
||||
|
||||
# Should only return valid records and handle corrupted ones gracefully
|
||||
assert isinstance(parsed_data, list), "Should return list even with corrupted data"
|
||||
valid_records = [r for r in parsed_data if 'date' in r and r['date'] is not None]
|
||||
assert len(valid_records) >= 1, "Should salvage at least some valid records"
|
||||
|
||||
async def test_malformed_json_recovery(self, aemet_client, madrid_coords):
|
||||
"""Test recovery from malformed JSON responses"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Mock malformed responses
|
||||
malformed_responses = [
|
||||
None,
|
||||
"",
|
||||
"invalid json",
|
||||
{"incomplete": "response"},
|
||||
{"datos": None},
|
||||
{"datos": ""},
|
||||
]
|
||||
|
||||
for response in malformed_responses:
|
||||
with patch.object(aemet_client, '_get', new_callable=AsyncMock, return_value=response):
|
||||
result = await aemet_client.get_current_weather(lat, lon)
|
||||
|
||||
assert result is not None, f"Should handle malformed response: {response}"
|
||||
assert result['source'] == WeatherSource.SYNTHETIC.value, "Should fallback to synthetic"
|
||||
|
||||
async def test_api_rate_limiting_recovery(self, aemet_client, madrid_coords):
|
||||
"""Test recovery from API rate limiting"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Mock rate limiting responses
|
||||
rate_limit_response = {
|
||||
"descripcion": "Demasiadas peticiones",
|
||||
"estado": 429
|
||||
}
|
||||
|
||||
with patch.object(aemet_client, '_get', new_callable=AsyncMock, return_value=rate_limit_response):
|
||||
result = await aemet_client.get_current_weather(lat, lon)
|
||||
|
||||
assert result is not None, "Should handle rate limiting"
|
||||
assert result['source'] == WeatherSource.SYNTHETIC.value, "Should fallback to synthetic on rate limit"
|
||||
|
||||
|
||||
class TestAEMETPerformanceAndScaling:
|
||||
"""Test performance characteristics and scaling behavior"""
|
||||
|
||||
async def test_concurrent_requests_performance(self, aemet_client, madrid_coords):
|
||||
"""Test performance with concurrent requests"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Create multiple concurrent requests
|
||||
tasks = []
|
||||
for i in range(10):
|
||||
task = aemet_client.get_current_weather(lat, lon)
|
||||
tasks.append(task)
|
||||
|
||||
start_time = datetime.now()
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
execution_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
# Check that most requests succeeded
|
||||
successful_results = [r for r in results if isinstance(r, dict) and 'temperature' in r]
|
||||
assert len(successful_results) >= 8, "Most concurrent requests should succeed"
|
||||
|
||||
# Should complete in reasonable time (allowing for potential API rate limiting)
|
||||
assert execution_time < 15000, f"Concurrent requests took too long: {execution_time:.0f}ms"
|
||||
|
||||
print(f"✅ Concurrent requests test - {len(successful_results)}/10 succeeded in {execution_time:.0f}ms")
|
||||
|
||||
async def test_memory_usage_with_large_datasets(self, aemet_client, madrid_coords):
|
||||
"""Test memory usage with large historical datasets"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Request large historical dataset
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=90) # 3 months
|
||||
|
||||
import psutil
|
||||
import os
|
||||
|
||||
# Get initial memory usage
|
||||
process = psutil.Process(os.getpid())
|
||||
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||
|
||||
result = await aemet_client.get_historical_weather(lat, lon, start_date, end_date)
|
||||
|
||||
# Get final memory usage
|
||||
final_memory = process.memory_info().rss / 1024 / 1024 # MB
|
||||
memory_increase = final_memory - initial_memory
|
||||
|
||||
assert isinstance(result, list), "Should return historical data"
|
||||
|
||||
# Memory increase should be reasonable (less than 100MB for 90 days)
|
||||
assert memory_increase < 100, f"Memory usage increased too much: {memory_increase:.1f}MB"
|
||||
|
||||
print(f"✅ Memory usage test - {len(result)} records, +{memory_increase:.1f}MB")
|
||||
|
||||
async def test_caching_behavior(self, aemet_client, madrid_coords):
|
||||
"""Test caching behavior and performance improvement"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# First request (cold)
|
||||
start_time = datetime.now()
|
||||
result1 = await aemet_client.get_current_weather(lat, lon)
|
||||
first_call_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
# Second request (potentially cached)
|
||||
start_time = datetime.now()
|
||||
result2 = await aemet_client.get_current_weather(lat, lon)
|
||||
second_call_time = (datetime.now() - start_time).total_seconds() * 1000
|
||||
|
||||
assert result1 is not None, "First call should succeed"
|
||||
assert result2 is not None, "Second call should succeed"
|
||||
|
||||
# Both should return valid data
|
||||
assert 'temperature' in result1, "First result should have temperature"
|
||||
assert 'temperature' in result2, "Second result should have temperature"
|
||||
|
||||
print(f"✅ Caching test - First call: {first_call_time:.0f}ms, Second call: {second_call_time:.0f}ms")
|
||||
|
||||
|
||||
class TestAEMETIntegrationScenarios:
|
||||
"""Test realistic integration scenarios"""
|
||||
|
||||
async def test_daily_weather_workflow(self, aemet_client, madrid_coords):
|
||||
"""Test a complete daily weather workflow"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Simulate a daily weather check workflow
|
||||
workflow_results = {}
|
||||
|
||||
# Step 1: Get current conditions
|
||||
current = await aemet_client.get_current_weather(lat, lon)
|
||||
workflow_results['current'] = current
|
||||
assert current is not None, "Should get current weather"
|
||||
|
||||
# Step 2: Get today's forecast
|
||||
forecast = await aemet_client.get_forecast(lat, lon, 1)
|
||||
workflow_results['forecast'] = forecast
|
||||
assert len(forecast) == 1, "Should get today's forecast"
|
||||
|
||||
# Step 3: Get week ahead forecast
|
||||
week_forecast = await aemet_client.get_forecast(lat, lon, 7)
|
||||
workflow_results['week_forecast'] = week_forecast
|
||||
assert len(week_forecast) == 7, "Should get 7-day forecast"
|
||||
|
||||
# Step 4: Get last week's actual weather for comparison
|
||||
end_date = datetime.now() - timedelta(days=1)
|
||||
start_date = end_date - timedelta(days=7)
|
||||
historical = await aemet_client.get_historical_weather(lat, lon, start_date, end_date)
|
||||
workflow_results['historical'] = historical
|
||||
assert isinstance(historical, list), "Should get historical data"
|
||||
|
||||
# Validate workflow consistency
|
||||
all_sources = set()
|
||||
if current: all_sources.add(current['source'])
|
||||
if forecast: all_sources.add(forecast[0]['source'])
|
||||
if week_forecast: all_sources.add(week_forecast[0]['source'])
|
||||
if historical: all_sources.update([h['source'] for h in historical])
|
||||
|
||||
print(f"✅ Daily workflow test - Sources used: {', '.join(all_sources)}")
|
||||
|
||||
return workflow_results
|
||||
|
||||
async def test_weather_alerting_scenario(self, aemet_client, madrid_coords):
|
||||
"""Test weather alerting scenario"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Get forecast for potential alerts
|
||||
forecast = await aemet_client.get_forecast(lat, lon, 3)
|
||||
|
||||
alerts = []
|
||||
for day in forecast:
|
||||
# Check for extreme temperatures
|
||||
if day['temperature'] > 35:
|
||||
alerts.append(f"High temperature alert: {day['temperature']}°C on {day['forecast_date'].date()}")
|
||||
elif day['temperature'] < -5:
|
||||
alerts.append(f"Low temperature alert: {day['temperature']}°C on {day['forecast_date'].date()}")
|
||||
|
||||
# Check for high precipitation
|
||||
if day['precipitation'] > 20:
|
||||
alerts.append(f"Heavy rain alert: {day['precipitation']}mm on {day['forecast_date'].date()}")
|
||||
|
||||
# Alerts should be properly formatted
|
||||
for alert in alerts:
|
||||
assert isinstance(alert, str), "Alert should be string"
|
||||
assert "alert" in alert.lower(), "Alert should contain 'alert'"
|
||||
|
||||
print(f"✅ Weather alerting test - {len(alerts)} alerts generated")
|
||||
|
||||
return alerts
|
||||
|
||||
async def test_historical_analysis_scenario(self, aemet_client, madrid_coords):
|
||||
"""Test historical weather analysis scenario"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Get historical data for analysis
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=30)
|
||||
|
||||
historical = await aemet_client.get_historical_weather(lat, lon, start_date, end_date)
|
||||
|
||||
if historical:
|
||||
# Calculate statistics
|
||||
temperatures = [h['temperature'] for h in historical if h['temperature'] is not None]
|
||||
precipitations = [h['precipitation'] for h in historical if h['precipitation'] is not None]
|
||||
|
||||
if temperatures:
|
||||
avg_temp = sum(temperatures) / len(temperatures)
|
||||
max_temp = max(temperatures)
|
||||
min_temp = min(temperatures)
|
||||
|
||||
# Validate statistics
|
||||
assert min_temp <= avg_temp <= max_temp, "Temperature statistics should be logical"
|
||||
assert -20 <= min_temp <= 50, "Min temperature should be reasonable"
|
||||
assert -20 <= max_temp <= 50, "Max temperature should be reasonable"
|
||||
|
||||
if precipitations:
|
||||
total_precip = sum(precipitations)
|
||||
rainy_days = len([p for p in precipitations if p > 0.1])
|
||||
|
||||
# Validate precipitation statistics
|
||||
assert total_precip >= 0, "Total precipitation should be non-negative"
|
||||
assert 0 <= rainy_days <= len(precipitations), "Rainy days should be reasonable"
|
||||
|
||||
print(f"✅ Historical analysis test - {len(historical)} records analyzed")
|
||||
|
||||
return {
|
||||
'record_count': len(historical),
|
||||
'avg_temp': avg_temp if temperatures else None,
|
||||
'temp_range': (min_temp, max_temp) if temperatures else None,
|
||||
'total_precip': total_precip if precipitations else None,
|
||||
'rainy_days': rainy_days if precipitations else None
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
class TestAEMETRegressionTests:
|
||||
"""Regression tests for previously fixed issues"""
|
||||
|
||||
async def test_timezone_handling_regression(self, aemet_client, madrid_coords):
|
||||
"""Regression test for timezone handling issues"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
# Get current weather and forecast
|
||||
current = await aemet_client.get_current_weather(lat, lon)
|
||||
forecast = await aemet_client.get_forecast(lat, lon, 2)
|
||||
|
||||
if current:
|
||||
# Current weather date should be recent (within last hour)
|
||||
now = datetime.now()
|
||||
time_diff = abs((now - current['date']).total_seconds())
|
||||
assert time_diff < 3600, "Current weather timestamp should be recent"
|
||||
|
||||
if forecast:
|
||||
# Forecast dates should be in the future
|
||||
now = datetime.now().date()
|
||||
for day in forecast:
|
||||
forecast_date = day['forecast_date'].date()
|
||||
assert forecast_date >= now, f"Forecast date {forecast_date} should be today or future"
|
||||
|
||||
async def test_data_type_conversion_regression(self, weather_parser):
|
||||
"""Regression test for data type conversion issues"""
|
||||
# Test cases that previously caused issues
|
||||
test_cases = [
|
||||
("25.5", 25.5), # String to float
|
||||
(25, 25.0), # Int to float
|
||||
("", None), # Empty string
|
||||
("invalid", None), # Invalid string
|
||||
(None, None), # None input
|
||||
]
|
||||
|
||||
for input_val, expected in test_cases:
|
||||
result = weather_parser.safe_float(input_val, None)
|
||||
if expected is None:
|
||||
assert result is None, f"Expected None for input {input_val}, got {result}"
|
||||
else:
|
||||
assert result == expected, f"Expected {expected} for input {input_val}, got {result}"
|
||||
|
||||
def test_empty_data_handling_regression(self, weather_parser):
|
||||
"""Regression test for empty data handling"""
|
||||
# Empty lists and dictionaries should be handled gracefully
|
||||
empty_data_cases = [
|
||||
[],
|
||||
[{}],
|
||||
[{"invalid": "data"}],
|
||||
None,
|
||||
]
|
||||
|
||||
for empty_data in empty_data_cases:
|
||||
result = weather_parser.parse_historical_data(empty_data if empty_data is not None else [])
|
||||
assert isinstance(result, list), f"Should return list for empty data: {empty_data}"
|
||||
# May be empty or have some synthetic data, but should not crash
|
||||
|
||||
|
||||
# ================================================================
|
||||
# STANDALONE TEST RUNNER FOR EDGE CASES
|
||||
# ================================================================
|
||||
|
||||
async def run_edge_case_tests():
|
||||
"""Run edge case tests manually"""
|
||||
print("="*60)
|
||||
print("AEMET EDGE CASE TESTS")
|
||||
print("="*60)
|
||||
|
||||
client = AEMETClient()
|
||||
parser = WeatherDataParser()
|
||||
generator = SyntheticWeatherGenerator()
|
||||
|
||||
madrid_coords = (40.4168, -3.7038)
|
||||
|
||||
print(f"\n1. Testing extreme coordinates...")
|
||||
extreme_result = await client.get_current_weather(90, 180)
|
||||
print(f" Extreme coords result: {extreme_result['source']} source")
|
||||
|
||||
print(f"\n2. Testing parser edge cases...")
|
||||
parser_tests = [
|
||||
parser.safe_float(None, 10.0),
|
||||
parser.safe_float("invalid", 5.0),
|
||||
parser.extract_temperature_value([]),
|
||||
]
|
||||
print(f" Parser edge cases passed: {len(parser_tests)}")
|
||||
|
||||
print(f"\n3. Testing synthetic generator extremes...")
|
||||
large_forecast = generator.generate_forecast_sync(100)
|
||||
print(f" Generated {len(large_forecast)} forecast days")
|
||||
|
||||
print(f"\n4. Testing concurrent requests...")
|
||||
tasks = [client.get_current_weather(*madrid_coords) for _ in range(5)]
|
||||
concurrent_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
successful = len([r for r in concurrent_results if isinstance(r, dict)])
|
||||
print(f" Concurrent requests: {successful}/5 successful")
|
||||
|
||||
print(f"\n✅ Edge case tests completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_edge_case_tests())
|
||||
34
services/external/Dockerfile
vendored
Normal file
34
services/external/Dockerfile
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
# services/external/Dockerfile
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
g++ \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements and install Python dependencies
|
||||
COPY services/external/requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared modules first
|
||||
COPY shared/ /app/shared/
|
||||
|
||||
# Copy application code
|
||||
COPY services/external/app/ /app/app/
|
||||
|
||||
# Set Python path to include shared modules
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||
CMD python -c "import requests; requests.get('http://localhost:8000/health', timeout=5)" || exit 1
|
||||
|
||||
# Run the application
|
||||
CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
1
services/external/app/__init__.py
vendored
Normal file
1
services/external/app/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
# services/external/app/__init__.py
|
||||
1
services/external/app/api/__init__.py
vendored
Normal file
1
services/external/app/api/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
# services/external/app/api/__init__.py
|
||||
@@ -1,6 +1,4 @@
|
||||
# ================================================================
|
||||
# services/data/app/api/traffic.py - FIXED VERSION
|
||||
# ================================================================
|
||||
# services/external/app/api/traffic.py
|
||||
"""Traffic data API endpoints with improved error handling"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path
|
||||
@@ -12,10 +10,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.services.traffic_service import TrafficService
|
||||
from app.services.messaging import data_publisher, publish_traffic_updated
|
||||
from app.schemas.external import (
|
||||
from app.services.messaging import publish_traffic_updated
|
||||
from app.schemas.traffic import (
|
||||
TrafficDataResponse,
|
||||
HistoricalTrafficRequest
|
||||
HistoricalTrafficRequest,
|
||||
TrafficForecastRequest
|
||||
)
|
||||
|
||||
from shared.auth.decorators import (
|
||||
@@ -86,7 +85,7 @@ async def get_historical_traffic(
|
||||
raise HTTPException(status_code=400, detail="Date range cannot exceed 90 days")
|
||||
|
||||
historical_data = await traffic_service.get_historical_traffic(
|
||||
request.latitude, request.longitude, request.start_date, request.end_date, db
|
||||
request.latitude, request.longitude, request.start_date, request.end_date, str(tenant_id)
|
||||
)
|
||||
|
||||
# Publish event (with error handling)
|
||||
@@ -112,58 +111,74 @@ async def get_historical_traffic(
|
||||
logger.error("Unexpected error in historical traffic API", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
@router.post("/tenants/{tenant_id}/traffic/stored")
|
||||
async def get_stored_traffic_for_training(
|
||||
request: HistoricalTrafficRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@router.post("/tenants/{tenant_id}/traffic/forecast")
|
||||
async def get_traffic_forecast(
|
||||
request: TrafficForecastRequest,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
):
|
||||
"""Get stored traffic data specifically for training/re-training purposes"""
|
||||
"""Get traffic forecast for location"""
|
||||
try:
|
||||
# Validate date range
|
||||
if request.end_date <= request.start_date:
|
||||
raise HTTPException(status_code=400, detail="End date must be after start date")
|
||||
logger.debug("API: Getting traffic forecast",
|
||||
lat=request.latitude, lon=request.longitude, hours=request.hours)
|
||||
|
||||
# Allow longer date ranges for training (up to 3 years)
|
||||
if (request.end_date - request.start_date).days > 1095:
|
||||
raise HTTPException(status_code=400, detail="Date range cannot exceed 3 years for training data")
|
||||
# For now, return mock forecast data since we don't have a real traffic forecast service
|
||||
# In a real implementation, this would call a traffic forecasting service
|
||||
|
||||
logger.info("Retrieving stored traffic data for training",
|
||||
tenant_id=str(tenant_id),
|
||||
location=f"{request.latitude},{request.longitude}",
|
||||
date_range=f"{request.start_date} to {request.end_date}")
|
||||
# Generate mock forecast data for the requested hours
|
||||
forecast_data = []
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Use the dedicated method for training data retrieval
|
||||
stored_data = await traffic_service.get_stored_traffic_for_training(
|
||||
request.latitude, request.longitude, request.start_date, request.end_date, db
|
||||
)
|
||||
base_time = datetime.utcnow()
|
||||
for hour in range(request.hours):
|
||||
forecast_time = base_time + timedelta(hours=hour)
|
||||
|
||||
# Mock traffic pattern (higher during rush hours)
|
||||
hour_of_day = forecast_time.hour
|
||||
if 7 <= hour_of_day <= 9 or 17 <= hour_of_day <= 19: # Rush hours
|
||||
traffic_volume = 120
|
||||
pedestrian_count = 80
|
||||
congestion_level = "high"
|
||||
average_speed = 15
|
||||
elif 22 <= hour_of_day or hour_of_day <= 6: # Night hours
|
||||
traffic_volume = 20
|
||||
pedestrian_count = 10
|
||||
congestion_level = "low"
|
||||
average_speed = 50
|
||||
else: # Regular hours
|
||||
traffic_volume = 60
|
||||
pedestrian_count = 40
|
||||
congestion_level = "medium"
|
||||
average_speed = 35
|
||||
|
||||
# Use consistent TrafficDataResponse format
|
||||
forecast_data.append({
|
||||
"date": forecast_time.isoformat(),
|
||||
"traffic_volume": traffic_volume,
|
||||
"pedestrian_count": pedestrian_count,
|
||||
"congestion_level": congestion_level,
|
||||
"average_speed": average_speed,
|
||||
"source": "madrid_opendata_forecast"
|
||||
})
|
||||
|
||||
# Log retrieval for audit purposes
|
||||
logger.info("Stored traffic data retrieved for training",
|
||||
records_count=len(stored_data),
|
||||
tenant_id=str(tenant_id),
|
||||
purpose="model_training")
|
||||
|
||||
# Publish event for monitoring
|
||||
# Publish event (with error handling)
|
||||
try:
|
||||
await publish_traffic_updated({
|
||||
"type": "stored_data_retrieved_for_training",
|
||||
"type": "forecast_requested",
|
||||
"latitude": request.latitude,
|
||||
"longitude": request.longitude,
|
||||
"start_date": request.start_date.isoformat(),
|
||||
"end_date": request.end_date.isoformat(),
|
||||
"records_count": len(stored_data),
|
||||
"tenant_id": str(tenant_id),
|
||||
"hours": request.hours,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
except Exception as pub_error:
|
||||
logger.warning("Failed to publish stored traffic retrieval event", error=str(pub_error))
|
||||
logger.warning("Failed to publish traffic forecast event", error=str(pub_error))
|
||||
# Continue processing
|
||||
|
||||
return stored_data
|
||||
logger.debug("Successfully returning traffic forecast", records=len(forecast_data))
|
||||
return forecast_data
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error in stored traffic retrieval API", error=str(e))
|
||||
logger.error("Unexpected error in traffic forecast API", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
@@ -1,5 +1,7 @@
|
||||
# services/data/app/api/weather.py - UPDATED WITH UNIFIED AUTH
|
||||
"""Weather data API endpoints with unified authentication"""
|
||||
# services/external/app/api/weather.py
|
||||
"""
|
||||
Weather API Endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, BackgroundTasks, Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
@@ -7,18 +9,15 @@ from datetime import datetime, date
|
||||
import structlog
|
||||
from uuid import UUID
|
||||
|
||||
from app.schemas.external import (
|
||||
from app.schemas.weather import (
|
||||
WeatherDataResponse,
|
||||
WeatherForecastResponse,
|
||||
WeatherForecastRequest
|
||||
WeatherForecastRequest,
|
||||
HistoricalWeatherRequest
|
||||
)
|
||||
from app.services.weather_service import WeatherService
|
||||
from app.services.messaging import publish_weather_updated
|
||||
|
||||
from app.schemas.external import (
|
||||
HistoricalWeatherRequest
|
||||
)
|
||||
|
||||
# Import unified authentication from shared library
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
@@ -73,6 +72,49 @@ async def get_current_weather(
|
||||
logger.error("Failed to get current weather", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
@router.post("/tenants/{tenant_id}/weather/historical")
|
||||
async def get_historical_weather(
|
||||
request: HistoricalWeatherRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
):
|
||||
"""Get historical weather data with date range in payload"""
|
||||
try:
|
||||
# Validate date range
|
||||
if request.end_date <= request.start_date:
|
||||
raise HTTPException(status_code=400, detail="End date must be after start date")
|
||||
|
||||
if (request.end_date - request.start_date).days > 1000:
|
||||
raise HTTPException(status_code=400, detail="Date range cannot exceed 90 days")
|
||||
|
||||
historical_data = await weather_service.get_historical_weather(
|
||||
request.latitude, request.longitude, request.start_date, request.end_date)
|
||||
|
||||
# Publish event (with error handling)
|
||||
try:
|
||||
await publish_weather_updated({
|
||||
"type": "historical_requested",
|
||||
"latitude": request.latitude,
|
||||
"longitude": request.longitude,
|
||||
"start_date": request.start_date.isoformat(),
|
||||
"end_date": request.end_date.isoformat(),
|
||||
"records_count": len(historical_data),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
except Exception as pub_error:
|
||||
logger.warning("Failed to publish historical weather event", error=str(pub_error))
|
||||
# Continue processing
|
||||
|
||||
return historical_data
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error in historical weather API", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/weather/forecast", response_model=List[WeatherForecastResponse])
|
||||
async def get_weather_forecast(
|
||||
request: WeatherForecastRequest,
|
||||
@@ -113,86 +155,3 @@ async def get_weather_forecast(
|
||||
except Exception as e:
|
||||
logger.error("Failed to get weather forecast", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
@router.post("/tenants/{tenant_id}/weather/historical")
|
||||
async def get_historical_weather(
|
||||
request: HistoricalWeatherRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
):
|
||||
"""Get historical weather data with date range in payload"""
|
||||
try:
|
||||
# Validate date range
|
||||
if request.end_date <= request.start_date:
|
||||
raise HTTPException(status_code=400, detail="End date must be after start date")
|
||||
|
||||
if (request.end_date - request.start_date).days > 1000:
|
||||
raise HTTPException(status_code=400, detail="Date range cannot exceed 90 days")
|
||||
|
||||
historical_data = await weather_service.get_historical_weather(
|
||||
request.latitude, request.longitude, request.start_date, request.end_date, db
|
||||
)
|
||||
|
||||
# Publish event (with error handling)
|
||||
try:
|
||||
await publish_weather_updated({
|
||||
"type": "historical_requested",
|
||||
"latitude": request.latitude,
|
||||
"longitude": request.longitude,
|
||||
"start_date": request.start_date.isoformat(),
|
||||
"end_date": request.end_date.isoformat(),
|
||||
"records_count": len(historical_data),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
except Exception as pub_error:
|
||||
logger.warning("Failed to publish historical weather event", error=str(pub_error))
|
||||
# Continue processing
|
||||
|
||||
return historical_data
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error in historical weather API", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
@router.post("/tenants/{tenant_id}weather/sync")
|
||||
async def sync_weather_data(
|
||||
background_tasks: BackgroundTasks,
|
||||
force: bool = Query(False, description="Force sync even if recently synced"),
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
):
|
||||
"""Manually trigger weather data synchronization"""
|
||||
try:
|
||||
logger.info("Weather sync requested",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"],
|
||||
force=force)
|
||||
|
||||
# Check if user has permission to sync (could be admin only)
|
||||
if current_user.get("role") not in ["admin", "manager"]:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Insufficient permissions to sync weather data"
|
||||
)
|
||||
|
||||
# Schedule background sync
|
||||
background_tasks.add_task(
|
||||
weather_service.sync_weather_data,
|
||||
tenant_id=tenant_id,
|
||||
force=force
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Weather sync initiated",
|
||||
"status": "processing",
|
||||
"initiated_by": current_user["user_id"]
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to initiate weather sync", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
1
services/external/app/core/__init__.py
vendored
Normal file
1
services/external/app/core/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
# services/external/app/core/__init__.py
|
||||
@@ -1,30 +1,26 @@
|
||||
# ================================================================
|
||||
# DATA SERVICE CONFIGURATION
|
||||
# services/data/app/core/config.py
|
||||
# ================================================================
|
||||
|
||||
"""
|
||||
Data service configuration
|
||||
External data integration and management
|
||||
"""
|
||||
# services/external/app/core/config.py
|
||||
|
||||
from shared.config.base import BaseServiceSettings
|
||||
import os
|
||||
from pydantic import Field
|
||||
|
||||
class DataSettings(BaseServiceSettings):
|
||||
"""Data service specific settings"""
|
||||
|
||||
# Service Identity
|
||||
APP_NAME: str = "Data Service"
|
||||
SERVICE_NAME: str = "data-service"
|
||||
DESCRIPTION: str = "External data integration and management service"
|
||||
SERVICE_NAME: str = "external-service"
|
||||
VERSION: str = "1.0.0"
|
||||
APP_NAME: str = "Bakery External Data Service"
|
||||
DESCRIPTION: str = "External data collection service for weather and traffic data"
|
||||
|
||||
# Database Configuration
|
||||
DATABASE_URL: str = os.getenv("DATA_DATABASE_URL",
|
||||
"postgresql+asyncpg://data_user:data_pass123@data-db:5432/data_db")
|
||||
# API Configuration
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
# Redis Database (dedicated for external data cache)
|
||||
REDIS_DB: int = 3
|
||||
# Override database URL to use EXTERNAL_DATABASE_URL
|
||||
DATABASE_URL: str = Field(
|
||||
default="postgresql+asyncpg://external_user:external_pass123@external-db:5432/external_db",
|
||||
env="EXTERNAL_DATABASE_URL"
|
||||
)
|
||||
|
||||
# External API Configuration
|
||||
AEMET_API_KEY: str = os.getenv("AEMET_API_KEY", "")
|
||||
81
services/external/app/core/database.py
vendored
Normal file
81
services/external/app/core/database.py
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
# services/external/app/core/database.py
|
||||
"""
|
||||
External Service Database Configuration using shared database manager
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from app.core.config import settings
|
||||
from shared.database.base import DatabaseManager, Base
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Create database manager instance
|
||||
database_manager = DatabaseManager(
|
||||
database_url=settings.DATABASE_URL,
|
||||
service_name="external-service"
|
||||
)
|
||||
|
||||
async def get_db():
|
||||
"""
|
||||
Database dependency for FastAPI - using shared database manager
|
||||
"""
|
||||
async for session in database_manager.get_db():
|
||||
yield session
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""Initialize database tables using shared database manager"""
|
||||
try:
|
||||
logger.info("Initializing External Service database...")
|
||||
|
||||
# Import all models to ensure they're registered
|
||||
from app.models import weather, traffic # noqa: F401
|
||||
|
||||
# Create all tables using database manager
|
||||
await database_manager.create_tables(Base.metadata)
|
||||
|
||||
logger.info("External Service database initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize database", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
async def close_db():
|
||||
"""Close database connections using shared database manager"""
|
||||
try:
|
||||
await database_manager.close_connections()
|
||||
logger.info("Database connections closed")
|
||||
except Exception as e:
|
||||
logger.error("Error closing database connections", error=str(e))
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_transaction():
|
||||
"""
|
||||
Context manager for database transactions using shared database manager
|
||||
"""
|
||||
async with database_manager.get_session() as session:
|
||||
try:
|
||||
async with session.begin():
|
||||
yield session
|
||||
except Exception as e:
|
||||
logger.error("Transaction error", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_background_session():
|
||||
"""
|
||||
Context manager for background tasks using shared database manager
|
||||
"""
|
||||
async with database_manager.get_background_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def health_check():
|
||||
"""Database health check using shared database manager"""
|
||||
return await database_manager.health_check()
|
||||
@@ -16,13 +16,6 @@ from ..clients.madrid_client import MadridTrafficAPIClient
|
||||
from ..processors.madrid_processor import MadridTrafficDataProcessor
|
||||
from ..processors.madrid_business_logic import MadridTrafficAnalyzer
|
||||
from ..models.madrid_models import TrafficRecord, CongestionLevel
|
||||
from app.core.performance import (
|
||||
rate_limit,
|
||||
async_cache,
|
||||
monitor_performance,
|
||||
global_performance_monitor
|
||||
)
|
||||
|
||||
|
||||
class MadridTrafficClient(BaseTrafficClient, BaseAPIClient):
|
||||
"""
|
||||
@@ -57,9 +50,6 @@ class MadridTrafficClient(BaseTrafficClient, BaseAPIClient):
|
||||
return (self.MADRID_BOUNDS['lat_min'] <= latitude <= self.MADRID_BOUNDS['lat_max'] and
|
||||
self.MADRID_BOUNDS['lon_min'] <= longitude <= self.MADRID_BOUNDS['lon_max'])
|
||||
|
||||
@rate_limit(calls=30, period=60)
|
||||
@async_cache(ttl=300)
|
||||
@monitor_performance(monitor=global_performance_monitor)
|
||||
async def get_current_traffic(self, latitude: float, longitude: float) -> Optional[Dict[str, Any]]:
|
||||
"""Get current traffic data with enhanced pedestrian inference"""
|
||||
try:
|
||||
@@ -98,8 +88,6 @@ class MadridTrafficClient(BaseTrafficClient, BaseAPIClient):
|
||||
self.logger.error("Error getting current traffic", error=str(e))
|
||||
return None
|
||||
|
||||
@rate_limit(calls=10, period=60)
|
||||
@monitor_performance(monitor=global_performance_monitor)
|
||||
async def get_historical_traffic(self, latitude: float, longitude: float,
|
||||
start_date: datetime, end_date: datetime) -> List[Dict[str, Any]]:
|
||||
"""Get historical traffic data with pedestrian enhancement"""
|
||||
@@ -271,11 +259,24 @@ class MadridTrafficClient(BaseTrafficClient, BaseAPIClient):
|
||||
zip_content, zip_url, latitude, longitude, nearest_points
|
||||
)
|
||||
|
||||
# Filter by date range
|
||||
filtered_records = [
|
||||
record for record in month_records
|
||||
if start_date <= record.get('date', datetime.min.replace(tzinfo=timezone.utc)) <= end_date
|
||||
]
|
||||
# Filter by date range - ensure timezone consistency
|
||||
# Make sure start_date and end_date have timezone info for comparison
|
||||
start_tz = start_date if start_date.tzinfo else start_date.replace(tzinfo=timezone.utc)
|
||||
end_tz = end_date if end_date.tzinfo else end_date.replace(tzinfo=timezone.utc)
|
||||
|
||||
filtered_records = []
|
||||
for record in month_records:
|
||||
record_date = record.get('date')
|
||||
if not record_date:
|
||||
continue
|
||||
|
||||
# Ensure record date has timezone info
|
||||
if not record_date.tzinfo:
|
||||
record_date = record_date.replace(tzinfo=timezone.utc)
|
||||
|
||||
# Now compare with consistent timezone info
|
||||
if start_tz <= record_date <= end_tz:
|
||||
filtered_records.append(record)
|
||||
|
||||
historical_records.extend(filtered_records)
|
||||
|
||||
@@ -54,19 +54,6 @@ class BaseAPIClient:
|
||||
logger.error("Unexpected error", error=str(e), url=url)
|
||||
return None
|
||||
|
||||
async def get(self, url: str, headers: Optional[Dict] = None, timeout: Optional[int] = None) -> httpx.Response:
|
||||
"""
|
||||
Public GET method for direct HTTP requests
|
||||
Returns the raw httpx Response object for maximum flexibility
|
||||
"""
|
||||
request_headers = headers or {}
|
||||
request_timeout = httpx.Timeout(timeout if timeout else 30.0)
|
||||
|
||||
async with httpx.AsyncClient(timeout=request_timeout, follow_redirects=True) as client:
|
||||
response = await client.get(url, headers=request_headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
async def _fetch_url_directly(self, url: str, headers: Optional[Dict] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch data directly from a full URL (for AEMET datos URLs)"""
|
||||
try:
|
||||
@@ -138,7 +125,7 @@ class BaseAPIClient:
|
||||
logger.error("Unexpected error", error=str(e), url=url)
|
||||
return None
|
||||
|
||||
async def get(self, url: str, headers: Optional[Dict] = None, timeout: Optional[int] = None) -> httpx.Response:
|
||||
async def get_direct(self, url: str, headers: Optional[Dict] = None, timeout: Optional[int] = None) -> httpx.Response:
|
||||
"""
|
||||
Public GET method for direct HTTP requests
|
||||
Returns the raw httpx Response object for maximum flexibility
|
||||
@@ -17,7 +17,7 @@ from ..base_client import BaseAPIClient
|
||||
class MadridTrafficAPIClient(BaseAPIClient):
|
||||
"""Pure HTTP client for Madrid traffic APIs"""
|
||||
|
||||
TRAFFIC_ENDPOINT = "https://datos.madrid.es/egob/catalogo/202468-10-intensidad-trafico.xml"
|
||||
TRAFFIC_ENDPOINT = "https://informo.madrid.es/informo/tmadrid/pm.xml"
|
||||
MEASUREMENT_POINTS_URL = "https://datos.madrid.es/egob/catalogo/202468-263-intensidad-trafico.csv"
|
||||
|
||||
def __init__(self):
|
||||
@@ -46,12 +46,16 @@ class MadridTrafficAPIClient(BaseAPIClient):
|
||||
base_url = "https://datos.madrid.es/egob/catalogo/208627"
|
||||
|
||||
# URL numbering pattern (this may need adjustment based on actual URLs)
|
||||
# Note: Historical data is only available for past periods, not current/future
|
||||
if year == 2023:
|
||||
url_number = 116 + (month - 1) # 116-127 for 2023
|
||||
elif year == 2024:
|
||||
url_number = 128 + (month - 1) # 128-139 for 2024
|
||||
elif year == 2025:
|
||||
# For 2025, use the continuing numbering from 2024
|
||||
url_number = 140 + (month - 1) # Starting from 140 for January 2025
|
||||
else:
|
||||
url_number = 116 # Fallback
|
||||
url_number = 116 # Fallback to 2023 data
|
||||
|
||||
return f"{base_url}-{url_number}-transporte-ptomedida-historico.zip"
|
||||
|
||||
@@ -69,7 +73,7 @@ class MadridTrafficAPIClient(BaseAPIClient):
|
||||
'Referer': 'https://datos.madrid.es/'
|
||||
}
|
||||
|
||||
response = await self.get(endpoint, headers=headers, timeout=30)
|
||||
response = await self.get_direct(endpoint, headers=headers, timeout=30)
|
||||
|
||||
if not response or response.status_code != 200:
|
||||
self.logger.warning("Failed to fetch XML data",
|
||||
186
services/external/app/main.py
vendored
Normal file
186
services/external/app/main.py
vendored
Normal file
@@ -0,0 +1,186 @@
|
||||
# services/external/app/main.py
|
||||
"""
|
||||
External Service Main Application
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import init_db, close_db
|
||||
from shared.monitoring import setup_logging, HealthChecker
|
||||
from shared.monitoring.metrics import setup_metrics_early
|
||||
|
||||
# Setup logging first
|
||||
setup_logging("external-service", settings.LOG_LEVEL)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Global variables for lifespan access
|
||||
metrics_collector = None
|
||||
health_checker = None
|
||||
|
||||
# Create FastAPI app FIRST
|
||||
app = FastAPI(
|
||||
title="Bakery External Data Service",
|
||||
description="External data collection service for weather, traffic, and events data",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Setup metrics BEFORE any middleware and BEFORE lifespan
|
||||
metrics_collector = setup_metrics_early(app, "external-service")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan events"""
|
||||
global health_checker
|
||||
|
||||
# Startup
|
||||
logger.info("Starting External Service...")
|
||||
|
||||
try:
|
||||
# Initialize database
|
||||
await init_db()
|
||||
logger.info("Database initialized")
|
||||
|
||||
# Register custom metrics
|
||||
metrics_collector.register_counter("weather_api_calls_total", "Total weather API calls")
|
||||
metrics_collector.register_counter("weather_api_success_total", "Successful weather API calls")
|
||||
metrics_collector.register_counter("weather_api_failures_total", "Failed weather API calls")
|
||||
|
||||
metrics_collector.register_counter("traffic_api_calls_total", "Total traffic API calls")
|
||||
metrics_collector.register_counter("traffic_api_success_total", "Successful traffic API calls")
|
||||
metrics_collector.register_counter("traffic_api_failures_total", "Failed traffic API calls")
|
||||
|
||||
metrics_collector.register_counter("data_collection_jobs_total", "Data collection jobs")
|
||||
metrics_collector.register_counter("data_records_stored_total", "Data records stored")
|
||||
metrics_collector.register_counter("data_quality_issues_total", "Data quality issues detected")
|
||||
|
||||
metrics_collector.register_histogram("weather_api_duration_seconds", "Weather API call duration")
|
||||
metrics_collector.register_histogram("traffic_api_duration_seconds", "Traffic API call duration")
|
||||
metrics_collector.register_histogram("data_collection_duration_seconds", "Data collection job duration")
|
||||
metrics_collector.register_histogram("data_processing_duration_seconds", "Data processing duration")
|
||||
|
||||
# Setup health checker
|
||||
health_checker = HealthChecker("external-service")
|
||||
|
||||
# Add database health check
|
||||
async def check_database():
|
||||
try:
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy import text
|
||||
async for db in get_db():
|
||||
await db.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
return f"Database error: {e}"
|
||||
|
||||
# Add external API health checks
|
||||
async def check_weather_api():
|
||||
try:
|
||||
# Simple connectivity check
|
||||
if settings.AEMET_API_KEY:
|
||||
return True
|
||||
else:
|
||||
return "AEMET API key not configured"
|
||||
except Exception as e:
|
||||
return f"Weather API error: {e}"
|
||||
|
||||
async def check_traffic_api():
|
||||
try:
|
||||
# Simple connectivity check
|
||||
if settings.MADRID_OPENDATA_API_KEY:
|
||||
return True
|
||||
else:
|
||||
return "Madrid Open Data API key not configured"
|
||||
except Exception as e:
|
||||
return f"Traffic API error: {e}"
|
||||
|
||||
health_checker.add_check("database", check_database, timeout=5.0, critical=True)
|
||||
health_checker.add_check("weather_api", check_weather_api, timeout=10.0, critical=False)
|
||||
health_checker.add_check("traffic_api", check_traffic_api, timeout=10.0, critical=False)
|
||||
|
||||
# Store health checker in app state
|
||||
app.state.health_checker = health_checker
|
||||
|
||||
logger.info("External Service started successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start External Service: {e}")
|
||||
raise
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down External Service...")
|
||||
await close_db()
|
||||
|
||||
# Set lifespan AFTER metrics setup
|
||||
app.router.lifespan_context = lifespan
|
||||
|
||||
# CORS middleware (added after metrics setup)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
from app.api.weather import router as weather_router
|
||||
from app.api.traffic import router as traffic_router
|
||||
app.include_router(weather_router, prefix="/api/v1", tags=["weather"])
|
||||
app.include_router(traffic_router, prefix="/api/v1", tags=["traffic"])
|
||||
|
||||
# Health check endpoint
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Comprehensive health check endpoint"""
|
||||
if health_checker:
|
||||
return await health_checker.check_health()
|
||||
else:
|
||||
return {
|
||||
"service": "external-service",
|
||||
"status": "healthy",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
# Root endpoint
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {
|
||||
"service": "External Data Service",
|
||||
"version": "1.0.0",
|
||||
"status": "running",
|
||||
"endpoints": {
|
||||
"health": "/health",
|
||||
"docs": "/docs",
|
||||
"weather": "/api/v1/weather",
|
||||
"traffic": "/api/v1/traffic",
|
||||
"jobs": "/api/v1/jobs"
|
||||
},
|
||||
"data_sources": {
|
||||
"weather": "AEMET (Spanish Weather Service)",
|
||||
"traffic": "Madrid Open Data Portal",
|
||||
"coverage": "Madrid, Spain"
|
||||
}
|
||||
}
|
||||
|
||||
# Exception handlers
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
"""Global exception handler with metrics"""
|
||||
logger.error(f"Unhandled exception: {exc}", exc_info=True)
|
||||
|
||||
# Record error metric if available
|
||||
if metrics_collector:
|
||||
metrics_collector.increment_counter("errors_total", labels={"type": "unhandled"})
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "Internal server error"}
|
||||
)
|
||||
1
services/external/app/models/__init__.py
vendored
Normal file
1
services/external/app/models/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
# services/external/app/models/__init__.py
|
||||
@@ -3,8 +3,8 @@
|
||||
# ================================================================
|
||||
"""Weather data models"""
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Float, Integer, Text, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy import Column, String, DateTime, Float, Integer, Text, Index, Boolean
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -15,15 +15,36 @@ class WeatherData(Base):
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
location_id = Column(String(100), nullable=False, index=True)
|
||||
city = Column(String(50), nullable=False)
|
||||
station_name = Column(String(200), nullable=True)
|
||||
latitude = Column(Float, nullable=True)
|
||||
longitude = Column(Float, nullable=True)
|
||||
date = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
forecast_date = Column(DateTime(timezone=True), nullable=True)
|
||||
temperature = Column(Float, nullable=True) # Celsius
|
||||
temperature_min = Column(Float, nullable=True)
|
||||
temperature_max = Column(Float, nullable=True)
|
||||
feels_like = Column(Float, nullable=True)
|
||||
precipitation = Column(Float, nullable=True) # mm
|
||||
precipitation_probability = Column(Float, nullable=True)
|
||||
humidity = Column(Float, nullable=True) # percentage
|
||||
wind_speed = Column(Float, nullable=True) # km/h
|
||||
wind_direction = Column(Float, nullable=True)
|
||||
wind_gust = Column(Float, nullable=True)
|
||||
pressure = Column(Float, nullable=True) # hPa
|
||||
visibility = Column(Float, nullable=True)
|
||||
uv_index = Column(Float, nullable=True)
|
||||
cloud_cover = Column(Float, nullable=True)
|
||||
condition = Column(String(100), nullable=True)
|
||||
description = Column(String(200), nullable=True)
|
||||
weather_code = Column(String(20), nullable=True)
|
||||
source = Column(String(50), nullable=False, default="aemet")
|
||||
raw_data = Column(Text, nullable=True)
|
||||
data_type = Column(String(20), nullable=False)
|
||||
is_forecast = Column(Boolean, nullable=True)
|
||||
data_quality_score = Column(Float, nullable=True)
|
||||
raw_data = Column(JSON, nullable=True)
|
||||
processed_data = Column(JSON, nullable=True)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=True, index=True)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
191
services/external/app/repositories/traffic_repository.py
vendored
Normal file
191
services/external/app/repositories/traffic_repository.py
vendored
Normal file
@@ -0,0 +1,191 @@
|
||||
# ================================================================
|
||||
# services/data/app/repositories/traffic_repository.py
|
||||
# ================================================================
|
||||
"""
|
||||
Traffic Repository - Enhanced for multiple cities with comprehensive data access patterns
|
||||
Follows existing repository architecture while adding city-specific functionality
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, func, desc, asc, text, update, delete
|
||||
from sqlalchemy.orm import selectinload
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import structlog
|
||||
|
||||
from app.models.traffic import TrafficData
|
||||
from app.schemas.traffic import TrafficDataCreate, TrafficDataResponse
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TrafficRepository:
|
||||
"""
|
||||
Enhanced repository for traffic data operations across multiple cities
|
||||
Provides city-aware queries and advanced traffic analytics
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self.model = TrafficData
|
||||
|
||||
# ================================================================
|
||||
# CORE TRAFFIC DATA OPERATIONS
|
||||
# ================================================================
|
||||
|
||||
async def get_by_location_and_date_range(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> List[TrafficData]:
|
||||
"""Get traffic data by location and date range"""
|
||||
try:
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
# Build base query
|
||||
query = select(self.model).where(self.model.location_id == location_id)
|
||||
|
||||
# Add tenant filter if specified
|
||||
if tenant_id:
|
||||
query = query.where(self.model.tenant_id == tenant_id)
|
||||
|
||||
# Add date range filters
|
||||
if start_date:
|
||||
query = query.where(self.model.date >= start_date)
|
||||
|
||||
if end_date:
|
||||
query = query.where(self.model.date <= end_date)
|
||||
|
||||
# Order by date
|
||||
query = query.order_by(self.model.date)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get traffic data by location and date range",
|
||||
latitude=latitude, longitude=longitude,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get traffic data: {str(e)}")
|
||||
|
||||
async def store_traffic_data_batch(
|
||||
self,
|
||||
traffic_data_list: List[Dict[str, Any]],
|
||||
location_id: str,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""Store a batch of traffic data records with enhanced validation and duplicate handling."""
|
||||
stored_count = 0
|
||||
try:
|
||||
if not traffic_data_list:
|
||||
return 0
|
||||
|
||||
# Check for existing records to avoid duplicates
|
||||
dates = [data.get('date') for data in traffic_data_list if data.get('date')]
|
||||
existing_dates = set()
|
||||
if dates:
|
||||
existing_stmt = select(TrafficData.date).where(
|
||||
and_(
|
||||
TrafficData.location_id == location_id,
|
||||
TrafficData.date.in_(dates)
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(existing_stmt)
|
||||
existing_dates = {row[0] for row in result.fetchall()}
|
||||
logger.debug(f"Found {len(existing_dates)} existing records for location {location_id}")
|
||||
|
||||
batch_records = []
|
||||
for data in traffic_data_list:
|
||||
record_date = data.get('date')
|
||||
if not record_date or record_date in existing_dates:
|
||||
continue # Skip duplicates
|
||||
|
||||
# Validate data before preparing for insertion
|
||||
if self._validate_traffic_data(data):
|
||||
batch_records.append({
|
||||
'location_id': location_id,
|
||||
'city': data.get('city', 'madrid'), # Default to madrid for historical data
|
||||
'tenant_id': tenant_id, # Include tenant_id in batch insert
|
||||
'date': record_date,
|
||||
'traffic_volume': data.get('traffic_volume'),
|
||||
'pedestrian_count': data.get('pedestrian_count'),
|
||||
'congestion_level': data.get('congestion_level'),
|
||||
'average_speed': data.get('average_speed'),
|
||||
'source': data.get('source', 'unknown'),
|
||||
'raw_data': str(data)
|
||||
})
|
||||
|
||||
if batch_records:
|
||||
# Use bulk insert for performance
|
||||
await self.session.execute(
|
||||
TrafficData.__table__.insert(),
|
||||
batch_records
|
||||
)
|
||||
await self.session.commit()
|
||||
stored_count = len(batch_records)
|
||||
logger.info(f"Successfully stored {stored_count} traffic records for location {location_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to store traffic data batch",
|
||||
error=str(e), location_id=location_id)
|
||||
await self.session.rollback()
|
||||
raise DatabaseError(f"Batch store failed: {str(e)}")
|
||||
|
||||
return stored_count
|
||||
|
||||
def _validate_traffic_data(self, data: Dict[str, Any]) -> bool:
|
||||
"""Validate traffic data before storage"""
|
||||
required_fields = ['date']
|
||||
|
||||
# Check required fields
|
||||
for field in required_fields:
|
||||
if not data.get(field):
|
||||
return False
|
||||
|
||||
# Validate data types and ranges
|
||||
traffic_volume = data.get('traffic_volume')
|
||||
if traffic_volume is not None and (traffic_volume < 0 or traffic_volume > 10000):
|
||||
return False
|
||||
|
||||
pedestrian_count = data.get('pedestrian_count')
|
||||
if pedestrian_count is not None and (pedestrian_count < 0 or pedestrian_count > 10000):
|
||||
return False
|
||||
|
||||
average_speed = data.get('average_speed')
|
||||
if average_speed is not None and (average_speed < 0 or average_speed > 200):
|
||||
return False
|
||||
|
||||
congestion_level = data.get('congestion_level')
|
||||
if congestion_level and congestion_level not in ['low', 'medium', 'high', 'blocked']:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def get_historical_traffic_for_training(self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime) -> List[TrafficData]:
|
||||
"""Retrieve stored traffic data for training ML models."""
|
||||
try:
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
stmt = select(TrafficData).where(
|
||||
and_(
|
||||
TrafficData.location_id == location_id,
|
||||
TrafficData.date >= start_date,
|
||||
TrafficData.date <= end_date
|
||||
)
|
||||
).order_by(TrafficData.date)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to retrieve traffic data for training",
|
||||
error=str(e), location_id=location_id)
|
||||
raise DatabaseError(f"Training data retrieval failed: {str(e)}")
|
||||
138
services/external/app/repositories/weather_repository.py
vendored
Normal file
138
services/external/app/repositories/weather_repository.py
vendored
Normal file
@@ -0,0 +1,138 @@
|
||||
# services/external/app/repositories/weather_repository.py
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import structlog
|
||||
import json
|
||||
|
||||
from app.models.weather import WeatherData
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class WeatherRepository:
|
||||
"""
|
||||
Repository for weather data operations, adapted for WeatherService.
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_historical_weather(self,
|
||||
location_id: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime) -> List[WeatherData]:
|
||||
"""
|
||||
Retrieves historical weather data for a specific location and date range.
|
||||
This method directly supports the data retrieval logic in WeatherService.
|
||||
"""
|
||||
try:
|
||||
stmt = select(WeatherData).where(
|
||||
and_(
|
||||
WeatherData.location_id == location_id,
|
||||
WeatherData.date >= start_date,
|
||||
WeatherData.date <= end_date
|
||||
)
|
||||
).order_by(WeatherData.date)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
records = result.scalars().all()
|
||||
logger.debug(f"Retrieved {len(records)} historical records for location {location_id}")
|
||||
return list(records)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to get historical weather from repository",
|
||||
error=str(e),
|
||||
location_id=location_id
|
||||
)
|
||||
raise
|
||||
|
||||
def _serialize_json_fields(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize JSON fields (raw_data, processed_data) to ensure proper JSON storage
|
||||
"""
|
||||
serialized = data.copy()
|
||||
|
||||
# Serialize raw_data if present
|
||||
if 'raw_data' in serialized and serialized['raw_data'] is not None:
|
||||
if not isinstance(serialized['raw_data'], str):
|
||||
try:
|
||||
# Convert datetime objects to strings for JSON serialization
|
||||
raw_data = serialized['raw_data']
|
||||
if isinstance(raw_data, dict):
|
||||
# Handle datetime objects in the dict
|
||||
json_safe_data = {}
|
||||
for k, v in raw_data.items():
|
||||
if hasattr(v, 'isoformat'): # datetime-like object
|
||||
json_safe_data[k] = v.isoformat()
|
||||
else:
|
||||
json_safe_data[k] = v
|
||||
serialized['raw_data'] = json_safe_data
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not serialize raw_data, storing as string: {e}")
|
||||
serialized['raw_data'] = str(raw_data)
|
||||
|
||||
# Serialize processed_data if present
|
||||
if 'processed_data' in serialized and serialized['processed_data'] is not None:
|
||||
if not isinstance(serialized['processed_data'], str):
|
||||
try:
|
||||
processed_data = serialized['processed_data']
|
||||
if isinstance(processed_data, dict):
|
||||
json_safe_data = {}
|
||||
for k, v in processed_data.items():
|
||||
if hasattr(v, 'isoformat'): # datetime-like object
|
||||
json_safe_data[k] = v.isoformat()
|
||||
else:
|
||||
json_safe_data[k] = v
|
||||
serialized['processed_data'] = json_safe_data
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not serialize processed_data, storing as string: {e}")
|
||||
serialized['processed_data'] = str(processed_data)
|
||||
|
||||
return serialized
|
||||
|
||||
async def bulk_create_weather_data(self, weather_records: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Bulk inserts new weather records into the database.
|
||||
Used by WeatherService after fetching new historical data from an external API.
|
||||
"""
|
||||
try:
|
||||
if not weather_records:
|
||||
return
|
||||
|
||||
# Serialize JSON fields before creating model instances
|
||||
serialized_records = [self._serialize_json_fields(data) for data in weather_records]
|
||||
records = [WeatherData(**data) for data in serialized_records]
|
||||
self.session.add_all(records)
|
||||
await self.session.commit()
|
||||
logger.info(f"Successfully bulk inserted {len(records)} weather records")
|
||||
|
||||
except Exception as e:
|
||||
await self.session.rollback()
|
||||
logger.error(
|
||||
"Failed to bulk create weather records",
|
||||
error=str(e),
|
||||
count=len(weather_records)
|
||||
)
|
||||
raise
|
||||
|
||||
async def create_weather_data(self, data: Dict[str, Any]) -> WeatherData:
|
||||
"""
|
||||
Creates a single new weather data record.
|
||||
"""
|
||||
try:
|
||||
# Serialize JSON fields before creating model instance
|
||||
serialized_data = self._serialize_json_fields(data)
|
||||
new_record = WeatherData(**serialized_data)
|
||||
self.session.add(new_record)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(new_record)
|
||||
logger.info(f"Created new weather record with ID {new_record.id}")
|
||||
return new_record
|
||||
|
||||
except Exception as e:
|
||||
await self.session.rollback()
|
||||
logger.error("Failed to create single weather record", error=str(e))
|
||||
raise
|
||||
1
services/external/app/schemas/__init__.py
vendored
Normal file
1
services/external/app/schemas/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
# services/external/app/schemas/__init__.py
|
||||
@@ -1,7 +1,7 @@
|
||||
# ================================================================
|
||||
# services/data/app/schemas/traffic.py
|
||||
# ================================================================
|
||||
"""Traffic data schemas"""
|
||||
# services/external/app/schemas/traffic.py
|
||||
"""
|
||||
Traffic Service Pydantic Schemas
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from datetime import datetime
|
||||
@@ -69,4 +69,32 @@ class TrafficAnalytics(BaseModel):
|
||||
peak_traffic_hour: Optional[int] = None
|
||||
peak_pedestrian_hour: Optional[int] = None
|
||||
congestion_distribution: dict = Field(default_factory=dict)
|
||||
avg_speed: Optional[float] = None
|
||||
avg_speed: Optional[float] = None
|
||||
|
||||
class TrafficDataResponse(BaseModel):
|
||||
date: datetime
|
||||
traffic_volume: Optional[int]
|
||||
pedestrian_count: Optional[int]
|
||||
congestion_level: Optional[str]
|
||||
average_speed: Optional[float]
|
||||
source: str
|
||||
|
||||
class LocationRequest(BaseModel):
|
||||
latitude: float
|
||||
longitude: float
|
||||
address: Optional[str] = None
|
||||
|
||||
class DateRangeRequest(BaseModel):
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
|
||||
class HistoricalTrafficRequest(BaseModel):
|
||||
latitude: float
|
||||
longitude: float
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
|
||||
class TrafficForecastRequest(BaseModel):
|
||||
latitude: float
|
||||
longitude: float
|
||||
hours: int = 24
|
||||
@@ -1,6 +1,4 @@
|
||||
# ================================================================
|
||||
# services/data/app/schemas/weather.py
|
||||
# ================================================================
|
||||
# services/external/app/schemas/weather.py
|
||||
"""Weather data schemas"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@@ -120,4 +118,44 @@ class WeatherAnalytics(BaseModel):
|
||||
avg_pressure: Optional[float] = None
|
||||
weather_conditions: dict = Field(default_factory=dict)
|
||||
rainy_days: int = 0
|
||||
sunny_days: int = 0
|
||||
sunny_days: int = 0
|
||||
|
||||
class WeatherDataResponse(BaseModel):
|
||||
date: datetime
|
||||
temperature: Optional[float]
|
||||
precipitation: Optional[float]
|
||||
humidity: Optional[float]
|
||||
wind_speed: Optional[float]
|
||||
pressure: Optional[float]
|
||||
description: Optional[str]
|
||||
source: str
|
||||
|
||||
class WeatherForecastResponse(BaseModel):
|
||||
forecast_date: datetime
|
||||
generated_at: datetime
|
||||
temperature: Optional[float]
|
||||
precipitation: Optional[float]
|
||||
humidity: Optional[float]
|
||||
wind_speed: Optional[float]
|
||||
description: Optional[str]
|
||||
source: str
|
||||
|
||||
class LocationRequest(BaseModel):
|
||||
latitude: float
|
||||
longitude: float
|
||||
address: Optional[str] = None
|
||||
|
||||
class DateRangeRequest(BaseModel):
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
|
||||
class HistoricalWeatherRequest(BaseModel):
|
||||
latitude: float
|
||||
longitude: float
|
||||
start_date: datetime
|
||||
end_date: datetime
|
||||
|
||||
class WeatherForecastRequest(BaseModel):
|
||||
latitude: float
|
||||
longitude: float
|
||||
days: int
|
||||
1
services/external/app/services/__init__.py
vendored
Normal file
1
services/external/app/services/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
# services/external/app/services/__init__.py
|
||||
63
services/external/app/services/messaging.py
vendored
Normal file
63
services/external/app/services/messaging.py
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
# services/external/app/services/messaging.py
|
||||
"""
|
||||
External Service Messaging - Event Publishing using shared messaging infrastructure
|
||||
"""
|
||||
|
||||
from shared.messaging.rabbitmq import RabbitMQClient
|
||||
from app.core.config import settings
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Single global instance
|
||||
data_publisher = RabbitMQClient(settings.RABBITMQ_URL, "data-service")
|
||||
|
||||
async def setup_messaging():
|
||||
"""Initialize messaging for data service"""
|
||||
try:
|
||||
success = await data_publisher.connect()
|
||||
if success:
|
||||
logger.info("Data service messaging initialized")
|
||||
else:
|
||||
logger.warning("Data service messaging failed to initialize")
|
||||
return success
|
||||
except Exception as e:
|
||||
logger.warning("Failed to setup messaging", error=str(e))
|
||||
return False
|
||||
|
||||
async def cleanup_messaging():
|
||||
"""Cleanup messaging for data service"""
|
||||
try:
|
||||
await data_publisher.disconnect()
|
||||
logger.info("Data service messaging cleaned up")
|
||||
except Exception as e:
|
||||
logger.warning("Error during messaging cleanup", error=str(e))
|
||||
|
||||
async def publish_weather_updated(data: dict) -> bool:
|
||||
"""Publish weather updated event"""
|
||||
try:
|
||||
return await data_publisher.publish_data_event("weather.updated", data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish weather updated event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_traffic_updated(data: dict) -> bool:
|
||||
"""Publish traffic updated event"""
|
||||
try:
|
||||
return await data_publisher.publish_data_event("traffic.updated", data)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish traffic updated event", error=str(e))
|
||||
return False
|
||||
|
||||
|
||||
|
||||
# Health check for messaging
|
||||
async def check_messaging_health() -> dict:
|
||||
"""Check messaging system health"""
|
||||
try:
|
||||
if data_publisher.connected:
|
||||
return {"status": "healthy", "service": "rabbitmq", "connected": True}
|
||||
else:
|
||||
return {"status": "unhealthy", "service": "rabbitmq", "connected": False, "error": "Not connected"}
|
||||
except Exception as e:
|
||||
return {"status": "unhealthy", "service": "rabbitmq", "connected": False, "error": str(e)}
|
||||
298
services/external/app/services/traffic_service.py
vendored
Normal file
298
services/external/app/services/traffic_service.py
vendored
Normal file
@@ -0,0 +1,298 @@
|
||||
# ================================================================
|
||||
# services/data/app/services/traffic_service.py
|
||||
# ================================================================
|
||||
"""
|
||||
Abstracted Traffic Service - Universal interface for traffic data across multiple cities
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import structlog
|
||||
|
||||
from app.external.apis.traffic import UniversalTrafficClient
|
||||
from app.models.traffic import TrafficData
|
||||
from app.repositories.traffic_repository import TrafficRepository
|
||||
|
||||
logger = structlog.get_logger()
|
||||
from app.core.database import database_manager
|
||||
|
||||
class TrafficService:
|
||||
"""
|
||||
Abstracted traffic service providing unified interface for traffic data
|
||||
Routes requests to appropriate city-specific clients automatically
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.universal_client = UniversalTrafficClient()
|
||||
self.database_manager = database_manager
|
||||
|
||||
async def get_current_traffic(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get current traffic data for any supported location
|
||||
|
||||
Args:
|
||||
latitude: Query location latitude
|
||||
longitude: Query location longitude
|
||||
tenant_id: Optional tenant identifier for logging/analytics
|
||||
|
||||
Returns:
|
||||
Dict with current traffic data or None if not available
|
||||
"""
|
||||
try:
|
||||
logger.info("Getting current traffic data",
|
||||
lat=latitude, lon=longitude, tenant_id=tenant_id)
|
||||
|
||||
# Delegate to universal client
|
||||
traffic_data = await self.universal_client.get_current_traffic(latitude, longitude)
|
||||
|
||||
if traffic_data:
|
||||
# Add service metadata
|
||||
traffic_data['service_metadata'] = {
|
||||
'request_timestamp': datetime.now().isoformat(),
|
||||
'tenant_id': tenant_id,
|
||||
'service_version': '2.0',
|
||||
'query_location': {'latitude': latitude, 'longitude': longitude}
|
||||
}
|
||||
|
||||
logger.info("Successfully retrieved current traffic data",
|
||||
lat=latitude, lon=longitude,
|
||||
source=traffic_data.get('source', 'unknown'))
|
||||
|
||||
return traffic_data
|
||||
else:
|
||||
logger.warning("No current traffic data available",
|
||||
lat=latitude, lon=longitude)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting current traffic data",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return None
|
||||
|
||||
async def get_historical_traffic(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get historical traffic data for any supported location with database storage
|
||||
|
||||
Args:
|
||||
latitude: Query location latitude
|
||||
longitude: Query location longitude
|
||||
start_date: Start date for historical data
|
||||
end_date: End date for historical data
|
||||
tenant_id: Optional tenant identifier
|
||||
|
||||
Returns:
|
||||
List of historical traffic data dictionaries
|
||||
"""
|
||||
try:
|
||||
logger.info("Getting historical traffic data",
|
||||
lat=latitude, lon=longitude,
|
||||
start=start_date, end=end_date, tenant_id=tenant_id)
|
||||
|
||||
# Validate date range
|
||||
if start_date >= end_date:
|
||||
logger.warning("Invalid date range", start=start_date, end=end_date)
|
||||
return []
|
||||
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
async with self.database_manager.get_session() as session:
|
||||
traffic_repo = TrafficRepository(session)
|
||||
# Check database first using the repository
|
||||
db_records = await traffic_repo.get_by_location_and_date_range(
|
||||
latitude, longitude, start_date, end_date, tenant_id
|
||||
)
|
||||
|
||||
if db_records:
|
||||
logger.info("Historical traffic data found in database",
|
||||
count=len(db_records))
|
||||
return [self._convert_db_record_to_dict(record) for record in db_records]
|
||||
|
||||
# Delegate to universal client if not in DB
|
||||
traffic_data = await self.universal_client.get_historical_traffic(
|
||||
latitude, longitude, start_date, end_date
|
||||
)
|
||||
|
||||
if traffic_data:
|
||||
# Add service metadata to each record
|
||||
for record in traffic_data:
|
||||
record['service_metadata'] = {
|
||||
'request_timestamp': datetime.now().isoformat(),
|
||||
'tenant_id': tenant_id,
|
||||
'service_version': '2.0',
|
||||
'query_location': {'latitude': latitude, 'longitude': longitude},
|
||||
'date_range': {
|
||||
'start': start_date.isoformat(),
|
||||
'end': end_date.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
async with self.database_manager.get_session() as session:
|
||||
traffic_repo = TrafficRepository(session)
|
||||
# Store in database using the repository
|
||||
stored_count = await traffic_repo.store_traffic_data_batch(
|
||||
traffic_data, location_id, tenant_id
|
||||
)
|
||||
logger.info("Traffic data stored for re-training",
|
||||
fetched=len(traffic_data), stored=stored_count,
|
||||
location=location_id)
|
||||
|
||||
logger.info("Successfully retrieved historical traffic data",
|
||||
lat=latitude, lon=longitude, records=len(traffic_data))
|
||||
|
||||
return traffic_data
|
||||
else:
|
||||
logger.info("No historical traffic data available",
|
||||
lat=latitude, lon=longitude)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting historical traffic data",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return []
|
||||
|
||||
def _convert_db_record_to_dict(self, record: TrafficData) -> Dict[str, Any]:
|
||||
"""Convert database record to dictionary format"""
|
||||
return {
|
||||
'date': record.date,
|
||||
'traffic_volume': record.traffic_volume,
|
||||
'pedestrian_count': record.pedestrian_count,
|
||||
'congestion_level': record.congestion_level,
|
||||
'average_speed': record.average_speed,
|
||||
'source': record.source,
|
||||
'location_id': record.location_id,
|
||||
'raw_data': record.raw_data
|
||||
}
|
||||
|
||||
async def get_traffic_events(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
radius_km: float = 5.0,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get traffic events and incidents for any supported location
|
||||
|
||||
Args:
|
||||
latitude: Query location latitude
|
||||
longitude: Query location longitude
|
||||
radius_km: Search radius in kilometers
|
||||
tenant_id: Optional tenant identifier
|
||||
|
||||
Returns:
|
||||
List of traffic events
|
||||
"""
|
||||
try:
|
||||
logger.info("Getting traffic events",
|
||||
lat=latitude, lon=longitude, radius=radius_km, tenant_id=tenant_id)
|
||||
|
||||
# Delegate to universal client
|
||||
events = await self.universal_client.get_events(latitude, longitude, radius_km)
|
||||
|
||||
# Add metadata to events
|
||||
for event in events:
|
||||
event['service_metadata'] = {
|
||||
'request_timestamp': datetime.now().isoformat(),
|
||||
'tenant_id': tenant_id,
|
||||
'service_version': '2.0',
|
||||
'query_location': {'latitude': latitude, 'longitude': longitude},
|
||||
'search_radius_km': radius_km
|
||||
}
|
||||
|
||||
logger.info("Retrieved traffic events",
|
||||
lat=latitude, lon=longitude, events=len(events))
|
||||
|
||||
return events
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting traffic events",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return []
|
||||
|
||||
def get_location_info(self, latitude: float, longitude: float) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about traffic data availability for location
|
||||
|
||||
Args:
|
||||
latitude: Query location latitude
|
||||
longitude: Query location longitude
|
||||
|
||||
Returns:
|
||||
Dict with location support information
|
||||
"""
|
||||
try:
|
||||
info = self.universal_client.get_location_info(latitude, longitude)
|
||||
|
||||
# Add service layer information
|
||||
info['service_layer'] = {
|
||||
'version': '2.0',
|
||||
'abstraction_level': 'universal',
|
||||
'supported_operations': [
|
||||
'current_traffic',
|
||||
'historical_traffic',
|
||||
'traffic_events',
|
||||
'bulk_requests'
|
||||
]
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting location info",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return {
|
||||
'supported': False,
|
||||
'error': str(e),
|
||||
'service_layer': {'version': '2.0'}
|
||||
}
|
||||
|
||||
async def get_stored_traffic_for_training(self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime) -> List[Dict[str, Any]]:
|
||||
"""Retrieve stored traffic data specifically for training purposes"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
traffic_repo = TrafficRepository(session)
|
||||
records = await traffic_repo.get_historical_traffic_for_training(
|
||||
latitude, longitude, start_date, end_date
|
||||
)
|
||||
|
||||
# Convert to training format
|
||||
training_data = []
|
||||
for record in records:
|
||||
training_data.append({
|
||||
'date': record.date,
|
||||
'traffic_volume': record.traffic_volume,
|
||||
'pedestrian_count': record.pedestrian_count,
|
||||
'congestion_level': record.congestion_level,
|
||||
'average_speed': record.average_speed,
|
||||
'location_id': record.location_id,
|
||||
'source': record.source,
|
||||
'measurement_point_id': record.raw_data # Contains additional metadata
|
||||
})
|
||||
|
||||
logger.info(f"Retrieved {len(training_data)} traffic records for training",
|
||||
location_id=f"{latitude:.4f},{longitude:.4f}", start=start_date, end=end_date)
|
||||
|
||||
return training_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to retrieve traffic data for training",
|
||||
error=str(e), location_id=f"{latitude:.4f},{longitude:.4f}")
|
||||
return []
|
||||
@@ -1,24 +1,25 @@
|
||||
# ================================================================
|
||||
# services/data/app/services/weather_service.py - FIXED VERSION
|
||||
# ================================================================
|
||||
"""Weather data service with improved error handling"""
|
||||
# services/data/app/services/weather_service.py - REVISED VERSION
|
||||
|
||||
"""Weather data service with repository pattern"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
import structlog
|
||||
|
||||
from app.models.weather import WeatherData, WeatherForecast
|
||||
from app.external.aemet import AEMETClient
|
||||
from app.schemas.external import WeatherDataResponse, WeatherForecastResponse
|
||||
from app.schemas.weather import WeatherDataResponse, WeatherForecastResponse
|
||||
from app.repositories.weather_repository import WeatherRepository
|
||||
|
||||
logger = structlog.get_logger()
|
||||
from app.core.database import database_manager
|
||||
|
||||
class WeatherService:
|
||||
|
||||
def __init__(self):
|
||||
self.aemet_client = AEMETClient()
|
||||
self.database_manager = database_manager
|
||||
|
||||
async def get_current_weather(self, latitude: float, longitude: float) -> Optional[WeatherDataResponse]:
|
||||
"""Get current weather for location"""
|
||||
@@ -82,26 +83,23 @@ class WeatherService:
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
db: AsyncSession) -> List[WeatherDataResponse]:
|
||||
end_date: datetime) -> List[WeatherDataResponse]:
|
||||
"""Get historical weather data"""
|
||||
try:
|
||||
logger.debug("Getting historical weather",
|
||||
lat=latitude, lon=longitude,
|
||||
start=start_date, end=end_date)
|
||||
|
||||
# First check database
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
stmt = select(WeatherData).where(
|
||||
and_(
|
||||
WeatherData.location_id == location_id,
|
||||
WeatherData.date >= start_date,
|
||||
WeatherData.date <= end_date
|
||||
|
||||
async with self.database_manager.get_session() as session:
|
||||
weather_repository = WeatherRepository(session)
|
||||
# Use the repository to get data from the database
|
||||
db_records = await weather_repository.get_historical_weather(
|
||||
location_id,
|
||||
start_date,
|
||||
end_date
|
||||
)
|
||||
).order_by(WeatherData.date)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
db_records = result.scalars().all()
|
||||
|
||||
if db_records:
|
||||
logger.debug("Historical data found in database", count=len(db_records))
|
||||
@@ -123,28 +121,28 @@ class WeatherService:
|
||||
)
|
||||
|
||||
if weather_data:
|
||||
# Store in database for future use
|
||||
try:
|
||||
for data in weather_data:
|
||||
weather_record = WeatherData(
|
||||
location_id=location_id,
|
||||
date=data.get('date', datetime.now()),
|
||||
temperature=data.get('temperature'),
|
||||
precipitation=data.get('precipitation'),
|
||||
humidity=data.get('humidity'),
|
||||
wind_speed=data.get('wind_speed'),
|
||||
pressure=data.get('pressure'),
|
||||
description=data.get('description'),
|
||||
source="aemet",
|
||||
raw_data=str(data)
|
||||
)
|
||||
db.add(weather_record)
|
||||
|
||||
await db.commit()
|
||||
logger.debug("Historical data stored in database", count=len(weather_data))
|
||||
except Exception as db_error:
|
||||
logger.warning("Failed to store historical data in database", error=str(db_error))
|
||||
await db.rollback()
|
||||
# Use the repository to store the new data
|
||||
records_to_store = [{
|
||||
"location_id": location_id,
|
||||
"city": "Madrid", # Default city for AEMET data
|
||||
"date": data.get('date', datetime.now()),
|
||||
"temperature": data.get('temperature'),
|
||||
"precipitation": data.get('precipitation'),
|
||||
"humidity": data.get('humidity'),
|
||||
"wind_speed": data.get('wind_speed'),
|
||||
"pressure": data.get('pressure'),
|
||||
"description": data.get('description'),
|
||||
"source": "aemet",
|
||||
"data_type": "historical",
|
||||
"raw_data": data, # Pass as dict, not string
|
||||
"tenant_id": None
|
||||
} for data in weather_data]
|
||||
|
||||
async with self.database_manager.get_session() as session:
|
||||
weather_repository = WeatherRepository(session)
|
||||
await weather_repository.bulk_create_weather_data(records_to_store)
|
||||
|
||||
logger.debug("Historical data stored in database", count=len(weather_data))
|
||||
|
||||
return [WeatherDataResponse(**item) for item in weather_data]
|
||||
else:
|
||||
19
services/external/pytest.ini
vendored
Normal file
19
services/external/pytest.ini
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
[tool:pytest]
|
||||
testpaths = tests
|
||||
asyncio_mode = auto
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts =
|
||||
-v
|
||||
--tb=short
|
||||
--strict-markers
|
||||
--disable-warnings
|
||||
--cov=app
|
||||
--cov-report=term-missing
|
||||
--cov-report=html:htmlcov
|
||||
markers =
|
||||
unit: Unit tests
|
||||
integration: Integration tests
|
||||
slow: Slow running tests
|
||||
external: Tests requiring external services
|
||||
56
services/external/requirements.txt
vendored
Normal file
56
services/external/requirements.txt
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
# services/external/requirements.txt
|
||||
# FastAPI and web framework
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
|
||||
# Database
|
||||
sqlalchemy==2.0.23
|
||||
psycopg2-binary==2.9.9
|
||||
asyncpg==0.29.0
|
||||
aiosqlite==0.19.0
|
||||
alembic==1.12.1
|
||||
|
||||
# HTTP clients for external APIs
|
||||
httpx==0.25.2
|
||||
aiofiles==23.2.0
|
||||
requests==2.31.0
|
||||
|
||||
# Data processing and time series
|
||||
pandas==2.1.3
|
||||
numpy==1.25.2
|
||||
|
||||
# Validation and serialization
|
||||
pydantic==2.5.0
|
||||
pydantic-settings==2.0.3
|
||||
|
||||
# Authentication and security
|
||||
python-jose[cryptography]==3.3.0
|
||||
|
||||
# Logging and monitoring
|
||||
structlog==23.2.0
|
||||
prometheus-client==0.19.0
|
||||
|
||||
# Message queues
|
||||
aio-pika==9.3.1
|
||||
|
||||
# Background job processing
|
||||
redis==5.0.1
|
||||
|
||||
# Date and time handling
|
||||
pytz==2023.3
|
||||
python-dateutil==2.8.2
|
||||
|
||||
# XML parsing (for some APIs)
|
||||
lxml==4.9.3
|
||||
|
||||
# Geospatial processing
|
||||
pyproj==3.6.1
|
||||
|
||||
# Note: pytest and testing dependencies are in tests/requirements.txt
|
||||
|
||||
# Development
|
||||
python-multipart==0.0.6
|
||||
|
||||
# External API specific
|
||||
beautifulsoup4==4.12.2 # For web scraping if needed
|
||||
xmltodict==0.13.0 # For XML API responses
|
||||
1
services/external/shared/shared
vendored
Symbolic link
1
services/external/shared/shared
vendored
Symbolic link
@@ -0,0 +1 @@
|
||||
/Users/urtzialfaro/Documents/bakery-ia/shared
|
||||
314
services/external/tests/conftest.py
vendored
Normal file
314
services/external/tests/conftest.py
vendored
Normal file
@@ -0,0 +1,314 @@
|
||||
# services/external/tests/conftest.py
|
||||
"""
|
||||
Pytest configuration and fixtures for External Service tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
from app.core.config import settings
|
||||
from app.core.database import Base, get_db
|
||||
from app.models.weather import WeatherData, WeatherStation
|
||||
from app.models.traffic import TrafficData, TrafficMeasurementPoint
|
||||
|
||||
|
||||
# Test database configuration
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create event loop for the test session"""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_engine():
|
||||
"""Create test database engine"""
|
||||
engine = create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
poolclass=StaticPool,
|
||||
connect_args={"check_same_thread": False}
|
||||
)
|
||||
|
||||
# Create tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create test database session"""
|
||||
async_session = async_sessionmaker(
|
||||
test_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client():
|
||||
"""Create test client"""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def override_get_db(test_db_session):
|
||||
"""Override get_db dependency for testing"""
|
||||
async def _override_get_db():
|
||||
yield test_db_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# Test data fixtures
|
||||
@pytest.fixture
|
||||
def sample_tenant_id() -> UUID:
|
||||
"""Sample tenant ID for testing"""
|
||||
return uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_weather_data() -> dict:
|
||||
"""Sample weather data for testing"""
|
||||
return {
|
||||
"city": "madrid",
|
||||
"location_id": "40.4168,-3.7038",
|
||||
"date": datetime.now(timezone.utc),
|
||||
"temperature": 18.5,
|
||||
"humidity": 65.0,
|
||||
"pressure": 1013.2,
|
||||
"wind_speed": 10.2,
|
||||
"condition": "partly_cloudy",
|
||||
"description": "Parcialmente nublado",
|
||||
"source": "aemet",
|
||||
"data_type": "current",
|
||||
"is_forecast": False,
|
||||
"data_quality_score": 95.0
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_traffic_data() -> dict:
|
||||
"""Sample traffic data for testing"""
|
||||
return {
|
||||
"city": "madrid",
|
||||
"location_id": "PM_M30_001",
|
||||
"date": datetime.now(timezone.utc),
|
||||
"measurement_point_id": "PM_M30_001",
|
||||
"measurement_point_name": "M-30 Norte - Nudo Norte",
|
||||
"measurement_point_type": "M30",
|
||||
"traffic_volume": 850,
|
||||
"average_speed": 65.2,
|
||||
"congestion_level": "medium",
|
||||
"occupation_percentage": 45.8,
|
||||
"latitude": 40.4501,
|
||||
"longitude": -3.6919,
|
||||
"district": "Chamartín",
|
||||
"source": "madrid_opendata",
|
||||
"data_quality_score": 92.0,
|
||||
"is_synthetic": False
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_weather_forecast() -> list[dict]:
|
||||
"""Sample weather forecast data"""
|
||||
base_date = datetime.now(timezone.utc)
|
||||
return [
|
||||
{
|
||||
"city": "madrid",
|
||||
"location_id": "40.4168,-3.7038",
|
||||
"date": base_date,
|
||||
"forecast_date": base_date,
|
||||
"temperature": 20.0,
|
||||
"temperature_min": 15.0,
|
||||
"temperature_max": 25.0,
|
||||
"precipitation": 0.0,
|
||||
"humidity": 60.0,
|
||||
"wind_speed": 12.0,
|
||||
"condition": "sunny",
|
||||
"description": "Soleado",
|
||||
"source": "aemet",
|
||||
"data_type": "forecast",
|
||||
"is_forecast": True,
|
||||
"data_quality_score": 85.0
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def populated_weather_db(test_db_session: AsyncSession, sample_weather_data: dict):
|
||||
"""Database populated with weather test data"""
|
||||
weather_record = WeatherData(**sample_weather_data)
|
||||
test_db_session.add(weather_record)
|
||||
await test_db_session.commit()
|
||||
yield test_db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def populated_traffic_db(test_db_session: AsyncSession, sample_traffic_data: dict):
|
||||
"""Database populated with traffic test data"""
|
||||
traffic_record = TrafficData(**sample_traffic_data)
|
||||
test_db_session.add(traffic_record)
|
||||
await test_db_session.commit()
|
||||
yield test_db_session
|
||||
|
||||
|
||||
# Mock external API fixtures
|
||||
@pytest.fixture
|
||||
def mock_aemet_response():
|
||||
"""Mock AEMET API response"""
|
||||
return {
|
||||
"date": datetime.now(timezone.utc),
|
||||
"temperature": 18.5,
|
||||
"humidity": 65.0,
|
||||
"pressure": 1013.2,
|
||||
"wind_speed": 10.2,
|
||||
"description": "Parcialmente nublado",
|
||||
"source": "aemet"
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_madrid_traffic_xml():
|
||||
"""Mock Madrid Open Data traffic XML"""
|
||||
return """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<pms>
|
||||
<pm codigo="PM_M30_001" nombre="M-30 Norte - Nudo Norte">
|
||||
<intensidad>850</intensidad>
|
||||
<ocupacion>45</ocupacion>
|
||||
<velocidad>65</velocidad>
|
||||
<fechahora>2024-01-15T10:30:00</fechahora>
|
||||
</pm>
|
||||
<pm codigo="PM_URB_002" nombre="Gran Vía - Plaza España">
|
||||
<intensidad>320</intensidad>
|
||||
<ocupacion>78</ocupacion>
|
||||
<velocidad>25</velocidad>
|
||||
<fechahora>2024-01-15T10:30:00</fechahora>
|
||||
</pm>
|
||||
</pms>"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_messaging():
|
||||
"""Mock messaging service"""
|
||||
class MockMessaging:
|
||||
def __init__(self):
|
||||
self.published_events = []
|
||||
|
||||
async def publish_weather_updated(self, data):
|
||||
self.published_events.append(("weather_updated", data))
|
||||
return True
|
||||
|
||||
async def publish_traffic_updated(self, data):
|
||||
self.published_events.append(("traffic_updated", data))
|
||||
return True
|
||||
|
||||
async def publish_collection_job_started(self, data):
|
||||
self.published_events.append(("job_started", data))
|
||||
return True
|
||||
|
||||
async def publish_collection_job_completed(self, data):
|
||||
self.published_events.append(("job_completed", data))
|
||||
return True
|
||||
|
||||
return MockMessaging()
|
||||
|
||||
|
||||
# Mock external clients
|
||||
@pytest.fixture
|
||||
def mock_aemet_client():
|
||||
"""Mock AEMET client"""
|
||||
class MockAEMETClient:
|
||||
async def get_current_weather(self, lat, lon):
|
||||
return {
|
||||
"date": datetime.now(timezone.utc),
|
||||
"temperature": 18.5,
|
||||
"humidity": 65.0,
|
||||
"pressure": 1013.2,
|
||||
"wind_speed": 10.2,
|
||||
"description": "Parcialmente nublado",
|
||||
"source": "aemet"
|
||||
}
|
||||
|
||||
async def get_forecast(self, lat, lon, days):
|
||||
return [
|
||||
{
|
||||
"forecast_date": datetime.now(timezone.utc),
|
||||
"temperature": 20.0,
|
||||
"temperature_min": 15.0,
|
||||
"temperature_max": 25.0,
|
||||
"precipitation": 0.0,
|
||||
"humidity": 60.0,
|
||||
"wind_speed": 12.0,
|
||||
"description": "Soleado",
|
||||
"source": "aemet"
|
||||
}
|
||||
]
|
||||
|
||||
return MockAEMETClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_madrid_client():
|
||||
"""Mock Madrid traffic client"""
|
||||
class MockMadridClient:
|
||||
async def fetch_current_traffic_xml(self):
|
||||
return """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<pms>
|
||||
<pm codigo="PM_TEST_001" nombre="Test Point">
|
||||
<intensidad>500</intensidad>
|
||||
<ocupacion>50</ocupacion>
|
||||
<velocidad>50</velocidad>
|
||||
<fechahora>2024-01-15T10:30:00</fechahora>
|
||||
</pm>
|
||||
</pms>"""
|
||||
|
||||
return MockMadridClient()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_madrid_processor():
|
||||
"""Mock Madrid traffic processor"""
|
||||
class MockMadridProcessor:
|
||||
async def process_current_traffic_xml(self, xml_content):
|
||||
return [
|
||||
{
|
||||
"city": "madrid",
|
||||
"location_id": "PM_TEST_001",
|
||||
"date": datetime.now(timezone.utc),
|
||||
"measurement_point_id": "PM_TEST_001",
|
||||
"measurement_point_name": "Test Point",
|
||||
"measurement_point_type": "TEST",
|
||||
"traffic_volume": 500,
|
||||
"average_speed": 50.0,
|
||||
"congestion_level": "medium",
|
||||
"occupation_percentage": 50.0,
|
||||
"latitude": 40.4168,
|
||||
"longitude": -3.7038,
|
||||
"district": "Centro",
|
||||
"source": "madrid_opendata",
|
||||
"data_quality_score": 90.0,
|
||||
"is_synthetic": False
|
||||
}
|
||||
]
|
||||
|
||||
return MockMadridProcessor()
|
||||
9
services/external/tests/requirements.txt
vendored
Normal file
9
services/external/tests/requirements.txt
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
# Testing dependencies for External Service
|
||||
pytest==7.4.3
|
||||
pytest-asyncio==0.21.1
|
||||
pytest-mock==3.12.0
|
||||
httpx==0.25.2
|
||||
fastapi[all]==0.104.1
|
||||
sqlalchemy[asyncio]==2.0.23
|
||||
aiosqlite==0.19.0
|
||||
coverage==7.3.2
|
||||
393
services/external/tests/unit/test_repositories.py
vendored
Normal file
393
services/external/tests/unit/test_repositories.py
vendored
Normal file
@@ -0,0 +1,393 @@
|
||||
# services/external/tests/unit/test_repositories.py
|
||||
"""
|
||||
Unit tests for External Service Repositories
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
from app.repositories.weather_repository import WeatherRepository
|
||||
from app.repositories.traffic_repository import TrafficRepository
|
||||
from app.models.weather import WeatherData, WeatherStation, WeatherDataJob
|
||||
from app.models.traffic import TrafficData, TrafficMeasurementPoint, TrafficDataJob
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestWeatherRepository:
|
||||
"""Test Weather Repository operations"""
|
||||
|
||||
async def test_create_weather_data(self, test_db_session, sample_weather_data):
|
||||
"""Test creating weather data"""
|
||||
repository = WeatherRepository(test_db_session)
|
||||
|
||||
record = await repository.create_weather_data(sample_weather_data)
|
||||
|
||||
assert record is not None
|
||||
assert record.id is not None
|
||||
assert record.city == sample_weather_data["city"]
|
||||
assert record.temperature == sample_weather_data["temperature"]
|
||||
|
||||
async def test_get_current_weather(self, populated_weather_db, sample_weather_data):
|
||||
"""Test getting current weather data"""
|
||||
repository = WeatherRepository(populated_weather_db)
|
||||
|
||||
result = await repository.get_current_weather("madrid")
|
||||
|
||||
assert result is not None
|
||||
assert result.city == "madrid"
|
||||
assert result.temperature == sample_weather_data["temperature"]
|
||||
|
||||
async def test_get_weather_forecast(self, test_db_session, sample_weather_forecast):
|
||||
"""Test getting weather forecast"""
|
||||
repository = WeatherRepository(test_db_session)
|
||||
|
||||
# Create forecast data
|
||||
for forecast_item in sample_weather_forecast:
|
||||
await repository.create_weather_data(forecast_item)
|
||||
|
||||
result = await repository.get_weather_forecast("madrid", 7)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].is_forecast is True
|
||||
|
||||
async def test_get_historical_weather(self, test_db_session, sample_weather_data):
|
||||
"""Test getting historical weather data"""
|
||||
repository = WeatherRepository(test_db_session)
|
||||
|
||||
# Create historical data
|
||||
historical_data = sample_weather_data.copy()
|
||||
historical_data["date"] = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
await repository.create_weather_data(historical_data)
|
||||
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=2)
|
||||
end_date = datetime.now(timezone.utc)
|
||||
|
||||
result = await repository.get_historical_weather("madrid", start_date, end_date)
|
||||
|
||||
assert len(result) >= 1
|
||||
|
||||
async def test_create_weather_station(self, test_db_session):
|
||||
"""Test creating weather station"""
|
||||
repository = WeatherRepository(test_db_session)
|
||||
|
||||
station_data = {
|
||||
"station_id": "TEST_001",
|
||||
"name": "Test Station",
|
||||
"city": "madrid",
|
||||
"latitude": 40.4168,
|
||||
"longitude": -3.7038,
|
||||
"altitude": 650.0,
|
||||
"is_active": True
|
||||
}
|
||||
|
||||
station = await repository.create_weather_station(station_data)
|
||||
|
||||
assert station is not None
|
||||
assert station.station_id == "TEST_001"
|
||||
assert station.name == "Test Station"
|
||||
|
||||
async def test_get_weather_stations(self, test_db_session):
|
||||
"""Test getting weather stations"""
|
||||
repository = WeatherRepository(test_db_session)
|
||||
|
||||
# Create test station
|
||||
station_data = {
|
||||
"station_id": "TEST_001",
|
||||
"name": "Test Station",
|
||||
"city": "madrid",
|
||||
"latitude": 40.4168,
|
||||
"longitude": -3.7038,
|
||||
"is_active": True
|
||||
}
|
||||
await repository.create_weather_station(station_data)
|
||||
|
||||
stations = await repository.get_weather_stations("madrid")
|
||||
|
||||
assert len(stations) == 1
|
||||
assert stations[0].station_id == "TEST_001"
|
||||
|
||||
async def test_create_weather_job(self, test_db_session, sample_tenant_id):
|
||||
"""Test creating weather data collection job"""
|
||||
repository = WeatherRepository(test_db_session)
|
||||
|
||||
job_data = {
|
||||
"job_type": "current",
|
||||
"city": "madrid",
|
||||
"status": "pending",
|
||||
"scheduled_at": datetime.utcnow(),
|
||||
"tenant_id": sample_tenant_id
|
||||
}
|
||||
|
||||
job = await repository.create_weather_job(job_data)
|
||||
|
||||
assert job is not None
|
||||
assert job.job_type == "current"
|
||||
assert job.status == "pending"
|
||||
|
||||
async def test_update_weather_job(self, test_db_session, sample_tenant_id):
|
||||
"""Test updating weather job"""
|
||||
repository = WeatherRepository(test_db_session)
|
||||
|
||||
# Create job first
|
||||
job_data = {
|
||||
"job_type": "current",
|
||||
"city": "madrid",
|
||||
"status": "pending",
|
||||
"scheduled_at": datetime.utcnow(),
|
||||
"tenant_id": sample_tenant_id
|
||||
}
|
||||
job = await repository.create_weather_job(job_data)
|
||||
|
||||
# Update job
|
||||
update_data = {
|
||||
"status": "completed",
|
||||
"completed_at": datetime.utcnow(),
|
||||
"success_count": 1
|
||||
}
|
||||
|
||||
success = await repository.update_weather_job(job.id, update_data)
|
||||
|
||||
assert success is True
|
||||
|
||||
async def test_get_weather_jobs(self, test_db_session, sample_tenant_id):
|
||||
"""Test getting weather jobs"""
|
||||
repository = WeatherRepository(test_db_session)
|
||||
|
||||
# Create test job
|
||||
job_data = {
|
||||
"job_type": "forecast",
|
||||
"city": "madrid",
|
||||
"status": "completed",
|
||||
"scheduled_at": datetime.utcnow(),
|
||||
"tenant_id": sample_tenant_id
|
||||
}
|
||||
await repository.create_weather_job(job_data)
|
||||
|
||||
jobs = await repository.get_weather_jobs()
|
||||
|
||||
assert len(jobs) >= 1
|
||||
assert any(job.job_type == "forecast" for job in jobs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestTrafficRepository:
|
||||
"""Test Traffic Repository operations"""
|
||||
|
||||
async def test_create_traffic_data(self, test_db_session, sample_traffic_data):
|
||||
"""Test creating traffic data"""
|
||||
repository = TrafficRepository(test_db_session)
|
||||
|
||||
# Convert sample data to list for bulk create
|
||||
traffic_list = [sample_traffic_data]
|
||||
|
||||
count = await repository.bulk_create_traffic_data(traffic_list)
|
||||
|
||||
assert count == 1
|
||||
|
||||
async def test_get_current_traffic(self, populated_traffic_db, sample_traffic_data):
|
||||
"""Test getting current traffic data"""
|
||||
repository = TrafficRepository(populated_traffic_db)
|
||||
|
||||
result = await repository.get_current_traffic("madrid")
|
||||
|
||||
assert len(result) >= 1
|
||||
assert result[0].city == "madrid"
|
||||
|
||||
async def test_get_current_traffic_with_filters(self, populated_traffic_db):
|
||||
"""Test getting current traffic with filters"""
|
||||
repository = TrafficRepository(populated_traffic_db)
|
||||
|
||||
result = await repository.get_current_traffic("madrid", district="Chamartín")
|
||||
|
||||
# Should return results based on filter
|
||||
assert isinstance(result, list)
|
||||
|
||||
async def test_get_historical_traffic(self, test_db_session, sample_traffic_data):
|
||||
"""Test getting historical traffic data"""
|
||||
repository = TrafficRepository(test_db_session)
|
||||
|
||||
# Create historical data
|
||||
historical_data = sample_traffic_data.copy()
|
||||
historical_data["date"] = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
await repository.bulk_create_traffic_data([historical_data])
|
||||
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=2)
|
||||
end_date = datetime.now(timezone.utc)
|
||||
|
||||
result = await repository.get_historical_traffic("madrid", start_date, end_date)
|
||||
|
||||
assert len(result) >= 1
|
||||
|
||||
async def test_create_measurement_point(self, test_db_session):
|
||||
"""Test creating traffic measurement point"""
|
||||
repository = TrafficRepository(test_db_session)
|
||||
|
||||
point_data = {
|
||||
"point_id": "TEST_POINT_001",
|
||||
"name": "Test Measurement Point",
|
||||
"city": "madrid",
|
||||
"point_type": "TEST",
|
||||
"latitude": 40.4168,
|
||||
"longitude": -3.7038,
|
||||
"district": "Centro",
|
||||
"road_name": "Test Road",
|
||||
"is_active": True
|
||||
}
|
||||
|
||||
point = await repository.create_measurement_point(point_data)
|
||||
|
||||
assert point is not None
|
||||
assert point.point_id == "TEST_POINT_001"
|
||||
assert point.name == "Test Measurement Point"
|
||||
|
||||
async def test_get_measurement_points(self, test_db_session):
|
||||
"""Test getting measurement points"""
|
||||
repository = TrafficRepository(test_db_session)
|
||||
|
||||
# Create test point
|
||||
point_data = {
|
||||
"point_id": "TEST_POINT_001",
|
||||
"name": "Test Point",
|
||||
"city": "madrid",
|
||||
"point_type": "TEST",
|
||||
"latitude": 40.4168,
|
||||
"longitude": -3.7038,
|
||||
"is_active": True
|
||||
}
|
||||
await repository.create_measurement_point(point_data)
|
||||
|
||||
points = await repository.get_measurement_points("madrid")
|
||||
|
||||
assert len(points) == 1
|
||||
assert points[0].point_id == "TEST_POINT_001"
|
||||
|
||||
async def test_get_measurement_points_with_filters(self, test_db_session):
|
||||
"""Test getting measurement points with filters"""
|
||||
repository = TrafficRepository(test_db_session)
|
||||
|
||||
# Create test points with different types
|
||||
for i, point_type in enumerate(["M30", "URB", "TEST"]):
|
||||
point_data = {
|
||||
"point_id": f"TEST_POINT_{i:03d}",
|
||||
"name": f"Test Point {i}",
|
||||
"city": "madrid",
|
||||
"point_type": point_type,
|
||||
"latitude": 40.4168,
|
||||
"longitude": -3.7038,
|
||||
"is_active": True
|
||||
}
|
||||
await repository.create_measurement_point(point_data)
|
||||
|
||||
# Filter by type
|
||||
points = await repository.get_measurement_points("madrid", road_type="M30")
|
||||
|
||||
assert len(points) == 1
|
||||
assert points[0].point_type == "M30"
|
||||
|
||||
async def test_get_traffic_analytics(self, populated_traffic_db):
|
||||
"""Test getting traffic analytics"""
|
||||
repository = TrafficRepository(populated_traffic_db)
|
||||
|
||||
analytics = await repository.get_traffic_analytics("madrid")
|
||||
|
||||
assert isinstance(analytics, dict)
|
||||
assert "total_measurements" in analytics
|
||||
assert "average_volume" in analytics
|
||||
|
||||
async def test_create_traffic_job(self, test_db_session, sample_tenant_id):
|
||||
"""Test creating traffic collection job"""
|
||||
repository = TrafficRepository(test_db_session)
|
||||
|
||||
job_data = {
|
||||
"job_type": "current",
|
||||
"city": "madrid",
|
||||
"status": "pending",
|
||||
"scheduled_at": datetime.utcnow(),
|
||||
"tenant_id": sample_tenant_id
|
||||
}
|
||||
|
||||
job = await repository.create_traffic_job(job_data)
|
||||
|
||||
assert job is not None
|
||||
assert job.job_type == "current"
|
||||
assert job.status == "pending"
|
||||
|
||||
async def test_update_traffic_job(self, test_db_session, sample_tenant_id):
|
||||
"""Test updating traffic job"""
|
||||
repository = TrafficRepository(test_db_session)
|
||||
|
||||
# Create job first
|
||||
job_data = {
|
||||
"job_type": "current",
|
||||
"city": "madrid",
|
||||
"status": "pending",
|
||||
"scheduled_at": datetime.utcnow(),
|
||||
"tenant_id": sample_tenant_id
|
||||
}
|
||||
job = await repository.create_traffic_job(job_data)
|
||||
|
||||
# Update job
|
||||
update_data = {
|
||||
"status": "completed",
|
||||
"completed_at": datetime.utcnow(),
|
||||
"success_count": 10
|
||||
}
|
||||
|
||||
success = await repository.update_traffic_job(job.id, update_data)
|
||||
|
||||
assert success is True
|
||||
|
||||
async def test_get_traffic_jobs(self, test_db_session, sample_tenant_id):
|
||||
"""Test getting traffic jobs"""
|
||||
repository = TrafficRepository(test_db_session)
|
||||
|
||||
# Create test job
|
||||
job_data = {
|
||||
"job_type": "historical",
|
||||
"city": "madrid",
|
||||
"status": "completed",
|
||||
"scheduled_at": datetime.utcnow(),
|
||||
"tenant_id": sample_tenant_id
|
||||
}
|
||||
await repository.create_traffic_job(job_data)
|
||||
|
||||
jobs = await repository.get_traffic_jobs()
|
||||
|
||||
assert len(jobs) >= 1
|
||||
assert any(job.job_type == "historical" for job in jobs)
|
||||
|
||||
async def test_bulk_create_performance(self, test_db_session):
|
||||
"""Test bulk create performance"""
|
||||
repository = TrafficRepository(test_db_session)
|
||||
|
||||
# Create large dataset
|
||||
bulk_data = []
|
||||
for i in range(100):
|
||||
data = {
|
||||
"city": "madrid",
|
||||
"location_id": f"PM_TEST_{i:03d}",
|
||||
"date": datetime.now(timezone.utc),
|
||||
"measurement_point_id": f"PM_TEST_{i:03d}",
|
||||
"measurement_point_name": f"Test Point {i}",
|
||||
"measurement_point_type": "TEST",
|
||||
"traffic_volume": 100 + i,
|
||||
"average_speed": 50.0,
|
||||
"congestion_level": "medium",
|
||||
"occupation_percentage": 50.0,
|
||||
"latitude": 40.4168,
|
||||
"longitude": -3.7038,
|
||||
"source": "test"
|
||||
}
|
||||
bulk_data.append(data)
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
count = await repository.bulk_create_traffic_data(bulk_data)
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = end_time - start_time
|
||||
|
||||
assert count == 100
|
||||
assert execution_time < 3.0 # Should complete in under 3 seconds
|
||||
445
services/external/tests/unit/test_services.py
vendored
Normal file
445
services/external/tests/unit/test_services.py
vendored
Normal file
@@ -0,0 +1,445 @@
|
||||
# services/external/tests/unit/test_services.py
|
||||
"""
|
||||
Unit tests for External Service Services
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from app.services.weather_service import WeatherService
|
||||
from app.services.traffic_service import TrafficService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestWeatherService:
|
||||
"""Test Weather Service business logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def weather_service(self):
|
||||
"""Create weather service instance"""
|
||||
return WeatherService()
|
||||
|
||||
async def test_get_current_weather_from_cache(self, weather_service):
|
||||
"""Test getting current weather from cache"""
|
||||
with patch('app.services.weather_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
mock_weather = AsyncMock()
|
||||
mock_weather.date = datetime.now(timezone.utc) - timedelta(minutes=30) # Fresh data
|
||||
mock_weather.to_dict.return_value = {"temperature": 18.5, "city": "madrid"}
|
||||
mock_repository.get_current_weather.return_value = mock_weather
|
||||
|
||||
with patch('app.services.weather_service.WeatherRepository', return_value=mock_repository):
|
||||
result = await weather_service.get_current_weather("madrid")
|
||||
|
||||
assert result is not None
|
||||
assert result["temperature"] == 18.5
|
||||
assert result["city"] == "madrid"
|
||||
|
||||
async def test_get_current_weather_fetch_from_api(self, weather_service, mock_aemet_response):
|
||||
"""Test getting current weather from API when cache is stale"""
|
||||
with patch('app.services.weather_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
# No cached data or stale data
|
||||
mock_repository.get_current_weather.return_value = None
|
||||
mock_stored = AsyncMock()
|
||||
mock_stored.to_dict.return_value = {"temperature": 20.0}
|
||||
mock_repository.create_weather_data.return_value = mock_stored
|
||||
|
||||
# Mock AEMET client
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get_current_weather.return_value = mock_aemet_response
|
||||
|
||||
with patch('app.services.weather_service.WeatherRepository', return_value=mock_repository):
|
||||
weather_service.aemet_client = mock_client
|
||||
|
||||
result = await weather_service.get_current_weather("madrid")
|
||||
|
||||
assert result is not None
|
||||
assert result["temperature"] == 20.0
|
||||
mock_client.get_current_weather.assert_called_once()
|
||||
|
||||
async def test_get_weather_forecast_from_cache(self, weather_service):
|
||||
"""Test getting weather forecast from cache"""
|
||||
with patch('app.services.weather_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
mock_forecast = [AsyncMock(), AsyncMock()]
|
||||
for item in mock_forecast:
|
||||
item.created_at = datetime.now(timezone.utc) - timedelta(hours=1) # Fresh
|
||||
item.to_dict.return_value = {"temperature": 22.0}
|
||||
mock_repository.get_weather_forecast.return_value = mock_forecast
|
||||
|
||||
with patch('app.services.weather_service.WeatherRepository', return_value=mock_repository):
|
||||
result = await weather_service.get_weather_forecast("madrid", 7)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(item["temperature"] == 22.0 for item in result)
|
||||
|
||||
async def test_get_weather_forecast_fetch_from_api(self, weather_service):
|
||||
"""Test getting weather forecast from API when cache is stale"""
|
||||
with patch('app.services.weather_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
# No cached data
|
||||
mock_repository.get_weather_forecast.return_value = []
|
||||
mock_stored = AsyncMock()
|
||||
mock_stored.to_dict.return_value = {"temperature": 25.0}
|
||||
mock_repository.create_weather_data.return_value = mock_stored
|
||||
|
||||
# Mock AEMET client
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get_forecast.return_value = [
|
||||
{"forecast_date": datetime.now(), "temperature": 25.0}
|
||||
]
|
||||
|
||||
with patch('app.services.weather_service.WeatherRepository', return_value=mock_repository):
|
||||
weather_service.aemet_client = mock_client
|
||||
|
||||
result = await weather_service.get_weather_forecast("madrid", 7)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["temperature"] == 25.0
|
||||
mock_client.get_forecast.assert_called_once()
|
||||
|
||||
async def test_get_historical_weather(self, weather_service, sample_tenant_id):
|
||||
"""Test getting historical weather data"""
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=7)
|
||||
end_date = datetime.now(timezone.utc)
|
||||
|
||||
with patch('app.services.weather_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
mock_historical = [AsyncMock(), AsyncMock()]
|
||||
for item in mock_historical:
|
||||
item.to_dict.return_value = {"temperature": 18.0}
|
||||
mock_repository.get_historical_weather.return_value = mock_historical
|
||||
|
||||
with patch('app.services.weather_service.WeatherRepository', return_value=mock_repository):
|
||||
result = await weather_service.get_historical_weather(
|
||||
"madrid", start_date, end_date, sample_tenant_id
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(item["temperature"] == 18.0 for item in result)
|
||||
|
||||
async def test_get_weather_stations(self, weather_service):
|
||||
"""Test getting weather stations"""
|
||||
with patch('app.services.weather_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
mock_stations = [AsyncMock()]
|
||||
mock_stations[0].to_dict.return_value = {"station_id": "TEST_001"}
|
||||
mock_repository.get_weather_stations.return_value = mock_stations
|
||||
|
||||
with patch('app.services.weather_service.WeatherRepository', return_value=mock_repository):
|
||||
result = await weather_service.get_weather_stations("madrid")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["station_id"] == "TEST_001"
|
||||
|
||||
async def test_trigger_weather_collection(self, weather_service, sample_tenant_id):
|
||||
"""Test triggering weather data collection"""
|
||||
with patch('app.services.weather_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
mock_job = AsyncMock()
|
||||
mock_job.id = uuid4()
|
||||
mock_job.to_dict.return_value = {"id": str(mock_job.id), "status": "pending"}
|
||||
mock_repository.create_weather_job.return_value = mock_job
|
||||
|
||||
with patch('app.services.weather_service.WeatherRepository', return_value=mock_repository):
|
||||
result = await weather_service.trigger_weather_collection(
|
||||
"madrid", "current", sample_tenant_id
|
||||
)
|
||||
|
||||
assert result["status"] == "pending"
|
||||
mock_repository.create_weather_job.assert_called_once()
|
||||
|
||||
async def test_process_weather_collection_job(self, weather_service):
|
||||
"""Test processing weather collection job"""
|
||||
job_id = uuid4()
|
||||
|
||||
with patch('app.services.weather_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
|
||||
# Mock job
|
||||
mock_job = AsyncMock()
|
||||
mock_job.id = job_id
|
||||
mock_job.job_type = "current"
|
||||
mock_job.city = "madrid"
|
||||
|
||||
mock_repository.get_weather_jobs.return_value = [mock_job]
|
||||
mock_repository.update_weather_job.return_value = True
|
||||
|
||||
# Mock updated job after completion
|
||||
mock_updated_job = AsyncMock()
|
||||
mock_updated_job.to_dict.return_value = {"id": str(job_id), "status": "completed"}
|
||||
|
||||
# Mock methods for different calls
|
||||
def mock_get_jobs_side_effect():
|
||||
return [mock_updated_job] # Return completed job
|
||||
|
||||
mock_repository.get_weather_jobs.side_effect = [
|
||||
[mock_job], # First call returns pending job
|
||||
[mock_updated_job] # Second call returns completed job
|
||||
]
|
||||
|
||||
with patch('app.services.weather_service.WeatherRepository', return_value=mock_repository):
|
||||
with patch.object(weather_service, '_collect_current_weather', return_value=1):
|
||||
result = await weather_service.process_weather_collection_job(job_id)
|
||||
|
||||
assert result["status"] == "completed"
|
||||
|
||||
async def test_map_weather_condition(self, weather_service):
|
||||
"""Test weather condition mapping"""
|
||||
test_cases = [
|
||||
("Soleado", "clear"),
|
||||
("Nublado", "cloudy"),
|
||||
("Parcialmente nublado", "partly_cloudy"),
|
||||
("Lluvioso", "rainy"),
|
||||
("Nevando", "snowy"),
|
||||
("Tormenta", "stormy"),
|
||||
("Desconocido", "unknown")
|
||||
]
|
||||
|
||||
for description, expected in test_cases:
|
||||
result = weather_service._map_weather_condition(description)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestTrafficService:
|
||||
"""Test Traffic Service business logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def traffic_service(self):
|
||||
"""Create traffic service instance"""
|
||||
return TrafficService()
|
||||
|
||||
async def test_get_current_traffic_from_cache(self, traffic_service):
|
||||
"""Test getting current traffic from cache"""
|
||||
with patch('app.services.traffic_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
mock_traffic = [AsyncMock()]
|
||||
mock_traffic[0].date = datetime.now(timezone.utc) - timedelta(minutes=5) # Fresh
|
||||
mock_traffic[0].to_dict.return_value = {"traffic_volume": 850}
|
||||
mock_repository.get_current_traffic.return_value = mock_traffic
|
||||
|
||||
with patch('app.services.traffic_service.TrafficRepository', return_value=mock_repository):
|
||||
result = await traffic_service.get_current_traffic("madrid")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["traffic_volume"] == 850
|
||||
|
||||
async def test_get_current_traffic_fetch_from_api(self, traffic_service, mock_madrid_traffic_xml):
|
||||
"""Test getting current traffic from API when cache is stale"""
|
||||
with patch('app.services.traffic_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
# No cached data
|
||||
mock_repository.get_current_traffic.return_value = []
|
||||
mock_repository.bulk_create_traffic_data.return_value = 2
|
||||
|
||||
# Mock clients
|
||||
mock_client = AsyncMock()
|
||||
mock_client.fetch_current_traffic_xml.return_value = mock_madrid_traffic_xml
|
||||
|
||||
mock_processor = AsyncMock()
|
||||
mock_processor.process_current_traffic_xml.return_value = [
|
||||
{"traffic_volume": 850, "measurement_point_id": "PM_M30_001"},
|
||||
{"traffic_volume": 320, "measurement_point_id": "PM_URB_002"}
|
||||
]
|
||||
|
||||
with patch('app.services.traffic_service.TrafficRepository', return_value=mock_repository):
|
||||
traffic_service.madrid_client = mock_client
|
||||
traffic_service.madrid_processor = mock_processor
|
||||
|
||||
result = await traffic_service.get_current_traffic("madrid")
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["traffic_volume"] == 850
|
||||
mock_client.fetch_current_traffic_xml.assert_called_once()
|
||||
|
||||
async def test_get_historical_traffic(self, traffic_service, sample_tenant_id):
|
||||
"""Test getting historical traffic data"""
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=7)
|
||||
end_date = datetime.now(timezone.utc)
|
||||
|
||||
with patch('app.services.traffic_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
mock_historical = [AsyncMock(), AsyncMock()]
|
||||
for item in mock_historical:
|
||||
item.to_dict.return_value = {"traffic_volume": 500}
|
||||
mock_repository.get_historical_traffic.return_value = mock_historical
|
||||
|
||||
with patch('app.services.traffic_service.TrafficRepository', return_value=mock_repository):
|
||||
result = await traffic_service.get_historical_traffic(
|
||||
"madrid", start_date, end_date, tenant_id=sample_tenant_id
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(item["traffic_volume"] == 500 for item in result)
|
||||
|
||||
async def test_get_measurement_points(self, traffic_service):
|
||||
"""Test getting measurement points"""
|
||||
with patch('app.services.traffic_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
mock_points = [AsyncMock()]
|
||||
mock_points[0].to_dict.return_value = {"point_id": "PM_TEST_001"}
|
||||
mock_repository.get_measurement_points.return_value = mock_points
|
||||
|
||||
with patch('app.services.traffic_service.TrafficRepository', return_value=mock_repository):
|
||||
result = await traffic_service.get_measurement_points("madrid")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["point_id"] == "PM_TEST_001"
|
||||
|
||||
async def test_get_traffic_analytics(self, traffic_service):
|
||||
"""Test getting traffic analytics"""
|
||||
start_date = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
end_date = datetime.now(timezone.utc)
|
||||
|
||||
with patch('app.services.traffic_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
mock_analytics = {
|
||||
"total_measurements": 1000,
|
||||
"average_volume": 650.5,
|
||||
"peak_hour": "08:00"
|
||||
}
|
||||
mock_repository.get_traffic_analytics.return_value = mock_analytics
|
||||
|
||||
with patch('app.services.traffic_service.TrafficRepository', return_value=mock_repository):
|
||||
result = await traffic_service.get_traffic_analytics(
|
||||
"madrid", start_date, end_date
|
||||
)
|
||||
|
||||
assert result["total_measurements"] == 1000
|
||||
assert result["average_volume"] == 650.5
|
||||
assert "generated_at" in result
|
||||
|
||||
async def test_trigger_traffic_collection(self, traffic_service, sample_tenant_id):
|
||||
"""Test triggering traffic data collection"""
|
||||
with patch('app.services.traffic_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
mock_job = AsyncMock()
|
||||
mock_job.id = uuid4()
|
||||
mock_job.to_dict.return_value = {"id": str(mock_job.id), "status": "pending"}
|
||||
mock_repository.create_traffic_job.return_value = mock_job
|
||||
|
||||
with patch('app.services.traffic_service.TrafficRepository', return_value=mock_repository):
|
||||
result = await traffic_service.trigger_traffic_collection(
|
||||
"madrid", "current", user_id=sample_tenant_id
|
||||
)
|
||||
|
||||
assert result["status"] == "pending"
|
||||
mock_repository.create_traffic_job.assert_called_once()
|
||||
|
||||
async def test_process_traffic_collection_job(self, traffic_service):
|
||||
"""Test processing traffic collection job"""
|
||||
job_id = uuid4()
|
||||
|
||||
with patch('app.services.traffic_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
|
||||
# Mock job
|
||||
mock_job = AsyncMock()
|
||||
mock_job.id = job_id
|
||||
mock_job.job_type = "current"
|
||||
mock_job.city = "madrid"
|
||||
mock_job.location_pattern = None
|
||||
|
||||
mock_repository.get_traffic_jobs.return_value = [mock_job]
|
||||
mock_repository.update_traffic_job.return_value = True
|
||||
|
||||
# Mock updated job after completion
|
||||
mock_updated_job = AsyncMock()
|
||||
mock_updated_job.to_dict.return_value = {"id": str(job_id), "status": "completed"}
|
||||
|
||||
mock_repository.get_traffic_jobs.side_effect = [
|
||||
[mock_job], # First call returns pending job
|
||||
[mock_updated_job] # Second call returns completed job
|
||||
]
|
||||
|
||||
with patch('app.services.traffic_service.TrafficRepository', return_value=mock_repository):
|
||||
with patch.object(traffic_service, '_collect_current_traffic', return_value=125):
|
||||
result = await traffic_service.process_traffic_collection_job(job_id)
|
||||
|
||||
assert result["status"] == "completed"
|
||||
|
||||
async def test_is_traffic_data_fresh(self, traffic_service):
|
||||
"""Test traffic data freshness check"""
|
||||
from app.models.traffic import TrafficData
|
||||
|
||||
# Fresh data (5 minutes old)
|
||||
fresh_data = [AsyncMock()]
|
||||
fresh_data[0].date = datetime.utcnow() - timedelta(minutes=5)
|
||||
|
||||
result = traffic_service._is_traffic_data_fresh(fresh_data)
|
||||
assert result is True
|
||||
|
||||
# Stale data (15 minutes old)
|
||||
stale_data = [AsyncMock()]
|
||||
stale_data[0].date = datetime.utcnow() - timedelta(minutes=15)
|
||||
|
||||
result = traffic_service._is_traffic_data_fresh(stale_data)
|
||||
assert result is False
|
||||
|
||||
# Empty data
|
||||
result = traffic_service._is_traffic_data_fresh([])
|
||||
assert result is False
|
||||
|
||||
async def test_collect_current_traffic(self, traffic_service):
|
||||
"""Test current traffic collection"""
|
||||
with patch('app.services.traffic_service.get_db_transaction') as mock_get_db:
|
||||
mock_db = AsyncMock()
|
||||
mock_get_db.return_value.__aenter__.return_value = mock_db
|
||||
|
||||
mock_repository = AsyncMock()
|
||||
mock_repository.bulk_create_traffic_data.return_value = 10
|
||||
|
||||
with patch('app.services.traffic_service.TrafficRepository', return_value=mock_repository):
|
||||
with patch.object(traffic_service, '_fetch_current_traffic_from_api', return_value=[{} for _ in range(10)]):
|
||||
result = await traffic_service._collect_current_traffic("madrid", None)
|
||||
|
||||
assert result == 10
|
||||
@@ -9,7 +9,7 @@ from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
# Import the shared clients
|
||||
from shared.clients import get_data_client, get_service_clients
|
||||
from shared.clients import get_sales_client, get_external_client, get_service_clients
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
@@ -21,12 +21,13 @@ class DataClient:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Get the shared data client configured for this service
|
||||
self.data_client = get_data_client(settings, "forecasting")
|
||||
# Get the new specialized clients
|
||||
self.sales_client = get_sales_client(settings, "forecasting")
|
||||
self.external_client = get_external_client(settings, "forecasting")
|
||||
|
||||
# Or alternatively, get all clients at once:
|
||||
# self.clients = get_service_clients(settings, "training")
|
||||
# Then use: self.clients.data.get_sales_data(...)
|
||||
# self.clients = get_service_clients(settings, "forecasting")
|
||||
# Then use: self.clients.sales.get_sales_data(...) and self.clients.external.get_weather_forecast(...)
|
||||
|
||||
|
||||
async def fetch_weather_forecast(
|
||||
@@ -41,7 +42,7 @@ class DataClient:
|
||||
All the error handling and retry logic is now in the base client!
|
||||
"""
|
||||
try:
|
||||
weather_data = await self.data_client.get_weather_forecast(
|
||||
weather_data = await self.external_client.get_weather_forecast(
|
||||
tenant_id=tenant_id,
|
||||
days=days,
|
||||
latitude=latitude,
|
||||
|
||||
@@ -8,7 +8,7 @@ import structlog
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
# Import shared clients - no more code duplication!
|
||||
from shared.clients import get_service_clients, get_training_client, get_data_client
|
||||
from shared.clients import get_service_clients, get_training_client, get_sales_client
|
||||
from shared.database.base import create_database_manager
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -30,7 +30,7 @@ class ModelClient:
|
||||
|
||||
# Option 2: Get specific clients
|
||||
# self.training_client = get_training_client(settings, "forecasting")
|
||||
# self.data_client = get_data_client(settings, "forecasting")
|
||||
# self.sales_client = get_sales_client(settings, "forecasting")
|
||||
|
||||
async def get_available_models(
|
||||
self,
|
||||
|
||||
@@ -409,7 +409,9 @@ class PredictionService:
|
||||
# Traffic-based features
|
||||
'high_traffic': int(traffic > 150) if traffic > 0 else 0,
|
||||
'low_traffic': int(traffic < 50) if traffic > 0 else 0,
|
||||
'traffic_normalized': float((traffic - 100) / 50) if traffic > 0 else 0.0,
|
||||
# Fix: Use same normalization as training (when std=0, normalized=0.0)
|
||||
# Training uses constant 100.0 values, so std=0 and normalized=0.0
|
||||
'traffic_normalized': 0.0, # Match training behavior for consistent predictions
|
||||
'traffic_squared': traffic ** 2,
|
||||
'traffic_log': float(np.log1p(traffic)),
|
||||
|
||||
|
||||
33
services/sales/Dockerfile
Normal file
33
services/sales/Dockerfile
Normal file
@@ -0,0 +1,33 @@
|
||||
# services/sales/Dockerfile
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
g++ \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements and install Python dependencies
|
||||
COPY services/sales/requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy shared modules first
|
||||
COPY shared/ /app/shared/
|
||||
|
||||
# Copy application code
|
||||
COPY services/sales/app/ /app/app/
|
||||
|
||||
# Set Python path to include shared modules
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||
CMD python -c "import requests; requests.get('http://localhost:8000/health', timeout=5)" || exit 1
|
||||
|
||||
# Run the application
|
||||
CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
1
services/sales/app/__init__.py
Normal file
1
services/sales/app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# services/sales/app/__init__.py
|
||||
1
services/sales/app/api/__init__.py
Normal file
1
services/sales/app/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# services/sales/app/api/__init__.py
|
||||
397
services/sales/app/api/import_data.py
Normal file
397
services/sales/app/api/import_data.py
Normal file
@@ -0,0 +1,397 @@
|
||||
# services/sales/app/api/import_data.py
|
||||
"""
|
||||
Sales Data Import API Endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form, Path
|
||||
from typing import Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
import json
|
||||
|
||||
from app.services.data_import_service import DataImportService
|
||||
from shared.auth.decorators import get_current_user_dep, get_current_tenant_id_dep
|
||||
|
||||
router = APIRouter(tags=["data-import"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
def get_import_service():
|
||||
"""Dependency injection for DataImportService"""
|
||||
return DataImportService()
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/sales/import/validate-json")
|
||||
async def validate_json_data(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
data: Dict[str, Any] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
import_service: DataImportService = Depends(get_import_service)
|
||||
):
|
||||
"""Validate JSON sales data"""
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(tenant_id) != current_tenant:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tenant")
|
||||
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="No data provided")
|
||||
|
||||
logger.info("Validating JSON data", tenant_id=tenant_id, record_count=len(data.get("records", [])))
|
||||
|
||||
# Validate the data - handle different input formats
|
||||
if "records" in data:
|
||||
# New format with records array
|
||||
validation_data = {
|
||||
"tenant_id": str(tenant_id),
|
||||
"data": json.dumps(data.get("records", [])),
|
||||
"data_format": "json"
|
||||
}
|
||||
else:
|
||||
# Legacy format where the entire payload is the validation data
|
||||
validation_data = data.copy()
|
||||
validation_data["tenant_id"] = str(tenant_id)
|
||||
if "data_format" not in validation_data:
|
||||
validation_data["data_format"] = "json"
|
||||
|
||||
validation_result = await import_service.validate_import_data(validation_data)
|
||||
|
||||
logger.info("JSON validation completed", tenant_id=tenant_id, valid=validation_result.is_valid)
|
||||
|
||||
return {
|
||||
"is_valid": validation_result.is_valid,
|
||||
"total_records": validation_result.total_records,
|
||||
"valid_records": validation_result.valid_records,
|
||||
"invalid_records": validation_result.invalid_records,
|
||||
"errors": validation_result.errors,
|
||||
"warnings": validation_result.warnings,
|
||||
"summary": validation_result.summary
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate JSON data", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to validate data: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/sales/import/validate")
|
||||
async def validate_sales_data_universal(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
file: Optional[UploadFile] = File(None),
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
file_format: Optional[str] = Form(None),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
import_service: DataImportService = Depends(get_import_service)
|
||||
):
|
||||
"""Universal validation endpoint for sales data - supports files and JSON"""
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(tenant_id) != current_tenant:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tenant")
|
||||
|
||||
# Handle file upload validation
|
||||
if file:
|
||||
logger.info("Validating uploaded file", tenant_id=tenant_id, filename=file.filename)
|
||||
|
||||
# Auto-detect format from filename
|
||||
filename = file.filename.lower()
|
||||
if filename.endswith('.csv'):
|
||||
detected_format = 'csv'
|
||||
elif filename.endswith('.xlsx') or filename.endswith('.xls'):
|
||||
detected_format = 'excel'
|
||||
elif filename.endswith('.json'):
|
||||
detected_format = 'json'
|
||||
else:
|
||||
detected_format = file_format or 'csv' # Default to CSV
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
|
||||
if detected_format in ['xlsx', 'xls', 'excel']:
|
||||
# For Excel files, encode as base64
|
||||
import base64
|
||||
file_content = base64.b64encode(content).decode('utf-8')
|
||||
else:
|
||||
# For CSV/JSON, decode as text
|
||||
file_content = content.decode('utf-8')
|
||||
|
||||
validation_data = {
|
||||
"tenant_id": str(tenant_id),
|
||||
"data": file_content,
|
||||
"data_format": detected_format,
|
||||
"filename": file.filename
|
||||
}
|
||||
|
||||
# Handle JSON data validation
|
||||
elif data:
|
||||
logger.info("Validating JSON data", tenant_id=tenant_id)
|
||||
|
||||
validation_data = data.copy()
|
||||
validation_data["tenant_id"] = str(tenant_id)
|
||||
if "data_format" not in validation_data:
|
||||
validation_data["data_format"] = "json"
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="No file or data provided for validation")
|
||||
|
||||
# Perform validation
|
||||
validation_result = await import_service.validate_import_data(validation_data)
|
||||
|
||||
logger.info("Validation completed",
|
||||
tenant_id=tenant_id,
|
||||
valid=validation_result.is_valid,
|
||||
total_records=validation_result.total_records)
|
||||
|
||||
return {
|
||||
"is_valid": validation_result.is_valid,
|
||||
"total_records": validation_result.total_records,
|
||||
"valid_records": validation_result.valid_records,
|
||||
"invalid_records": validation_result.invalid_records,
|
||||
"errors": validation_result.errors,
|
||||
"warnings": validation_result.warnings,
|
||||
"summary": validation_result.summary,
|
||||
"message": "Validation completed successfully" if validation_result.is_valid else "Validation found errors",
|
||||
"details": {
|
||||
"total_records": validation_result.total_records,
|
||||
"format": validation_data.get("data_format", "unknown")
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate sales data", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to validate data: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/sales/import/validate-csv")
|
||||
async def validate_csv_data_legacy(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
file: UploadFile = File(...),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
import_service: DataImportService = Depends(get_import_service)
|
||||
):
|
||||
"""Legacy CSV validation endpoint - redirects to universal validator"""
|
||||
return await validate_sales_data_universal(
|
||||
tenant_id=tenant_id,
|
||||
file=file,
|
||||
current_user=current_user,
|
||||
current_tenant=current_tenant,
|
||||
import_service=import_service
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/sales/import")
|
||||
async def import_sales_data(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
file: Optional[UploadFile] = File(None),
|
||||
file_format: Optional[str] = Form(None),
|
||||
update_existing: bool = Form(False, description="Whether to update existing records"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
import_service: DataImportService = Depends(get_import_service)
|
||||
):
|
||||
"""Enhanced import sales data - supports multiple file formats and JSON"""
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(tenant_id) != current_tenant:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tenant")
|
||||
|
||||
# Handle file upload (form data)
|
||||
if file:
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file provided")
|
||||
|
||||
logger.info("Starting enhanced file import", tenant_id=tenant_id, filename=file.filename)
|
||||
|
||||
# Auto-detect format from filename
|
||||
filename = file.filename.lower()
|
||||
if filename.endswith('.csv'):
|
||||
detected_format = 'csv'
|
||||
elif filename.endswith('.xlsx') or filename.endswith('.xls'):
|
||||
detected_format = 'excel'
|
||||
elif filename.endswith('.json'):
|
||||
detected_format = 'json'
|
||||
else:
|
||||
detected_format = file_format or 'csv' # Default to CSV
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
|
||||
if detected_format in ['xlsx', 'xls', 'excel']:
|
||||
# For Excel files, encode as base64
|
||||
import base64
|
||||
file_content = base64.b64encode(content).decode('utf-8')
|
||||
else:
|
||||
# For CSV/JSON, decode as text
|
||||
file_content = content.decode('utf-8')
|
||||
|
||||
# Import the file using enhanced service
|
||||
import_result = await import_service.process_import(
|
||||
str(tenant_id), # Ensure string type
|
||||
file_content,
|
||||
detected_format,
|
||||
filename=file.filename
|
||||
)
|
||||
|
||||
# Handle JSON data
|
||||
elif data:
|
||||
logger.info("Starting enhanced JSON data import", tenant_id=tenant_id, record_count=len(data.get("records", [])))
|
||||
|
||||
# Import the data - handle different input formats
|
||||
if "records" in data:
|
||||
# New format with records array
|
||||
records_json = json.dumps(data.get("records", []))
|
||||
import_result = await import_service.process_import(
|
||||
str(tenant_id),
|
||||
records_json,
|
||||
"json"
|
||||
)
|
||||
else:
|
||||
# Legacy format - data field contains the data directly
|
||||
import_result = await import_service.process_import(
|
||||
str(tenant_id),
|
||||
data.get("data", ""),
|
||||
data.get("data_format", "json")
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="No data or file provided")
|
||||
|
||||
logger.info("Enhanced import completed",
|
||||
tenant_id=tenant_id,
|
||||
created=import_result.records_created,
|
||||
updated=import_result.records_updated,
|
||||
failed=import_result.records_failed,
|
||||
processing_time=import_result.processing_time_seconds)
|
||||
|
||||
# Return enhanced response matching frontend expectations
|
||||
response = {
|
||||
"success": import_result.success,
|
||||
"records_processed": import_result.records_processed,
|
||||
"records_created": import_result.records_created,
|
||||
"records_updated": import_result.records_updated,
|
||||
"records_failed": import_result.records_failed,
|
||||
"errors": import_result.errors,
|
||||
"warnings": import_result.warnings,
|
||||
"processing_time_seconds": import_result.processing_time_seconds,
|
||||
"records_imported": import_result.records_created, # Frontend compatibility
|
||||
"message": f"Successfully imported {import_result.records_created} records" if import_result.success else "Import completed with errors"
|
||||
}
|
||||
|
||||
# Add file-specific information if available
|
||||
if file:
|
||||
response["file_info"] = {
|
||||
"name": file.filename,
|
||||
"format": detected_format,
|
||||
"size_bytes": len(content) if 'content' in locals() else 0
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to import sales data", error=str(e), tenant_id=tenant_id, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to import data: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/sales/import/csv")
|
||||
async def import_csv_data(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
file: UploadFile = File(...),
|
||||
update_existing: bool = Form(False, description="Whether to update existing records"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
import_service: DataImportService = Depends(get_import_service)
|
||||
):
|
||||
"""Import CSV sales data file"""
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(tenant_id) != current_tenant:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tenant")
|
||||
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise HTTPException(status_code=400, detail="File must be a CSV file")
|
||||
|
||||
logger.info("Starting CSV data import", tenant_id=tenant_id, filename=file.filename)
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
file_content = content.decode('utf-8')
|
||||
|
||||
# Import the data
|
||||
import_result = await import_service.process_import(
|
||||
tenant_id,
|
||||
file_content,
|
||||
"csv",
|
||||
filename=file.filename
|
||||
)
|
||||
|
||||
logger.info("CSV import completed",
|
||||
tenant_id=tenant_id,
|
||||
filename=file.filename,
|
||||
created=import_result.records_created,
|
||||
updated=import_result.records_updated,
|
||||
failed=import_result.records_failed)
|
||||
|
||||
return {
|
||||
"success": import_result.success,
|
||||
"records_processed": import_result.records_processed,
|
||||
"records_created": import_result.records_created,
|
||||
"records_updated": import_result.records_updated,
|
||||
"records_failed": import_result.records_failed,
|
||||
"errors": import_result.errors,
|
||||
"warnings": import_result.warnings,
|
||||
"processing_time_seconds": import_result.processing_time_seconds
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to import CSV data", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to import CSV data: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/sales/import/template")
|
||||
async def get_import_template(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
format: str = "csv",
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep)
|
||||
):
|
||||
"""Get sales data import template"""
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(tenant_id) != current_tenant:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tenant")
|
||||
|
||||
if format not in ["csv", "json"]:
|
||||
raise HTTPException(status_code=400, detail="Format must be 'csv' or 'json'")
|
||||
|
||||
if format == "csv":
|
||||
template = "date,product_name,product_category,product_sku,quantity_sold,unit_price,revenue,cost_of_goods,discount_applied,location_id,sales_channel,source,notes,weather_condition,is_holiday,is_weekend"
|
||||
else:
|
||||
template = {
|
||||
"records": [
|
||||
{
|
||||
"date": "2024-01-01T10:00:00Z",
|
||||
"product_name": "Sample Product",
|
||||
"product_category": "Sample Category",
|
||||
"product_sku": "SAMPLE001",
|
||||
"quantity_sold": 1,
|
||||
"unit_price": 10.50,
|
||||
"revenue": 10.50,
|
||||
"cost_of_goods": 5.25,
|
||||
"discount_applied": 0.0,
|
||||
"location_id": "LOC001",
|
||||
"sales_channel": "in_store",
|
||||
"source": "manual",
|
||||
"notes": "Sample sales record",
|
||||
"weather_condition": "sunny",
|
||||
"is_holiday": False,
|
||||
"is_weekend": False
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
return {"template": template, "format": format}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get import template", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get import template: {str(e)}")
|
||||
325
services/sales/app/api/sales.py
Normal file
325
services/sales/app/api/sales.py
Normal file
@@ -0,0 +1,325 @@
|
||||
# services/sales/app/api/sales.py
|
||||
"""
|
||||
Sales API Endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
import structlog
|
||||
|
||||
from app.schemas.sales import (
|
||||
SalesDataCreate,
|
||||
SalesDataUpdate,
|
||||
SalesDataResponse,
|
||||
SalesDataQuery
|
||||
)
|
||||
from app.services.sales_service import SalesService
|
||||
from shared.auth.decorators import get_current_user_dep, get_current_tenant_id_dep
|
||||
|
||||
router = APIRouter(tags=["sales"])
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
def get_sales_service():
|
||||
"""Dependency injection for SalesService"""
|
||||
return SalesService()
|
||||
|
||||
@router.get("/tenants/{tenant_id}/sales/products")
|
||||
async def get_products_list(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Get list of products using repository pattern"""
|
||||
try:
|
||||
logger.debug("Getting products list with repository pattern", tenant_id=tenant_id)
|
||||
|
||||
products = await sales_service.get_products_list(str(tenant_id))
|
||||
|
||||
logger.debug("Products list retrieved using repository",
|
||||
count=len(products),
|
||||
tenant_id=tenant_id)
|
||||
return products
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get products list",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get products list: {str(e)}")
|
||||
|
||||
@router.post("/tenants/{tenant_id}/sales", response_model=SalesDataResponse)
|
||||
async def create_sales_record(
|
||||
sales_data: SalesDataCreate,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Create a new sales record"""
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(tenant_id) != current_tenant:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tenant")
|
||||
|
||||
logger.info(
|
||||
"Creating sales record",
|
||||
product=sales_data.product_name,
|
||||
quantity=sales_data.quantity_sold,
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
# Create the record
|
||||
record = await sales_service.create_sales_record(
|
||||
sales_data,
|
||||
tenant_id,
|
||||
user_id=UUID(current_user["user_id"]) if current_user.get("user_id") else None
|
||||
)
|
||||
|
||||
logger.info("Successfully created sales record", record_id=record.id, tenant_id=tenant_id)
|
||||
return record
|
||||
|
||||
except ValueError as ve:
|
||||
logger.warning("Validation error creating sales record", error=str(ve), tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=400, detail=str(ve))
|
||||
except Exception as e:
|
||||
logger.error("Failed to create sales record", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create sales record: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/sales", response_model=List[SalesDataResponse])
|
||||
async def get_sales_records(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
start_date: Optional[datetime] = Query(None, description="Start date filter"),
|
||||
end_date: Optional[datetime] = Query(None, description="End date filter"),
|
||||
product_name: Optional[str] = Query(None, description="Product name filter"),
|
||||
product_category: Optional[str] = Query(None, description="Product category filter"),
|
||||
location_id: Optional[str] = Query(None, description="Location filter"),
|
||||
sales_channel: Optional[str] = Query(None, description="Sales channel filter"),
|
||||
source: Optional[str] = Query(None, description="Data source filter"),
|
||||
is_validated: Optional[bool] = Query(None, description="Validation status filter"),
|
||||
limit: int = Query(50, ge=1, le=1000, description="Number of records to return"),
|
||||
offset: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
order_by: str = Query("date", description="Field to order by"),
|
||||
order_direction: str = Query("desc", description="Order direction (asc/desc)"),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Get sales records for a tenant with filtering and pagination"""
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(tenant_id) != current_tenant:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tenant")
|
||||
|
||||
# Build query parameters
|
||||
query_params = SalesDataQuery(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
product_name=product_name,
|
||||
product_category=product_category,
|
||||
location_id=location_id,
|
||||
sales_channel=sales_channel,
|
||||
source=source,
|
||||
is_validated=is_validated,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
order_by=order_by,
|
||||
order_direction=order_direction
|
||||
)
|
||||
|
||||
records = await sales_service.get_sales_records(tenant_id, query_params)
|
||||
|
||||
logger.info("Retrieved sales records", count=len(records), tenant_id=tenant_id)
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales records", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get sales records: {str(e)}")
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/sales/analytics/summary")
|
||||
async def get_sales_analytics(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
start_date: Optional[datetime] = Query(None, description="Start date filter"),
|
||||
end_date: Optional[datetime] = Query(None, description="End date filter"),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Get sales analytics summary for a tenant"""
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(tenant_id) != current_tenant:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tenant")
|
||||
|
||||
analytics = await sales_service.get_sales_analytics(tenant_id, start_date, end_date)
|
||||
|
||||
logger.info("Retrieved sales analytics", tenant_id=tenant_id)
|
||||
return analytics
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales analytics", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get sales analytics: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/products/{product_name}/sales", response_model=List[SalesDataResponse])
|
||||
async def get_product_sales(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
product_name: str = Path(..., description="Product name"),
|
||||
start_date: Optional[datetime] = Query(None, description="Start date filter"),
|
||||
end_date: Optional[datetime] = Query(None, description="End date filter"),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Get sales records for a specific product"""
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(tenant_id) != current_tenant:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tenant")
|
||||
|
||||
records = await sales_service.get_product_sales(tenant_id, product_name, start_date, end_date)
|
||||
|
||||
logger.info("Retrieved product sales", count=len(records), product=product_name, tenant_id=tenant_id)
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get product sales", error=str(e), tenant_id=tenant_id, product=product_name)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get product sales: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/sales/categories", response_model=List[str])
|
||||
async def get_product_categories(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Get distinct product categories from sales data"""
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(tenant_id) != current_tenant:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tenant")
|
||||
|
||||
categories = await sales_service.get_product_categories(tenant_id)
|
||||
|
||||
return categories
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get product categories", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get product categories: {str(e)}")
|
||||
|
||||
|
||||
# ================================================================
|
||||
# PARAMETERIZED ROUTES - Keep these at the end to avoid conflicts
|
||||
# ================================================================
|
||||
|
||||
@router.get("/tenants/{tenant_id}/sales/{record_id}", response_model=SalesDataResponse)
|
||||
async def get_sales_record(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
record_id: UUID = Path(..., description="Sales record ID"),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Get a specific sales record"""
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(tenant_id) != current_tenant:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tenant")
|
||||
|
||||
record = await sales_service.get_sales_record(record_id, tenant_id)
|
||||
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Sales record not found")
|
||||
|
||||
return record
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales record", error=str(e), record_id=record_id, tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get sales record: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/tenants/{tenant_id}/sales/{record_id}", response_model=SalesDataResponse)
|
||||
async def update_sales_record(
|
||||
update_data: SalesDataUpdate,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
record_id: UUID = Path(..., description="Sales record ID"),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Update a sales record"""
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(tenant_id) != current_tenant:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tenant")
|
||||
|
||||
updated_record = await sales_service.update_sales_record(record_id, update_data, tenant_id)
|
||||
|
||||
logger.info("Updated sales record", record_id=record_id, tenant_id=tenant_id)
|
||||
return updated_record
|
||||
|
||||
except ValueError as ve:
|
||||
logger.warning("Validation error updating sales record", error=str(ve), record_id=record_id)
|
||||
raise HTTPException(status_code=400, detail=str(ve))
|
||||
except Exception as e:
|
||||
logger.error("Failed to update sales record", error=str(e), record_id=record_id, tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update sales record: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/tenants/{tenant_id}/sales/{record_id}")
|
||||
async def delete_sales_record(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
record_id: UUID = Path(..., description="Sales record ID"),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Delete a sales record"""
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(tenant_id) != current_tenant:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tenant")
|
||||
|
||||
success = await sales_service.delete_sales_record(record_id, tenant_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Sales record not found")
|
||||
|
||||
logger.info("Deleted sales record", record_id=record_id, tenant_id=tenant_id)
|
||||
return {"message": "Sales record deleted successfully"}
|
||||
|
||||
except ValueError as ve:
|
||||
logger.warning("Error deleting sales record", error=str(ve), record_id=record_id)
|
||||
raise HTTPException(status_code=400, detail=str(ve))
|
||||
except Exception as e:
|
||||
logger.error("Failed to delete sales record", error=str(e), record_id=record_id, tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to delete sales record: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/sales/{record_id}/validate", response_model=SalesDataResponse)
|
||||
async def validate_sales_record(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
record_id: UUID = Path(..., description="Sales record ID"),
|
||||
validation_notes: Optional[str] = Query(None, description="Validation notes"),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
sales_service: SalesService = Depends(get_sales_service)
|
||||
):
|
||||
"""Mark a sales record as validated"""
|
||||
try:
|
||||
# Verify tenant access
|
||||
if str(tenant_id) != current_tenant:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this tenant")
|
||||
|
||||
validated_record = await sales_service.validate_sales_record(record_id, tenant_id, validation_notes)
|
||||
|
||||
logger.info("Validated sales record", record_id=record_id, tenant_id=tenant_id)
|
||||
return validated_record
|
||||
|
||||
except ValueError as ve:
|
||||
logger.warning("Error validating sales record", error=str(ve), record_id=record_id)
|
||||
raise HTTPException(status_code=400, detail=str(ve))
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate sales record", error=str(e), record_id=record_id, tenant_id=tenant_id)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to validate sales record: {str(e)}")
|
||||
1
services/sales/app/core/__init__.py
Normal file
1
services/sales/app/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# services/sales/app/core/__init__.py
|
||||
53
services/sales/app/core/config.py
Normal file
53
services/sales/app/core/config.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# services/sales/app/core/config.py
|
||||
"""
|
||||
Sales Service Configuration
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from pydantic import Field
|
||||
from shared.config.base import BaseServiceSettings
|
||||
|
||||
|
||||
class Settings(BaseServiceSettings):
|
||||
"""Sales service settings extending base configuration"""
|
||||
|
||||
# Override service-specific settings
|
||||
SERVICE_NAME: str = "sales-service"
|
||||
VERSION: str = "1.0.0"
|
||||
APP_NAME: str = "Bakery Sales Service"
|
||||
DESCRIPTION: str = "Sales data management and analytics service"
|
||||
|
||||
# API Configuration
|
||||
API_V1_STR: str = "/api/v1"
|
||||
|
||||
# Override database URL to use SALES_DATABASE_URL
|
||||
DATABASE_URL: str = Field(
|
||||
default="postgresql+asyncpg://sales_user:sales_pass123@sales-db:5432/sales_db",
|
||||
env="SALES_DATABASE_URL"
|
||||
)
|
||||
|
||||
# Sales-specific Redis database
|
||||
REDIS_DB: int = Field(default=2, env="SALES_REDIS_DB")
|
||||
|
||||
# File upload configuration
|
||||
MAX_UPLOAD_SIZE: int = 10 * 1024 * 1024 # 10MB
|
||||
UPLOAD_PATH: str = Field(default="/tmp/uploads", env="SALES_UPLOAD_PATH")
|
||||
ALLOWED_FILE_EXTENSIONS: List[str] = [".csv", ".xlsx", ".xls"]
|
||||
|
||||
# Pagination
|
||||
DEFAULT_PAGE_SIZE: int = 50
|
||||
MAX_PAGE_SIZE: int = 1000
|
||||
|
||||
# Data validation
|
||||
MIN_QUANTITY: float = 0.01
|
||||
MAX_QUANTITY: float = 10000.0
|
||||
MIN_REVENUE: float = 0.01
|
||||
MAX_REVENUE: float = 100000.0
|
||||
|
||||
# Sales-specific cache TTL (5 minutes)
|
||||
SALES_CACHE_TTL: int = 300
|
||||
PRODUCT_CACHE_TTL: int = 600 # 10 minutes
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
86
services/sales/app/core/database.py
Normal file
86
services/sales/app/core/database.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# services/sales/app/core/database.py
|
||||
"""
|
||||
Sales Service Database Configuration using shared database manager
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from app.core.config import settings
|
||||
from shared.database.base import DatabaseManager, Base
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Create database manager instance
|
||||
database_manager = DatabaseManager(
|
||||
database_url=settings.DATABASE_URL,
|
||||
service_name="sales-service",
|
||||
pool_size=settings.DB_POOL_SIZE,
|
||||
max_overflow=settings.DB_MAX_OVERFLOW,
|
||||
pool_recycle=settings.DB_POOL_RECYCLE,
|
||||
echo=settings.DB_ECHO
|
||||
)
|
||||
|
||||
|
||||
async def get_db():
|
||||
"""
|
||||
Database dependency for FastAPI - using shared database manager
|
||||
"""
|
||||
async for session in database_manager.get_db():
|
||||
yield session
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""Initialize database tables using shared database manager"""
|
||||
try:
|
||||
logger.info("Initializing Sales Service database...")
|
||||
|
||||
# Import all models to ensure they're registered
|
||||
from app.models import sales # noqa: F401
|
||||
|
||||
# Create all tables using database manager
|
||||
await database_manager.create_tables(Base.metadata)
|
||||
|
||||
logger.info("Sales Service database initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize database", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
async def close_db():
|
||||
"""Close database connections using shared database manager"""
|
||||
try:
|
||||
await database_manager.close_connections()
|
||||
logger.info("Database connections closed")
|
||||
except Exception as e:
|
||||
logger.error("Error closing database connections", error=str(e))
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_transaction():
|
||||
"""
|
||||
Context manager for database transactions using shared database manager
|
||||
"""
|
||||
async with database_manager.get_session() as session:
|
||||
try:
|
||||
async with session.begin():
|
||||
yield session
|
||||
except Exception as e:
|
||||
logger.error("Transaction error", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_background_session():
|
||||
"""
|
||||
Context manager for background tasks using shared database manager
|
||||
"""
|
||||
async with database_manager.get_background_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def health_check():
|
||||
"""Database health check using shared database manager"""
|
||||
return await database_manager.health_check()
|
||||
@@ -1,8 +1,6 @@
|
||||
# ================================================================
|
||||
# services/data/app/main.py - FIXED VERSION
|
||||
# ================================================================
|
||||
# services/sales/app/main.py
|
||||
"""
|
||||
Data Service Main Application - Fixed middleware issue
|
||||
Sales Service Main Application
|
||||
"""
|
||||
|
||||
import structlog
|
||||
@@ -12,15 +10,12 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.database import init_db
|
||||
from app.api.sales import router as sales_router
|
||||
from app.api.weather import router as weather_router
|
||||
from app.api.traffic import router as traffic_router
|
||||
from app.core.database import init_db, close_db
|
||||
from shared.monitoring import setup_logging, HealthChecker
|
||||
from shared.monitoring.metrics import setup_metrics_early
|
||||
|
||||
# Setup logging first
|
||||
setup_logging("data-service", settings.LOG_LEVEL)
|
||||
setup_logging("sales-service", settings.LOG_LEVEL)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Global variables for lifespan access
|
||||
@@ -29,45 +24,42 @@ health_checker = None
|
||||
|
||||
# Create FastAPI app FIRST
|
||||
app = FastAPI(
|
||||
title="Bakery Data Service",
|
||||
description="External data integration service for weather, traffic, and sales data",
|
||||
title="Bakery Sales Service",
|
||||
description="Sales data management service for bakery operations",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Setup metrics BEFORE any middleware and BEFORE lifespan
|
||||
metrics_collector = setup_metrics_early(app, "data-service")
|
||||
metrics_collector = setup_metrics_early(app, "sales-service")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan events - NO MIDDLEWARE ADDED HERE"""
|
||||
"""Application lifespan events"""
|
||||
global health_checker
|
||||
|
||||
# Startup
|
||||
logger.info("Starting Data Service...")
|
||||
logger.info("Starting Sales Service...")
|
||||
|
||||
try:
|
||||
# Initialize database
|
||||
await init_db()
|
||||
logger.info("Database initialized")
|
||||
|
||||
# Register custom metrics (metrics_collector already exists)
|
||||
# Register custom metrics
|
||||
metrics_collector.register_counter("sales_records_created_total", "Total sales records created")
|
||||
metrics_collector.register_counter("sales_records_updated_total", "Total sales records updated")
|
||||
metrics_collector.register_counter("sales_queries_total", "Sales record queries")
|
||||
metrics_collector.register_counter("weather_api_calls_total", "Weather API calls")
|
||||
metrics_collector.register_counter("weather_api_success_total", "Successful weather API calls")
|
||||
metrics_collector.register_counter("weather_api_failures_total", "Failed weather API calls")
|
||||
metrics_collector.register_counter("traffic_api_calls_total", "Traffic API calls")
|
||||
metrics_collector.register_counter("product_queries_total", "Product catalog queries")
|
||||
metrics_collector.register_counter("import_jobs_total", "Data import jobs")
|
||||
metrics_collector.register_counter("template_downloads_total", "Template downloads")
|
||||
metrics_collector.register_counter("export_jobs_total", "Data export jobs")
|
||||
|
||||
metrics_collector.register_histogram("sales_create_duration_seconds", "Sales record creation duration")
|
||||
metrics_collector.register_histogram("sales_list_duration_seconds", "Sales record list duration")
|
||||
metrics_collector.register_histogram("import_duration_seconds", "Data import duration")
|
||||
metrics_collector.register_histogram("weather_current_duration_seconds", "Current weather API duration")
|
||||
metrics_collector.register_histogram("weather_forecast_duration_seconds", "Weather forecast API duration")
|
||||
metrics_collector.register_histogram("external_api_duration_seconds", "External API call duration")
|
||||
metrics_collector.register_histogram("sales_query_duration_seconds", "Sales query duration")
|
||||
metrics_collector.register_histogram("import_processing_duration_seconds", "Import processing duration")
|
||||
metrics_collector.register_histogram("export_generation_duration_seconds", "Export generation duration")
|
||||
|
||||
# Setup health checker
|
||||
health_checker = HealthChecker("data-service")
|
||||
health_checker = HealthChecker("sales-service")
|
||||
|
||||
# Add database health check
|
||||
async def check_database():
|
||||
@@ -85,16 +77,17 @@ async def lifespan(app: FastAPI):
|
||||
# Store health checker in app state
|
||||
app.state.health_checker = health_checker
|
||||
|
||||
logger.info("Data Service started successfully")
|
||||
logger.info("Sales Service started successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start Data Service: {e}")
|
||||
logger.error(f"Failed to start Sales Service: {e}")
|
||||
raise
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down Data Service...")
|
||||
logger.info("Shutting down Sales Service...")
|
||||
await close_db()
|
||||
|
||||
# Set lifespan AFTER metrics setup
|
||||
app.router.lifespan_context = lifespan
|
||||
@@ -102,16 +95,17 @@ app.router.lifespan_context = lifespan
|
||||
# CORS middleware (added after metrics setup)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=getattr(settings, 'CORS_ORIGINS', ["*"]),
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
# Include routers - import router BEFORE sales router to avoid conflicts
|
||||
from app.api.sales import router as sales_router
|
||||
from app.api.import_data import router as import_router
|
||||
app.include_router(import_router, prefix="/api/v1", tags=["import"])
|
||||
app.include_router(sales_router, prefix="/api/v1", tags=["sales"])
|
||||
app.include_router(weather_router, prefix="/api/v1", tags=["weather"])
|
||||
app.include_router(traffic_router, prefix="/api/v1", tags=["traffic"])
|
||||
|
||||
# Health check endpoint
|
||||
@app.get("/health")
|
||||
@@ -121,11 +115,27 @@ async def health_check():
|
||||
return await health_checker.check_health()
|
||||
else:
|
||||
return {
|
||||
"service": "data-service",
|
||||
"service": "sales-service",
|
||||
"status": "healthy",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
# Root endpoint
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {
|
||||
"service": "Sales Service",
|
||||
"version": "1.0.0",
|
||||
"status": "running",
|
||||
"endpoints": {
|
||||
"health": "/health",
|
||||
"docs": "/docs",
|
||||
"sales": "/api/v1/sales",
|
||||
"products": "/api/v1/products"
|
||||
}
|
||||
}
|
||||
|
||||
# Exception handlers
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
5
services/sales/app/models/__init__.py
Normal file
5
services/sales/app/models/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# services/sales/app/models/__init__.py
|
||||
|
||||
from .sales import SalesData, Product, SalesImportJob
|
||||
|
||||
__all__ = ["SalesData", "Product", "SalesImportJob"]
|
||||
238
services/sales/app/models/sales.py
Normal file
238
services/sales/app/models/sales.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# services/sales/app/models/sales.py
|
||||
"""
|
||||
Sales data models for Sales Service
|
||||
Enhanced with additional fields and relationships
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Float, Integer, Text, Index, Boolean, Numeric
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
|
||||
class SalesData(Base):
|
||||
"""Enhanced sales data model"""
|
||||
__tablename__ = "sales_data"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
date = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
|
||||
# Product information
|
||||
product_name = Column(String(255), nullable=False, index=True)
|
||||
product_category = Column(String(100), nullable=True, index=True)
|
||||
product_sku = Column(String(100), nullable=True, index=True)
|
||||
|
||||
# Sales data
|
||||
quantity_sold = Column(Integer, nullable=False)
|
||||
unit_price = Column(Numeric(10, 2), nullable=True)
|
||||
revenue = Column(Numeric(10, 2), nullable=False)
|
||||
cost_of_goods = Column(Numeric(10, 2), nullable=True) # For profit calculation
|
||||
discount_applied = Column(Numeric(5, 2), nullable=True, default=0.0) # Percentage
|
||||
|
||||
# Location and channel
|
||||
location_id = Column(String(100), nullable=True, index=True)
|
||||
sales_channel = Column(String(50), nullable=True, default="in_store") # in_store, online, delivery
|
||||
|
||||
# Data source and quality
|
||||
source = Column(String(50), nullable=False, default="manual") # manual, pos, online, import
|
||||
is_validated = Column(Boolean, default=False)
|
||||
validation_notes = Column(Text, nullable=True)
|
||||
|
||||
# Additional metadata
|
||||
notes = Column(Text, nullable=True)
|
||||
weather_condition = Column(String(50), nullable=True) # For correlation analysis
|
||||
is_holiday = Column(Boolean, default=False)
|
||||
is_weekend = Column(Boolean, default=False)
|
||||
|
||||
# Audit fields
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc))
|
||||
created_by = Column(UUID(as_uuid=True), nullable=True) # User ID
|
||||
|
||||
# Performance-optimized indexes
|
||||
__table_args__ = (
|
||||
# Core query patterns
|
||||
Index('idx_sales_tenant_date', 'tenant_id', 'date'),
|
||||
Index('idx_sales_tenant_product', 'tenant_id', 'product_name'),
|
||||
Index('idx_sales_tenant_location', 'tenant_id', 'location_id'),
|
||||
Index('idx_sales_tenant_category', 'tenant_id', 'product_category'),
|
||||
|
||||
# Analytics queries
|
||||
Index('idx_sales_date_range', 'date', 'tenant_id'),
|
||||
Index('idx_sales_product_date', 'product_name', 'date', 'tenant_id'),
|
||||
Index('idx_sales_channel_date', 'sales_channel', 'date', 'tenant_id'),
|
||||
|
||||
# Data quality queries
|
||||
Index('idx_sales_source_validated', 'source', 'is_validated', 'tenant_id'),
|
||||
Index('idx_sales_sku_date', 'product_sku', 'date', 'tenant_id'),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert model to dictionary for API responses"""
|
||||
return {
|
||||
'id': str(self.id),
|
||||
'tenant_id': str(self.tenant_id),
|
||||
'date': self.date.isoformat() if self.date else None,
|
||||
'product_name': self.product_name,
|
||||
'product_category': self.product_category,
|
||||
'product_sku': self.product_sku,
|
||||
'quantity_sold': self.quantity_sold,
|
||||
'unit_price': float(self.unit_price) if self.unit_price else None,
|
||||
'revenue': float(self.revenue) if self.revenue else None,
|
||||
'cost_of_goods': float(self.cost_of_goods) if self.cost_of_goods else None,
|
||||
'discount_applied': float(self.discount_applied) if self.discount_applied else None,
|
||||
'location_id': self.location_id,
|
||||
'sales_channel': self.sales_channel,
|
||||
'source': self.source,
|
||||
'is_validated': self.is_validated,
|
||||
'validation_notes': self.validation_notes,
|
||||
'notes': self.notes,
|
||||
'weather_condition': self.weather_condition,
|
||||
'is_holiday': self.is_holiday,
|
||||
'is_weekend': self.is_weekend,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
|
||||
'created_by': str(self.created_by) if self.created_by else None,
|
||||
}
|
||||
|
||||
@property
|
||||
def profit_margin(self) -> Optional[float]:
|
||||
"""Calculate profit margin if cost data is available"""
|
||||
if self.revenue and self.cost_of_goods:
|
||||
return float((self.revenue - self.cost_of_goods) / self.revenue * 100)
|
||||
return None
|
||||
|
||||
|
||||
class Product(Base):
|
||||
"""Product catalog model - future expansion"""
|
||||
__tablename__ = "products"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Product identification
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
sku = Column(String(100), nullable=True, index=True)
|
||||
category = Column(String(100), nullable=True, index=True)
|
||||
subcategory = Column(String(100), nullable=True)
|
||||
|
||||
# Product details
|
||||
description = Column(Text, nullable=True)
|
||||
unit_of_measure = Column(String(20), nullable=False, default="unit")
|
||||
weight = Column(Float, nullable=True) # in grams
|
||||
volume = Column(Float, nullable=True) # in ml
|
||||
|
||||
# Pricing
|
||||
base_price = Column(Numeric(10, 2), nullable=True)
|
||||
cost_price = Column(Numeric(10, 2), nullable=True)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_seasonal = Column(Boolean, default=False)
|
||||
seasonal_start = Column(DateTime(timezone=True), nullable=True)
|
||||
seasonal_end = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Audit fields
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_products_tenant_name', 'tenant_id', 'name', unique=True),
|
||||
Index('idx_products_tenant_sku', 'tenant_id', 'sku'),
|
||||
Index('idx_products_category', 'tenant_id', 'category', 'is_active'),
|
||||
Index('idx_products_seasonal', 'is_seasonal', 'seasonal_start', 'seasonal_end'),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert model to dictionary for API responses"""
|
||||
return {
|
||||
'id': str(self.id),
|
||||
'tenant_id': str(self.tenant_id),
|
||||
'name': self.name,
|
||||
'sku': self.sku,
|
||||
'category': self.category,
|
||||
'subcategory': self.subcategory,
|
||||
'description': self.description,
|
||||
'unit_of_measure': self.unit_of_measure,
|
||||
'weight': self.weight,
|
||||
'volume': self.volume,
|
||||
'base_price': float(self.base_price) if self.base_price else None,
|
||||
'cost_price': float(self.cost_price) if self.cost_price else None,
|
||||
'is_active': self.is_active,
|
||||
'is_seasonal': self.is_seasonal,
|
||||
'seasonal_start': self.seasonal_start.isoformat() if self.seasonal_start else None,
|
||||
'seasonal_end': self.seasonal_end.isoformat() if self.seasonal_end else None,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
class SalesImportJob(Base):
|
||||
"""Track sales data import jobs"""
|
||||
__tablename__ = "sales_import_jobs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Job details
|
||||
filename = Column(String(255), nullable=False)
|
||||
file_size = Column(Integer, nullable=True)
|
||||
import_type = Column(String(50), nullable=False, default="csv") # csv, xlsx, api
|
||||
|
||||
# Processing status
|
||||
status = Column(String(20), nullable=False, default="pending") # pending, processing, completed, failed
|
||||
progress_percentage = Column(Float, default=0.0)
|
||||
|
||||
# Results
|
||||
total_rows = Column(Integer, default=0)
|
||||
processed_rows = Column(Integer, default=0)
|
||||
successful_imports = Column(Integer, default=0)
|
||||
failed_imports = Column(Integer, default=0)
|
||||
duplicate_rows = Column(Integer, default=0)
|
||||
|
||||
# Error tracking
|
||||
error_message = Column(Text, nullable=True)
|
||||
validation_errors = Column(Text, nullable=True) # JSON string of validation errors
|
||||
|
||||
# Timestamps
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
created_by = Column(UUID(as_uuid=True), nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_import_jobs_tenant_status', 'tenant_id', 'status', 'created_at'),
|
||||
Index('idx_import_jobs_status_date', 'status', 'created_at'),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert model to dictionary for API responses"""
|
||||
return {
|
||||
'id': str(self.id),
|
||||
'tenant_id': str(self.tenant_id),
|
||||
'filename': self.filename,
|
||||
'file_size': self.file_size,
|
||||
'import_type': self.import_type,
|
||||
'status': self.status,
|
||||
'progress_percentage': self.progress_percentage,
|
||||
'total_rows': self.total_rows,
|
||||
'processed_rows': self.processed_rows,
|
||||
'successful_imports': self.successful_imports,
|
||||
'failed_imports': self.failed_imports,
|
||||
'duplicate_rows': self.duplicate_rows,
|
||||
'error_message': self.error_message,
|
||||
'validation_errors': self.validation_errors,
|
||||
'started_at': self.started_at.isoformat() if self.started_at else None,
|
||||
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'created_by': str(self.created_by) if self.created_by else None,
|
||||
}
|
||||
6
services/sales/app/repositories/__init__.py
Normal file
6
services/sales/app/repositories/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# services/sales/app/repositories/__init__.py
|
||||
|
||||
from .sales_repository import SalesRepository
|
||||
from .product_repository import ProductRepository
|
||||
|
||||
__all__ = ["SalesRepository", "ProductRepository"]
|
||||
193
services/sales/app/repositories/product_repository.py
Normal file
193
services/sales/app/repositories/product_repository.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# services/sales/app/repositories/product_repository.py
|
||||
"""
|
||||
Product Repository using Repository Pattern
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
from sqlalchemy import select, and_, or_, desc, asc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import structlog
|
||||
|
||||
from app.models.sales import Product
|
||||
from app.schemas.sales import ProductCreate, ProductUpdate
|
||||
from shared.database.repository import BaseRepository
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ProductRepository(BaseRepository[Product, ProductCreate, ProductUpdate]):
|
||||
"""Repository for product operations"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
super().__init__(Product, db_session)
|
||||
|
||||
async def create_product(self, product_data: ProductCreate, tenant_id: UUID) -> Product:
|
||||
"""Create a new product"""
|
||||
try:
|
||||
# Prepare data
|
||||
create_data = product_data.model_dump()
|
||||
create_data['tenant_id'] = tenant_id
|
||||
|
||||
# Create product
|
||||
product = await self.create(create_data)
|
||||
logger.info(
|
||||
"Created product",
|
||||
product_id=product.id,
|
||||
name=product.name,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return product
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create product", error=str(e), tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def get_by_tenant(self, tenant_id: UUID, include_inactive: bool = False) -> List[Product]:
|
||||
"""Get all products for a tenant"""
|
||||
try:
|
||||
stmt = select(Product).where(Product.tenant_id == tenant_id)
|
||||
|
||||
if not include_inactive:
|
||||
stmt = stmt.where(Product.is_active == True)
|
||||
|
||||
stmt = stmt.order_by(Product.category, Product.name)
|
||||
|
||||
result = await self.db_session.execute(stmt)
|
||||
products = result.scalars().all()
|
||||
|
||||
logger.info(
|
||||
"Retrieved products",
|
||||
count=len(products),
|
||||
tenant_id=tenant_id,
|
||||
include_inactive=include_inactive
|
||||
)
|
||||
return list(products)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get products", error=str(e), tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def get_by_category(self, tenant_id: UUID, category: str) -> List[Product]:
|
||||
"""Get products by category"""
|
||||
try:
|
||||
stmt = select(Product).where(
|
||||
and_(
|
||||
Product.tenant_id == tenant_id,
|
||||
Product.category == category,
|
||||
Product.is_active == True
|
||||
)
|
||||
).order_by(Product.name)
|
||||
|
||||
result = await self.db_session.execute(stmt)
|
||||
products = result.scalars().all()
|
||||
|
||||
return list(products)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get products by category", error=str(e), tenant_id=tenant_id, category=category)
|
||||
raise
|
||||
|
||||
async def get_by_name(self, tenant_id: UUID, name: str) -> Optional[Product]:
|
||||
"""Get product by name"""
|
||||
try:
|
||||
stmt = select(Product).where(
|
||||
and_(
|
||||
Product.tenant_id == tenant_id,
|
||||
Product.name == name
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(stmt)
|
||||
product = result.scalar_one_or_none()
|
||||
|
||||
return product
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get product by name", error=str(e), tenant_id=tenant_id, name=name)
|
||||
raise
|
||||
|
||||
async def get_by_sku(self, tenant_id: UUID, sku: str) -> Optional[Product]:
|
||||
"""Get product by SKU"""
|
||||
try:
|
||||
stmt = select(Product).where(
|
||||
and_(
|
||||
Product.tenant_id == tenant_id,
|
||||
Product.sku == sku
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(stmt)
|
||||
product = result.scalar_one_or_none()
|
||||
|
||||
return product
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get product by SKU", error=str(e), tenant_id=tenant_id, sku=sku)
|
||||
raise
|
||||
|
||||
async def search_products(self, tenant_id: UUID, query: str, limit: int = 50) -> List[Product]:
|
||||
"""Search products by name or SKU"""
|
||||
try:
|
||||
stmt = select(Product).where(
|
||||
and_(
|
||||
Product.tenant_id == tenant_id,
|
||||
Product.is_active == True,
|
||||
or_(
|
||||
Product.name.ilike(f"%{query}%"),
|
||||
Product.sku.ilike(f"%{query}%"),
|
||||
Product.description.ilike(f"%{query}%")
|
||||
)
|
||||
)
|
||||
).order_by(Product.name).limit(limit)
|
||||
|
||||
result = await self.db_session.execute(stmt)
|
||||
products = result.scalars().all()
|
||||
|
||||
return list(products)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to search products", error=str(e), tenant_id=tenant_id, query=query)
|
||||
raise
|
||||
|
||||
async def get_categories(self, tenant_id: UUID) -> List[str]:
|
||||
"""Get distinct product categories for a tenant"""
|
||||
try:
|
||||
stmt = select(Product.category).where(
|
||||
and_(
|
||||
Product.tenant_id == tenant_id,
|
||||
Product.is_active == True,
|
||||
Product.category.is_not(None)
|
||||
)
|
||||
).distinct()
|
||||
|
||||
result = await self.db_session.execute(stmt)
|
||||
categories = [row[0] for row in result if row[0]]
|
||||
|
||||
return sorted(categories)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get product categories", error=str(e), tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def deactivate_product(self, product_id: UUID) -> Product:
|
||||
"""Deactivate a product"""
|
||||
try:
|
||||
product = await self.update(product_id, {'is_active': False})
|
||||
logger.info("Deactivated product", product_id=product_id)
|
||||
return product
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to deactivate product", error=str(e), product_id=product_id)
|
||||
raise
|
||||
|
||||
async def activate_product(self, product_id: UUID) -> Product:
|
||||
"""Activate a product"""
|
||||
try:
|
||||
product = await self.update(product_id, {'is_active': True})
|
||||
logger.info("Activated product", product_id=product_id)
|
||||
return product
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to activate product", error=str(e), product_id=product_id)
|
||||
raise
|
||||
296
services/sales/app/repositories/sales_repository.py
Normal file
296
services/sales/app/repositories/sales_repository.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# services/sales/app/repositories/sales_repository.py
|
||||
"""
|
||||
Sales Repository using Repository Pattern
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select, func, and_, or_, desc, asc
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import structlog
|
||||
|
||||
from app.models.sales import SalesData
|
||||
from app.schemas.sales import SalesDataCreate, SalesDataUpdate, SalesDataQuery
|
||||
from shared.database.repository import BaseRepository
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SalesRepository(BaseRepository[SalesData, SalesDataCreate, SalesDataUpdate]):
|
||||
"""Repository for sales data operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(SalesData, session)
|
||||
|
||||
async def create_sales_record(self, sales_data: SalesDataCreate, tenant_id: UUID) -> SalesData:
|
||||
"""Create a new sales record"""
|
||||
try:
|
||||
# Prepare data
|
||||
create_data = sales_data.model_dump()
|
||||
create_data['tenant_id'] = tenant_id
|
||||
|
||||
# Calculate weekend flag if not provided
|
||||
if sales_data.date and create_data.get('is_weekend') is None:
|
||||
create_data['is_weekend'] = sales_data.date.weekday() >= 5
|
||||
|
||||
# Create record
|
||||
record = await self.create(create_data)
|
||||
logger.info(
|
||||
"Created sales record",
|
||||
record_id=record.id,
|
||||
product=record.product_name,
|
||||
quantity=record.quantity_sold,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return record
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create sales record", error=str(e), tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def get_by_tenant(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
query_params: Optional[SalesDataQuery] = None
|
||||
) -> List[SalesData]:
|
||||
"""Get sales records by tenant with optional filtering"""
|
||||
try:
|
||||
# Build base query
|
||||
stmt = select(SalesData).where(SalesData.tenant_id == tenant_id)
|
||||
|
||||
# Apply filters if query_params provided
|
||||
if query_params:
|
||||
if query_params.start_date:
|
||||
stmt = stmt.where(SalesData.date >= query_params.start_date)
|
||||
if query_params.end_date:
|
||||
stmt = stmt.where(SalesData.date <= query_params.end_date)
|
||||
if query_params.product_name:
|
||||
stmt = stmt.where(SalesData.product_name.ilike(f"%{query_params.product_name}%"))
|
||||
if query_params.product_category:
|
||||
stmt = stmt.where(SalesData.product_category == query_params.product_category)
|
||||
if query_params.location_id:
|
||||
stmt = stmt.where(SalesData.location_id == query_params.location_id)
|
||||
if query_params.sales_channel:
|
||||
stmt = stmt.where(SalesData.sales_channel == query_params.sales_channel)
|
||||
if query_params.source:
|
||||
stmt = stmt.where(SalesData.source == query_params.source)
|
||||
if query_params.is_validated is not None:
|
||||
stmt = stmt.where(SalesData.is_validated == query_params.is_validated)
|
||||
|
||||
# Apply ordering
|
||||
if query_params.order_by and hasattr(SalesData, query_params.order_by):
|
||||
order_col = getattr(SalesData, query_params.order_by)
|
||||
if query_params.order_direction == 'asc':
|
||||
stmt = stmt.order_by(asc(order_col))
|
||||
else:
|
||||
stmt = stmt.order_by(desc(order_col))
|
||||
else:
|
||||
stmt = stmt.order_by(desc(SalesData.date))
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.offset(query_params.offset).limit(query_params.limit)
|
||||
else:
|
||||
# Default ordering
|
||||
stmt = stmt.order_by(desc(SalesData.date)).limit(50)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
records = result.scalars().all()
|
||||
|
||||
logger.info(
|
||||
"Retrieved sales records",
|
||||
count=len(records),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return list(records)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales records", error=str(e), tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def get_by_product(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_name: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> List[SalesData]:
|
||||
"""Get sales records for a specific product"""
|
||||
try:
|
||||
stmt = select(SalesData).where(
|
||||
and_(
|
||||
SalesData.tenant_id == tenant_id,
|
||||
SalesData.product_name == product_name
|
||||
)
|
||||
)
|
||||
|
||||
if start_date:
|
||||
stmt = stmt.where(SalesData.date >= start_date)
|
||||
if end_date:
|
||||
stmt = stmt.where(SalesData.date <= end_date)
|
||||
|
||||
stmt = stmt.order_by(desc(SalesData.date))
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
records = result.scalars().all()
|
||||
|
||||
return list(records)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get product sales", error=str(e), tenant_id=tenant_id, product=product_name)
|
||||
raise
|
||||
|
||||
async def get_analytics(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get sales analytics for a tenant"""
|
||||
try:
|
||||
# Build base query
|
||||
base_query = select(SalesData).where(SalesData.tenant_id == tenant_id)
|
||||
|
||||
if start_date:
|
||||
base_query = base_query.where(SalesData.date >= start_date)
|
||||
if end_date:
|
||||
base_query = base_query.where(SalesData.date <= end_date)
|
||||
|
||||
# Total revenue and quantity
|
||||
summary_query = select(
|
||||
func.sum(SalesData.revenue).label('total_revenue'),
|
||||
func.sum(SalesData.quantity_sold).label('total_quantity'),
|
||||
func.count().label('total_transactions'),
|
||||
func.avg(SalesData.revenue).label('avg_transaction_value')
|
||||
).where(SalesData.tenant_id == tenant_id)
|
||||
|
||||
if start_date:
|
||||
summary_query = summary_query.where(SalesData.date >= start_date)
|
||||
if end_date:
|
||||
summary_query = summary_query.where(SalesData.date <= end_date)
|
||||
|
||||
result = await self.session.execute(summary_query)
|
||||
summary = result.first()
|
||||
|
||||
# Top products
|
||||
top_products_query = select(
|
||||
SalesData.product_name,
|
||||
func.sum(SalesData.revenue).label('revenue'),
|
||||
func.sum(SalesData.quantity_sold).label('quantity')
|
||||
).where(SalesData.tenant_id == tenant_id)
|
||||
|
||||
if start_date:
|
||||
top_products_query = top_products_query.where(SalesData.date >= start_date)
|
||||
if end_date:
|
||||
top_products_query = top_products_query.where(SalesData.date <= end_date)
|
||||
|
||||
top_products_query = top_products_query.group_by(
|
||||
SalesData.product_name
|
||||
).order_by(
|
||||
desc(func.sum(SalesData.revenue))
|
||||
).limit(10)
|
||||
|
||||
top_products_result = await self.session.execute(top_products_query)
|
||||
top_products = [
|
||||
{
|
||||
'product_name': row.product_name,
|
||||
'revenue': float(row.revenue) if row.revenue else 0,
|
||||
'quantity': row.quantity or 0
|
||||
}
|
||||
for row in top_products_result
|
||||
]
|
||||
|
||||
# Sales by channel
|
||||
channel_query = select(
|
||||
SalesData.sales_channel,
|
||||
func.sum(SalesData.revenue).label('revenue'),
|
||||
func.count().label('transactions')
|
||||
).where(SalesData.tenant_id == tenant_id)
|
||||
|
||||
if start_date:
|
||||
channel_query = channel_query.where(SalesData.date >= start_date)
|
||||
if end_date:
|
||||
channel_query = channel_query.where(SalesData.date <= end_date)
|
||||
|
||||
channel_query = channel_query.group_by(SalesData.sales_channel)
|
||||
|
||||
channel_result = await self.session.execute(channel_query)
|
||||
sales_by_channel = {
|
||||
row.sales_channel: {
|
||||
'revenue': float(row.revenue) if row.revenue else 0,
|
||||
'transactions': row.transactions or 0
|
||||
}
|
||||
for row in channel_result
|
||||
}
|
||||
|
||||
return {
|
||||
'total_revenue': float(summary.total_revenue) if summary.total_revenue else 0,
|
||||
'total_quantity': summary.total_quantity or 0,
|
||||
'total_transactions': summary.total_transactions or 0,
|
||||
'average_transaction_value': float(summary.avg_transaction_value) if summary.avg_transaction_value else 0,
|
||||
'top_products': top_products,
|
||||
'sales_by_channel': sales_by_channel
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales analytics", error=str(e), tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def get_product_categories(self, tenant_id: UUID) -> List[str]:
|
||||
"""Get distinct product categories for a tenant"""
|
||||
try:
|
||||
stmt = select(SalesData.product_category).where(
|
||||
and_(
|
||||
SalesData.tenant_id == tenant_id,
|
||||
SalesData.product_category.is_not(None)
|
||||
)
|
||||
).distinct()
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
categories = [row[0] for row in result if row[0]]
|
||||
|
||||
return sorted(categories)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get product categories", error=str(e), tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def validate_record(self, record_id: UUID, validation_notes: Optional[str] = None) -> SalesData:
|
||||
"""Mark a sales record as validated"""
|
||||
try:
|
||||
record = await self.get_by_id(record_id)
|
||||
if not record:
|
||||
raise ValueError(f"Sales record {record_id} not found")
|
||||
|
||||
update_data = {
|
||||
'is_validated': True,
|
||||
'validation_notes': validation_notes
|
||||
}
|
||||
|
||||
updated_record = await self.update(record_id, update_data)
|
||||
|
||||
logger.info("Validated sales record", record_id=record_id)
|
||||
return updated_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate sales record", error=str(e), record_id=record_id)
|
||||
raise
|
||||
|
||||
async def get_product_statistics(self, tenant_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get product statistics for tenant"""
|
||||
try:
|
||||
stmt = select(SalesData.product_name).where(
|
||||
and_(
|
||||
SalesData.tenant_id == tenant_id,
|
||||
SalesData.product_name.is_not(None)
|
||||
)
|
||||
).distinct()
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
products = [row[0] for row in result if row[0]]
|
||||
|
||||
return sorted(products)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get product categories", error=str(e), tenant_id=tenant_id)
|
||||
raise
|
||||
25
services/sales/app/schemas/__init__.py
Normal file
25
services/sales/app/schemas/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# services/sales/app/schemas/__init__.py
|
||||
|
||||
from .sales import (
|
||||
SalesDataCreate,
|
||||
SalesDataUpdate,
|
||||
SalesDataResponse,
|
||||
SalesDataQuery,
|
||||
ProductCreate,
|
||||
ProductUpdate,
|
||||
ProductResponse,
|
||||
SalesAnalytics,
|
||||
ProductSalesAnalytics
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SalesDataCreate",
|
||||
"SalesDataUpdate",
|
||||
"SalesDataResponse",
|
||||
"SalesDataQuery",
|
||||
"ProductCreate",
|
||||
"ProductUpdate",
|
||||
"ProductResponse",
|
||||
"SalesAnalytics",
|
||||
"ProductSalesAnalytics"
|
||||
]
|
||||
198
services/sales/app/schemas/sales.py
Normal file
198
services/sales/app/schemas/sales.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# services/sales/app/schemas/sales.py
|
||||
"""
|
||||
Sales Service Pydantic Schemas
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
class SalesDataBase(BaseModel):
|
||||
"""Base sales data schema"""
|
||||
product_name: str = Field(..., min_length=1, max_length=255, description="Product name")
|
||||
product_category: Optional[str] = Field(None, max_length=100, description="Product category")
|
||||
product_sku: Optional[str] = Field(None, max_length=100, description="Product SKU")
|
||||
|
||||
quantity_sold: int = Field(..., gt=0, description="Quantity sold")
|
||||
unit_price: Optional[Decimal] = Field(None, ge=0, description="Unit price")
|
||||
revenue: Decimal = Field(..., gt=0, description="Total revenue")
|
||||
cost_of_goods: Optional[Decimal] = Field(None, ge=0, description="Cost of goods sold")
|
||||
discount_applied: Optional[Decimal] = Field(0, ge=0, le=100, description="Discount percentage")
|
||||
|
||||
location_id: Optional[str] = Field(None, max_length=100, description="Location identifier")
|
||||
sales_channel: Optional[str] = Field("in_store", description="Sales channel")
|
||||
source: str = Field("manual", description="Data source")
|
||||
|
||||
notes: Optional[str] = Field(None, description="Additional notes")
|
||||
weather_condition: Optional[str] = Field(None, max_length=50, description="Weather condition")
|
||||
is_holiday: bool = Field(False, description="Holiday flag")
|
||||
is_weekend: bool = Field(False, description="Weekend flag")
|
||||
|
||||
@validator('sales_channel')
|
||||
def validate_sales_channel(cls, v):
|
||||
allowed_channels = ['in_store', 'online', 'delivery', 'wholesale']
|
||||
if v not in allowed_channels:
|
||||
raise ValueError(f'Sales channel must be one of: {allowed_channels}')
|
||||
return v
|
||||
|
||||
@validator('source')
|
||||
def validate_source(cls, v):
|
||||
allowed_sources = ['manual', 'pos', 'online', 'import', 'api', 'csv']
|
||||
if v not in allowed_sources:
|
||||
raise ValueError(f'Source must be one of: {allowed_sources}')
|
||||
return v
|
||||
|
||||
|
||||
class SalesDataCreate(SalesDataBase):
|
||||
"""Schema for creating sales data"""
|
||||
tenant_id: Optional[UUID] = Field(None, description="Tenant ID (set automatically)")
|
||||
date: datetime = Field(..., description="Sale date and time")
|
||||
|
||||
|
||||
class SalesDataUpdate(BaseModel):
|
||||
"""Schema for updating sales data"""
|
||||
product_name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
product_category: Optional[str] = Field(None, max_length=100)
|
||||
product_sku: Optional[str] = Field(None, max_length=100)
|
||||
|
||||
quantity_sold: Optional[int] = Field(None, gt=0)
|
||||
unit_price: Optional[Decimal] = Field(None, ge=0)
|
||||
revenue: Optional[Decimal] = Field(None, gt=0)
|
||||
cost_of_goods: Optional[Decimal] = Field(None, ge=0)
|
||||
discount_applied: Optional[Decimal] = Field(None, ge=0, le=100)
|
||||
|
||||
location_id: Optional[str] = Field(None, max_length=100)
|
||||
sales_channel: Optional[str] = None
|
||||
|
||||
notes: Optional[str] = None
|
||||
weather_condition: Optional[str] = Field(None, max_length=50)
|
||||
is_holiday: Optional[bool] = None
|
||||
is_weekend: Optional[bool] = None
|
||||
|
||||
validation_notes: Optional[str] = None
|
||||
is_validated: Optional[bool] = None
|
||||
|
||||
|
||||
class SalesDataResponse(SalesDataBase):
|
||||
"""Schema for sales data responses"""
|
||||
id: UUID
|
||||
tenant_id: UUID
|
||||
date: datetime
|
||||
|
||||
is_validated: bool = False
|
||||
validation_notes: Optional[str] = None
|
||||
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
created_by: Optional[UUID] = None
|
||||
|
||||
profit_margin: Optional[float] = Field(None, description="Calculated profit margin")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SalesDataQuery(BaseModel):
|
||||
"""Schema for sales data queries"""
|
||||
start_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None
|
||||
product_name: Optional[str] = None
|
||||
product_category: Optional[str] = None
|
||||
location_id: Optional[str] = None
|
||||
sales_channel: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
is_validated: Optional[bool] = None
|
||||
|
||||
limit: int = Field(50, ge=1, le=1000, description="Number of records to return")
|
||||
offset: int = Field(0, ge=0, description="Number of records to skip")
|
||||
|
||||
order_by: str = Field("date", description="Field to order by")
|
||||
order_direction: str = Field("desc", description="Order direction")
|
||||
|
||||
@validator('order_direction')
|
||||
def validate_order_direction(cls, v):
|
||||
if v.lower() not in ['asc', 'desc']:
|
||||
raise ValueError('Order direction must be "asc" or "desc"')
|
||||
return v.lower()
|
||||
|
||||
|
||||
# Product schemas
|
||||
class ProductBase(BaseModel):
|
||||
"""Base product schema"""
|
||||
name: str = Field(..., min_length=1, max_length=255, description="Product name")
|
||||
sku: Optional[str] = Field(None, max_length=100, description="Stock Keeping Unit")
|
||||
category: Optional[str] = Field(None, max_length=100, description="Product category")
|
||||
subcategory: Optional[str] = Field(None, max_length=100, description="Product subcategory")
|
||||
|
||||
description: Optional[str] = Field(None, description="Product description")
|
||||
unit_of_measure: str = Field("unit", description="Unit of measure")
|
||||
weight: Optional[float] = Field(None, gt=0, description="Weight in grams")
|
||||
volume: Optional[float] = Field(None, gt=0, description="Volume in ml")
|
||||
|
||||
base_price: Optional[Decimal] = Field(None, ge=0, description="Base selling price")
|
||||
cost_price: Optional[Decimal] = Field(None, ge=0, description="Cost price")
|
||||
|
||||
is_seasonal: bool = Field(False, description="Seasonal product flag")
|
||||
seasonal_start: Optional[datetime] = Field(None, description="Season start date")
|
||||
seasonal_end: Optional[datetime] = Field(None, description="Season end date")
|
||||
|
||||
|
||||
class ProductCreate(ProductBase):
|
||||
"""Schema for creating products"""
|
||||
tenant_id: Optional[UUID] = Field(None, description="Tenant ID (set automatically)")
|
||||
|
||||
|
||||
class ProductUpdate(BaseModel):
|
||||
"""Schema for updating products"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=255)
|
||||
sku: Optional[str] = Field(None, max_length=100)
|
||||
category: Optional[str] = Field(None, max_length=100)
|
||||
subcategory: Optional[str] = Field(None, max_length=100)
|
||||
description: Optional[str] = None
|
||||
unit_of_measure: Optional[str] = None
|
||||
weight: Optional[float] = Field(None, gt=0)
|
||||
volume: Optional[float] = Field(None, gt=0)
|
||||
base_price: Optional[Decimal] = Field(None, ge=0)
|
||||
cost_price: Optional[Decimal] = Field(None, ge=0)
|
||||
is_active: Optional[bool] = None
|
||||
is_seasonal: Optional[bool] = None
|
||||
seasonal_start: Optional[datetime] = None
|
||||
seasonal_end: Optional[datetime] = None
|
||||
|
||||
|
||||
class ProductResponse(ProductBase):
|
||||
"""Schema for product responses"""
|
||||
id: UUID
|
||||
tenant_id: UUID
|
||||
is_active: bool = True
|
||||
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# Analytics schemas
|
||||
class SalesAnalytics(BaseModel):
|
||||
"""Sales analytics response"""
|
||||
total_revenue: Decimal
|
||||
total_quantity: int
|
||||
total_transactions: int
|
||||
average_transaction_value: Decimal
|
||||
top_products: List[dict]
|
||||
sales_by_channel: dict
|
||||
sales_by_day: List[dict]
|
||||
|
||||
|
||||
class ProductSalesAnalytics(BaseModel):
|
||||
"""Product-specific sales analytics"""
|
||||
product_name: str
|
||||
total_revenue: Decimal
|
||||
total_quantity: int
|
||||
total_transactions: int
|
||||
average_price: Decimal
|
||||
growth_rate: Optional[float] = None
|
||||
8
services/sales/app/services/__init__.py
Normal file
8
services/sales/app/services/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# services/sales/app/services/__init__.py
|
||||
|
||||
from .sales_service import SalesService
|
||||
from .product_service import ProductService
|
||||
from .data_import_service import DataImportService
|
||||
from .messaging import SalesEventPublisher, sales_publisher
|
||||
|
||||
__all__ = ["SalesService", "ProductService", "DataImportService", "SalesEventPublisher", "sales_publisher"]
|
||||
@@ -1,5 +1,6 @@
|
||||
# services/sales/app/services/data_import_service.py
|
||||
"""
|
||||
Enhanced Data Import Service
|
||||
Data Import Service
|
||||
Service for importing sales data using repository pattern and enhanced error handling
|
||||
"""
|
||||
|
||||
@@ -15,16 +16,46 @@ import re
|
||||
|
||||
from app.repositories.sales_repository import SalesRepository
|
||||
from app.models.sales import SalesData
|
||||
from app.schemas.sales import SalesDataCreate, SalesImportResult, SalesValidationResult
|
||||
from shared.database.unit_of_work import UnitOfWork
|
||||
from shared.database.transactions import transactional
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
from app.schemas.sales import SalesDataCreate
|
||||
from app.core.database import get_db_transaction
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class EnhancedDataImportService:
|
||||
"""Enhanced data import service using repository pattern"""
|
||||
# Import result schemas (dataclass definitions)
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Any
|
||||
|
||||
@dataclass
|
||||
class SalesValidationResult:
|
||||
is_valid: bool
|
||||
total_records: int
|
||||
valid_records: int
|
||||
invalid_records: int
|
||||
errors: List[Dict[str, Any]]
|
||||
warnings: List[Dict[str, Any]]
|
||||
summary: Dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class SalesImportResult:
|
||||
success: bool
|
||||
records_processed: int
|
||||
records_created: int
|
||||
records_updated: int
|
||||
records_failed: int
|
||||
errors: List[Dict[str, Any]]
|
||||
warnings: List[Dict[str, Any]]
|
||||
processing_time_seconds: float
|
||||
|
||||
|
||||
class DataImportService:
|
||||
"""Enhanced data import service using repository pattern with STRICT validation for production"""
|
||||
|
||||
# PRODUCTION VALIDATION CONFIGURATION
|
||||
STRICT_VALIDATION = True # Set to False for lenient validation, True for production quality
|
||||
MAX_QUANTITY_PER_DAY = 10000 # Maximum reasonable quantity per product per day
|
||||
MAX_REVENUE_PER_ITEM = 100000 # Maximum reasonable revenue per line item
|
||||
MAX_UNIT_PRICE = 10000 # Maximum reasonable price per unit for bakery items
|
||||
|
||||
# Common column mappings for different languages/formats
|
||||
COLUMN_MAPPINGS = {
|
||||
@@ -46,14 +77,14 @@ class EnhancedDataImportService:
|
||||
'%Y-%m-%d %H:%M:%S', '%d/%m/%Y %H:%M',
|
||||
]
|
||||
|
||||
def __init__(self, database_manager):
|
||||
"""Initialize service with database manager"""
|
||||
self.database_manager = database_manager
|
||||
def __init__(self):
|
||||
"""Initialize enhanced import service"""
|
||||
pass
|
||||
|
||||
async def validate_import_data(self, data: Dict[str, Any]) -> SalesValidationResult:
|
||||
"""Validate import data before processing"""
|
||||
"""Enhanced validation with better error handling and suggestions"""
|
||||
try:
|
||||
logger.info("Starting import data validation", tenant_id=data.get("tenant_id"))
|
||||
logger.info("Starting enhanced import data validation", tenant_id=data.get("tenant_id"))
|
||||
|
||||
validation_result = SalesValidationResult(
|
||||
is_valid=True,
|
||||
@@ -156,7 +187,7 @@ class EnhancedDataImportService:
|
||||
"code": "NO_CONTENT"
|
||||
})
|
||||
else:
|
||||
# Analyze structure
|
||||
# Enhanced column analysis
|
||||
headers = list(rows[0].keys()) if rows else []
|
||||
column_mapping = self._detect_columns(headers)
|
||||
|
||||
@@ -188,17 +219,65 @@ class EnhancedDataImportService:
|
||||
"code": "MISSING_QUANTITY_COLUMN"
|
||||
})
|
||||
|
||||
# Calculate estimated valid/invalid records
|
||||
# Enhanced data quality estimation
|
||||
if not errors:
|
||||
estimated_invalid = max(0, int(validation_result.total_records * 0.1))
|
||||
sample_size = min(10, len(rows))
|
||||
sample_rows = rows[:sample_size]
|
||||
quality_issues = 0
|
||||
|
||||
for i, row in enumerate(sample_rows):
|
||||
parsed_data = await self._parse_row_data(row, column_mapping, i + 1)
|
||||
if parsed_data.get("skip") or parsed_data.get("errors"):
|
||||
quality_issues += 1
|
||||
|
||||
estimated_error_rate = (quality_issues / sample_size) * 100 if sample_size > 0 else 0
|
||||
estimated_invalid = int(validation_result.total_records * estimated_error_rate / 100)
|
||||
|
||||
validation_result.valid_records = validation_result.total_records - estimated_invalid
|
||||
validation_result.invalid_records = estimated_invalid
|
||||
|
||||
# STRICT: Any data quality issues should fail validation for production
|
||||
if estimated_error_rate > 0:
|
||||
errors.append({
|
||||
"type": "data_quality_error",
|
||||
"message": f"Falló la validación de calidad: {estimated_error_rate:.0f}% de los datos tienen errores críticos",
|
||||
"field": "data",
|
||||
"row": None,
|
||||
"code": "DATA_QUALITY_FAILED"
|
||||
})
|
||||
|
||||
# Add specific error details
|
||||
if estimated_error_rate > 50:
|
||||
errors.append({
|
||||
"type": "data_quality_critical",
|
||||
"message": f"Calidad de datos crítica: más del 50% de los registros tienen errores",
|
||||
"field": "data",
|
||||
"row": None,
|
||||
"code": "DATA_QUALITY_CRITICAL"
|
||||
})
|
||||
elif estimated_error_rate > 20:
|
||||
errors.append({
|
||||
"type": "data_quality_high",
|
||||
"message": f"Alta tasa de errores detectada: {estimated_error_rate:.0f}% de los datos requieren corrección",
|
||||
"field": "data",
|
||||
"row": None,
|
||||
"code": "DATA_QUALITY_HIGH_ERROR_RATE"
|
||||
})
|
||||
else:
|
||||
# Even small error rates are now treated as validation failures
|
||||
errors.append({
|
||||
"type": "data_quality_detected",
|
||||
"message": f"Se detectaron errores de validación en {estimated_error_rate:.0f}% de los datos",
|
||||
"field": "data",
|
||||
"row": None,
|
||||
"code": "DATA_QUALITY_ERRORS_FOUND"
|
||||
})
|
||||
else:
|
||||
validation_result.valid_records = 0
|
||||
validation_result.invalid_records = validation_result.total_records
|
||||
|
||||
except Exception as csv_error:
|
||||
logger.warning("CSV analysis failed", error=str(csv_error))
|
||||
logger.warning("Enhanced CSV analysis failed", error=str(csv_error))
|
||||
warnings.append({
|
||||
"type": "analysis_warning",
|
||||
"message": f"No se pudo analizar completamente el CSV: {str(csv_error)}",
|
||||
@@ -212,7 +291,7 @@ class EnhancedDataImportService:
|
||||
validation_result.errors = errors
|
||||
validation_result.warnings = warnings
|
||||
|
||||
# Build summary
|
||||
# Enhanced summary generation
|
||||
validation_result.summary = {
|
||||
"status": "valid" if validation_result.is_valid else "invalid",
|
||||
"file_format": format_type,
|
||||
@@ -220,10 +299,11 @@ class EnhancedDataImportService:
|
||||
"file_size_mb": round(data_size / (1024 * 1024), 2),
|
||||
"estimated_processing_time_seconds": max(1, validation_result.total_records // 100),
|
||||
"validation_timestamp": datetime.utcnow().isoformat(),
|
||||
"detected_columns": list(self._detect_columns(list(csv.DictReader(io.StringIO(data_content)).fieldnames or [])).keys()) if format_type == "csv" and data_content else [],
|
||||
"suggestions": self._generate_suggestions(validation_result, format_type, len(warnings))
|
||||
}
|
||||
|
||||
logger.info("Import validation completed",
|
||||
logger.info("Enhanced import validation completed",
|
||||
is_valid=validation_result.is_valid,
|
||||
total_records=validation_result.total_records,
|
||||
error_count=len(errors),
|
||||
@@ -232,7 +312,7 @@ class EnhancedDataImportService:
|
||||
return validation_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Validation process failed", error=str(e))
|
||||
logger.error("Enhanced validation process failed", error=str(e))
|
||||
|
||||
return SalesValidationResult(
|
||||
is_valid=False,
|
||||
@@ -263,65 +343,57 @@ class EnhancedDataImportService:
|
||||
tenant_id: str,
|
||||
content: str,
|
||||
file_format: str,
|
||||
filename: Optional[str] = None,
|
||||
session = None
|
||||
filename: Optional[str] = None
|
||||
) -> SalesImportResult:
|
||||
"""Process data import using repository pattern"""
|
||||
"""Enhanced data import processing with better error handling"""
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
try:
|
||||
logger.info("Starting data import using repository pattern",
|
||||
logger.info("Starting enhanced data import",
|
||||
filename=filename,
|
||||
format=file_format,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
async with UnitOfWork(db_session) as uow:
|
||||
# Register sales repository
|
||||
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
|
||||
|
||||
# Process data based on format
|
||||
if file_format.lower() == 'csv':
|
||||
result = await self._process_csv_data(tenant_id, content, sales_repo, filename)
|
||||
elif file_format.lower() == 'json':
|
||||
result = await self._process_json_data(tenant_id, content, sales_repo, filename)
|
||||
elif file_format.lower() in ['excel', 'xlsx']:
|
||||
result = await self._process_excel_data(tenant_id, content, sales_repo, filename)
|
||||
else:
|
||||
raise ValidationError(f"Unsupported format: {file_format}")
|
||||
|
||||
# Commit all changes
|
||||
await uow.commit()
|
||||
|
||||
# Calculate processing time
|
||||
end_time = datetime.utcnow()
|
||||
processing_time = (end_time - start_time).total_seconds()
|
||||
|
||||
# Build final result
|
||||
final_result = SalesImportResult(
|
||||
success=result.get("success", False),
|
||||
records_processed=result.get("total_rows", 0),
|
||||
records_created=result.get("records_created", 0),
|
||||
records_updated=0, # We don't update, only create
|
||||
records_failed=result.get("total_rows", 0) - result.get("records_created", 0),
|
||||
errors=self._structure_messages(result.get("errors", [])),
|
||||
warnings=self._structure_messages(result.get("warnings", [])),
|
||||
processing_time_seconds=processing_time
|
||||
)
|
||||
|
||||
logger.info("Data import completed successfully",
|
||||
records_created=final_result.records_created,
|
||||
processing_time=processing_time)
|
||||
|
||||
return final_result
|
||||
|
||||
except (ValidationError, DatabaseError):
|
||||
raise
|
||||
async with get_db_transaction() as db:
|
||||
repository = SalesRepository(db)
|
||||
|
||||
# Process data based on format
|
||||
if file_format.lower() == 'csv':
|
||||
result = await self._process_csv_data(tenant_id, content, repository, filename)
|
||||
elif file_format.lower() == 'json':
|
||||
result = await self._process_json_data(tenant_id, content, repository, filename)
|
||||
elif file_format.lower() in ['excel', 'xlsx']:
|
||||
result = await self._process_excel_data(tenant_id, content, repository, filename)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {file_format}")
|
||||
|
||||
# Calculate processing time
|
||||
end_time = datetime.utcnow()
|
||||
processing_time = (end_time - start_time).total_seconds()
|
||||
|
||||
# Build enhanced final result
|
||||
final_result = SalesImportResult(
|
||||
success=result.get("success", False),
|
||||
records_processed=result.get("total_rows", 0),
|
||||
records_created=result.get("records_created", 0),
|
||||
records_updated=0, # We don't update, only create
|
||||
records_failed=result.get("total_rows", 0) - result.get("records_created", 0),
|
||||
errors=self._structure_messages(result.get("errors", [])),
|
||||
warnings=self._structure_messages(result.get("warnings", [])),
|
||||
processing_time_seconds=processing_time
|
||||
)
|
||||
|
||||
logger.info("Enhanced data import completed successfully",
|
||||
records_created=final_result.records_created,
|
||||
processing_time=processing_time)
|
||||
|
||||
return final_result
|
||||
|
||||
except Exception as e:
|
||||
end_time = datetime.utcnow()
|
||||
processing_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.error("Data import failed", error=str(e), tenant_id=tenant_id)
|
||||
logger.error("Enhanced data import failed", error=str(e), tenant_id=tenant_id)
|
||||
|
||||
return SalesImportResult(
|
||||
success=False,
|
||||
@@ -344,10 +416,10 @@ class EnhancedDataImportService:
|
||||
self,
|
||||
tenant_id: str,
|
||||
csv_content: str,
|
||||
sales_repo: SalesRepository,
|
||||
repository: SalesRepository,
|
||||
filename: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Process CSV data using repository"""
|
||||
"""Enhanced CSV processing with better data handling"""
|
||||
try:
|
||||
reader = csv.DictReader(io.StringIO(csv_content))
|
||||
rows = list(reader)
|
||||
@@ -361,46 +433,48 @@ class EnhancedDataImportService:
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
# Column mapping
|
||||
# Enhanced column mapping
|
||||
column_mapping = self._detect_columns(list(rows[0].keys()))
|
||||
|
||||
records_created = 0
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
logger.info(f"Processing {len(rows)} records from CSV")
|
||||
logger.info(f"Processing {len(rows)} records from CSV with enhanced mapping")
|
||||
|
||||
for index, row in enumerate(rows):
|
||||
try:
|
||||
# Parse and validate data
|
||||
# Enhanced data parsing and validation
|
||||
parsed_data = await self._parse_row_data(row, column_mapping, index + 1)
|
||||
if parsed_data.get("skip"):
|
||||
errors.extend(parsed_data.get("errors", []))
|
||||
warnings.extend(parsed_data.get("warnings", []))
|
||||
continue
|
||||
|
||||
# Create sales record using repository
|
||||
record_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"date": parsed_data["date"],
|
||||
"product_name": parsed_data["product_name"],
|
||||
"quantity_sold": parsed_data["quantity_sold"],
|
||||
"revenue": parsed_data.get("revenue"),
|
||||
"location_id": parsed_data.get("location_id"),
|
||||
"source": "csv"
|
||||
}
|
||||
# Create sales record with enhanced data
|
||||
sales_data = SalesDataCreate(
|
||||
tenant_id=tenant_id,
|
||||
date=parsed_data["date"],
|
||||
product_name=parsed_data["product_name"],
|
||||
product_category=parsed_data.get("product_category"),
|
||||
quantity_sold=parsed_data["quantity_sold"],
|
||||
unit_price=parsed_data.get("unit_price"),
|
||||
revenue=parsed_data.get("revenue"),
|
||||
location_id=parsed_data.get("location_id"),
|
||||
source="csv"
|
||||
)
|
||||
|
||||
await sales_repo.create(record_data)
|
||||
created_record = await repository.create_sales_record(sales_data, tenant_id)
|
||||
records_created += 1
|
||||
|
||||
# Log progress for large imports
|
||||
# Enhanced progress logging
|
||||
if records_created % 100 == 0:
|
||||
logger.info(f"Processed {records_created} records...")
|
||||
logger.info(f"Enhanced processing: {records_created}/{len(rows)} records completed...")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Row {index + 1}: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.warning("Record processing failed", error=error_msg)
|
||||
logger.warning("Enhanced record processing failed", error=error_msg)
|
||||
|
||||
success_rate = (records_created / len(rows)) * 100 if rows else 0
|
||||
|
||||
@@ -414,19 +488,19 @@ class EnhancedDataImportService:
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("CSV processing failed", error=str(e))
|
||||
raise DatabaseError(f"CSV processing error: {str(e)}")
|
||||
logger.error("Enhanced CSV processing failed", error=str(e))
|
||||
raise
|
||||
|
||||
async def _process_json_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
json_content: str,
|
||||
sales_repo: SalesRepository,
|
||||
repository: SalesRepository,
|
||||
filename: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Process JSON data using repository"""
|
||||
"""Enhanced JSON processing with pandas integration"""
|
||||
try:
|
||||
# Parse JSON
|
||||
# Parse JSON with base64 support
|
||||
if json_content.startswith('data:'):
|
||||
json_content = base64.b64decode(json_content.split(',')[1]).decode('utf-8')
|
||||
|
||||
@@ -445,28 +519,37 @@ class EnhancedDataImportService:
|
||||
elif isinstance(data, list):
|
||||
records = data
|
||||
else:
|
||||
raise ValidationError("Invalid JSON format")
|
||||
raise ValueError("Invalid JSON format")
|
||||
|
||||
# Convert to DataFrame for consistent processing
|
||||
df = pd.DataFrame(records)
|
||||
df.columns = df.columns.str.strip().str.lower()
|
||||
|
||||
return await self._process_dataframe(tenant_id, df, sales_repo, "json", filename)
|
||||
# Convert to DataFrame for enhanced processing
|
||||
if records:
|
||||
df = pd.DataFrame(records)
|
||||
df.columns = df.columns.str.strip().str.lower()
|
||||
|
||||
return await self._process_dataframe(tenant_id, df, repository, "json", filename)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"total_rows": 0,
|
||||
"records_created": 0,
|
||||
"errors": ["No records found in JSON"],
|
||||
"warnings": []
|
||||
}
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValidationError(f"Invalid JSON: {str(e)}")
|
||||
raise ValueError(f"Invalid JSON: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error("JSON processing failed", error=str(e))
|
||||
raise DatabaseError(f"JSON processing error: {str(e)}")
|
||||
logger.error("Enhanced JSON processing failed", error=str(e))
|
||||
raise
|
||||
|
||||
async def _process_excel_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
excel_content: str,
|
||||
sales_repo: SalesRepository,
|
||||
repository: SalesRepository,
|
||||
filename: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Process Excel data using repository"""
|
||||
"""Enhanced Excel processing with base64 support"""
|
||||
try:
|
||||
# Decode base64 content
|
||||
if excel_content.startswith('data:'):
|
||||
@@ -474,32 +557,32 @@ class EnhancedDataImportService:
|
||||
else:
|
||||
excel_bytes = base64.b64decode(excel_content)
|
||||
|
||||
# Read Excel file
|
||||
# Read Excel file with pandas
|
||||
df = pd.read_excel(io.BytesIO(excel_bytes), sheet_name=0)
|
||||
|
||||
# Clean column names
|
||||
# Enhanced column cleaning
|
||||
df.columns = df.columns.str.strip().str.lower()
|
||||
|
||||
# Remove empty rows
|
||||
df = df.dropna(how='all')
|
||||
|
||||
return await self._process_dataframe(tenant_id, df, sales_repo, "excel", filename)
|
||||
return await self._process_dataframe(tenant_id, df, repository, "excel", filename)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Excel processing failed", error=str(e))
|
||||
raise DatabaseError(f"Excel processing error: {str(e)}")
|
||||
logger.error("Enhanced Excel processing failed", error=str(e))
|
||||
raise
|
||||
|
||||
async def _process_dataframe(
|
||||
self,
|
||||
tenant_id: str,
|
||||
df: pd.DataFrame,
|
||||
sales_repo: SalesRepository,
|
||||
repository: SalesRepository,
|
||||
source: str,
|
||||
filename: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Process DataFrame using repository"""
|
||||
"""Enhanced DataFrame processing with better error handling"""
|
||||
try:
|
||||
# Map columns
|
||||
# Enhanced column mapping
|
||||
column_mapping = self._detect_columns(df.columns.tolist())
|
||||
|
||||
if not column_mapping.get('date') or not column_mapping.get('product'):
|
||||
@@ -509,50 +592,57 @@ class EnhancedDataImportService:
|
||||
if not column_mapping.get('product'):
|
||||
required_missing.append("product")
|
||||
|
||||
raise ValidationError(f"Required columns missing: {', '.join(required_missing)}")
|
||||
raise ValueError(f"Required columns missing: {', '.join(required_missing)}")
|
||||
|
||||
records_created = 0
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
logger.info(f"Processing {len(df)} records from {source}")
|
||||
logger.info(f"Enhanced processing of {len(df)} records from {source}")
|
||||
|
||||
for index, row in df.iterrows():
|
||||
try:
|
||||
# Convert pandas row to dict
|
||||
row_dict = {}
|
||||
for col in df.columns:
|
||||
row_dict[col] = row[col]
|
||||
val = row[col]
|
||||
# Handle pandas NaN values
|
||||
if pd.isna(val):
|
||||
row_dict[col] = None
|
||||
else:
|
||||
row_dict[col] = val
|
||||
|
||||
# Parse and validate data
|
||||
# Enhanced data parsing
|
||||
parsed_data = await self._parse_row_data(row_dict, column_mapping, index + 1)
|
||||
if parsed_data.get("skip"):
|
||||
errors.extend(parsed_data.get("errors", []))
|
||||
warnings.extend(parsed_data.get("warnings", []))
|
||||
continue
|
||||
|
||||
# Create sales record using repository
|
||||
record_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"date": parsed_data["date"],
|
||||
"product_name": parsed_data["product_name"],
|
||||
"quantity_sold": parsed_data["quantity_sold"],
|
||||
"revenue": parsed_data.get("revenue"),
|
||||
"location_id": parsed_data.get("location_id"),
|
||||
"source": source
|
||||
}
|
||||
# Create enhanced sales record
|
||||
sales_data = SalesDataCreate(
|
||||
tenant_id=tenant_id,
|
||||
date=parsed_data["date"],
|
||||
product_name=parsed_data["product_name"],
|
||||
product_category=parsed_data.get("product_category"),
|
||||
quantity_sold=parsed_data["quantity_sold"],
|
||||
unit_price=parsed_data.get("unit_price"),
|
||||
revenue=parsed_data.get("revenue"),
|
||||
location_id=parsed_data.get("location_id"),
|
||||
source=source
|
||||
)
|
||||
|
||||
await sales_repo.create(record_data)
|
||||
created_record = await repository.create_sales_record(sales_data, tenant_id)
|
||||
records_created += 1
|
||||
|
||||
# Log progress for large imports
|
||||
# Progress logging for large datasets
|
||||
if records_created % 100 == 0:
|
||||
logger.info(f"Processed {records_created} records...")
|
||||
logger.info(f"Enhanced DataFrame processing: {records_created}/{len(df)} records completed...")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Row {index + 1}: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.warning("Record processing failed", error=error_msg)
|
||||
logger.warning("Enhanced record processing failed", error=error_msg)
|
||||
|
||||
success_rate = (records_created / len(df)) * 100 if len(df) > 0 else 0
|
||||
|
||||
@@ -561,15 +651,15 @@ class EnhancedDataImportService:
|
||||
"total_rows": len(df),
|
||||
"records_created": records_created,
|
||||
"success_rate": success_rate,
|
||||
"errors": errors[:10], # Limit errors
|
||||
"errors": errors[:10], # Limit errors for performance
|
||||
"warnings": warnings[:10] # Limit warnings
|
||||
}
|
||||
|
||||
except ValidationError:
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("DataFrame processing failed", error=str(e))
|
||||
raise DatabaseError(f"Data processing error: {str(e)}")
|
||||
logger.error("Enhanced DataFrame processing failed", error=str(e))
|
||||
raise
|
||||
|
||||
async def _parse_row_data(
|
||||
self,
|
||||
@@ -577,12 +667,12 @@ class EnhancedDataImportService:
|
||||
column_mapping: Dict[str, str],
|
||||
row_number: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Parse and validate row data"""
|
||||
"""Enhanced row data parsing with better validation"""
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
try:
|
||||
# Extract and validate date
|
||||
# Enhanced date extraction and validation
|
||||
date_str = str(row.get(column_mapping.get('date', ''), '')).strip()
|
||||
if not date_str or date_str.lower() in ['nan', 'null', 'none', '']:
|
||||
errors.append(f"Row {row_number}: Missing date")
|
||||
@@ -593,7 +683,7 @@ class EnhancedDataImportService:
|
||||
errors.append(f"Row {row_number}: Invalid date format: {date_str}")
|
||||
return {"skip": True, "errors": errors, "warnings": warnings}
|
||||
|
||||
# Extract and validate product name
|
||||
# Enhanced product name extraction and cleaning
|
||||
product_name = str(row.get(column_mapping.get('product', ''), '')).strip()
|
||||
if not product_name or product_name.lower() in ['nan', 'null', 'none', '']:
|
||||
errors.append(f"Row {row_number}: Missing product name")
|
||||
@@ -601,42 +691,78 @@ class EnhancedDataImportService:
|
||||
|
||||
product_name = self._clean_product_name(product_name)
|
||||
|
||||
# Extract and validate quantity
|
||||
# STRICT quantity validation for production data quality
|
||||
quantity_raw = row.get(column_mapping.get('quantity', 'quantity'), 1)
|
||||
try:
|
||||
quantity = int(float(str(quantity_raw).replace(',', '.')))
|
||||
if quantity <= 0:
|
||||
warnings.append(f"Row {row_number}: Invalid quantity ({quantity}), using 1")
|
||||
if pd.isna(quantity_raw):
|
||||
# Allow default quantity of 1 for missing values
|
||||
quantity = 1
|
||||
else:
|
||||
quantity = int(float(str(quantity_raw).replace(',', '.')))
|
||||
if quantity <= 0:
|
||||
# STRICT: Treat invalid quantities as ERRORS, not warnings
|
||||
errors.append(f"Row {row_number}: Invalid quantity ({quantity}) - quantities must be positive")
|
||||
return {"skip": True, "errors": errors, "warnings": warnings}
|
||||
elif self.STRICT_VALIDATION and quantity > self.MAX_QUANTITY_PER_DAY:
|
||||
# STRICT: Check for unrealistic quantities
|
||||
errors.append(f"Row {row_number}: Unrealistic quantity ({quantity}) - exceeds maximum expected daily sales ({self.MAX_QUANTITY_PER_DAY})")
|
||||
return {"skip": True, "errors": errors, "warnings": warnings}
|
||||
except (ValueError, TypeError):
|
||||
warnings.append(f"Row {row_number}: Invalid quantity ({quantity_raw}), using 1")
|
||||
quantity = 1
|
||||
# STRICT: Treat non-numeric quantities as ERRORS
|
||||
errors.append(f"Row {row_number}: Invalid quantity format ({quantity_raw}) - must be a positive number")
|
||||
return {"skip": True, "errors": errors, "warnings": warnings}
|
||||
|
||||
# Extract revenue (optional)
|
||||
# Enhanced revenue extraction
|
||||
revenue = None
|
||||
unit_price = None
|
||||
if 'revenue' in column_mapping and column_mapping['revenue'] in row:
|
||||
revenue_raw = row.get(column_mapping['revenue'])
|
||||
if revenue_raw and str(revenue_raw).lower() not in ['nan', 'null', 'none', '']:
|
||||
if revenue_raw and not pd.isna(revenue_raw) and str(revenue_raw).lower() not in ['nan', 'null', 'none', '']:
|
||||
try:
|
||||
revenue = float(str(revenue_raw).replace(',', '.').replace('€', '').replace('$', '').strip())
|
||||
if revenue < 0:
|
||||
revenue = None
|
||||
warnings.append(f"Row {row_number}: Negative revenue ignored")
|
||||
# STRICT: Treat negative revenue as ERROR, not warning
|
||||
errors.append(f"Row {row_number}: Negative revenue ({revenue}) - revenue must be positive or zero")
|
||||
return {"skip": True, "errors": errors, "warnings": warnings}
|
||||
else:
|
||||
# STRICT: Check for unrealistic revenue values
|
||||
if self.STRICT_VALIDATION and revenue > self.MAX_REVENUE_PER_ITEM:
|
||||
errors.append(f"Row {row_number}: Unrealistic revenue ({revenue}) - exceeds maximum expected value ({self.MAX_REVENUE_PER_ITEM})")
|
||||
return {"skip": True, "errors": errors, "warnings": warnings}
|
||||
|
||||
# Calculate unit price if we have both revenue and quantity
|
||||
unit_price = revenue / quantity if quantity > 0 else None
|
||||
|
||||
# STRICT: Validate unit price reasonableness
|
||||
if unit_price and unit_price > 10000: # More than €10,000 per unit seems unrealistic for bakery
|
||||
errors.append(f"Row {row_number}: Unrealistic unit price ({unit_price:.2f}) - check quantity and revenue values")
|
||||
return {"skip": True, "errors": errors, "warnings": warnings}
|
||||
except (ValueError, TypeError):
|
||||
warnings.append(f"Row {row_number}: Invalid revenue format: {revenue_raw}")
|
||||
# STRICT: Treat invalid revenue format as ERROR
|
||||
errors.append(f"Row {row_number}: Invalid revenue format ({revenue_raw}) - must be a valid number")
|
||||
return {"skip": True, "errors": errors, "warnings": warnings}
|
||||
|
||||
# Extract location (optional)
|
||||
# Enhanced location extraction
|
||||
location_id = None
|
||||
if 'location' in column_mapping and column_mapping['location'] in row:
|
||||
location_raw = row.get(column_mapping['location'])
|
||||
if location_raw and str(location_raw).lower() not in ['nan', 'null', 'none', '']:
|
||||
if location_raw and not pd.isna(location_raw) and str(location_raw).lower() not in ['nan', 'null', 'none', '']:
|
||||
location_id = str(location_raw).strip()
|
||||
|
||||
# Enhanced product category extraction
|
||||
product_category = None
|
||||
if 'category' in column_mapping and column_mapping['category'] in row:
|
||||
category_raw = row.get(column_mapping['category'])
|
||||
if category_raw and not pd.isna(category_raw) and str(category_raw).lower() not in ['nan', 'null', 'none', '']:
|
||||
product_category = str(category_raw).strip()
|
||||
|
||||
return {
|
||||
"skip": False,
|
||||
"date": parsed_date,
|
||||
"product_name": product_name,
|
||||
"product_category": product_category,
|
||||
"quantity_sold": quantity,
|
||||
"unit_price": unit_price,
|
||||
"revenue": revenue,
|
||||
"location_id": location_id,
|
||||
"errors": errors,
|
||||
@@ -644,24 +770,39 @@ class EnhancedDataImportService:
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"Row {row_number}: Parsing error: {str(e)}")
|
||||
errors.append(f"Row {row_number}: Enhanced parsing error: {str(e)}")
|
||||
return {"skip": True, "errors": errors, "warnings": warnings}
|
||||
|
||||
def _detect_columns(self, columns: List[str]) -> Dict[str, str]:
|
||||
"""Detect column mappings using fuzzy matching"""
|
||||
"""Enhanced column detection with fuzzy matching"""
|
||||
mapping = {}
|
||||
columns_lower = [col.lower() for col in columns]
|
||||
columns_lower = [col.lower().strip() for col in columns]
|
||||
|
||||
for standard_name, possible_names in self.COLUMN_MAPPINGS.items():
|
||||
for col in columns_lower:
|
||||
best_match = None
|
||||
best_score = 0
|
||||
|
||||
for col_idx, col in enumerate(columns_lower):
|
||||
for possible in possible_names:
|
||||
if possible in col or col in possible:
|
||||
mapping[standard_name] = columns[columns_lower.index(col)]
|
||||
# Exact match (highest priority)
|
||||
if possible == col:
|
||||
best_match = columns[col_idx]
|
||||
best_score = 100
|
||||
break
|
||||
if standard_name in mapping:
|
||||
# Contains match
|
||||
elif possible in col or col in possible:
|
||||
score = len(possible) / len(col) * 90
|
||||
if score > best_score:
|
||||
best_match = columns[col_idx]
|
||||
best_score = score
|
||||
|
||||
if best_score == 100: # Found exact match
|
||||
break
|
||||
|
||||
if best_match and best_score > 70: # Threshold for matches
|
||||
mapping[standard_name] = best_match
|
||||
|
||||
# Map common aliases
|
||||
# Enhanced alias mapping
|
||||
if 'product' not in mapping and 'product_name' in mapping:
|
||||
mapping['product'] = mapping['product_name']
|
||||
if 'quantity' not in mapping and 'quantity_sold' in mapping:
|
||||
@@ -672,13 +813,13 @@ class EnhancedDataImportService:
|
||||
return mapping
|
||||
|
||||
def _parse_date(self, date_str: str) -> Optional[datetime]:
|
||||
"""Parse date string with multiple format attempts"""
|
||||
"""Enhanced date parsing with pandas and multiple format support"""
|
||||
if not date_str or str(date_str).lower() in ['nan', 'null', 'none']:
|
||||
return None
|
||||
|
||||
date_str = str(date_str).strip()
|
||||
|
||||
# Try pandas first
|
||||
# Try pandas first (most robust)
|
||||
try:
|
||||
parsed_dt = pd.to_datetime(date_str, dayfirst=True)
|
||||
if hasattr(parsed_dt, 'to_pydatetime'):
|
||||
@@ -691,7 +832,7 @@ class EnhancedDataImportService:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try specific formats
|
||||
# Try specific formats as fallback
|
||||
for fmt in self.DATE_FORMATS:
|
||||
try:
|
||||
parsed_dt = datetime.strptime(date_str, fmt)
|
||||
@@ -701,10 +842,11 @@ class EnhancedDataImportService:
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
logger.warning(f"Could not parse date: {date_str}")
|
||||
return None
|
||||
|
||||
def _clean_product_name(self, product_name: str) -> str:
|
||||
"""Clean and standardize product names"""
|
||||
"""Enhanced product name cleaning and standardization"""
|
||||
if not product_name:
|
||||
return "Producto sin nombre"
|
||||
|
||||
@@ -717,12 +859,14 @@ class EnhancedDataImportService:
|
||||
# Capitalize first letter of each word
|
||||
cleaned = cleaned.title()
|
||||
|
||||
# Common corrections for Spanish bakeries
|
||||
# Enhanced corrections for Spanish bakeries
|
||||
replacements = {
|
||||
'Pan De': 'Pan de',
|
||||
'Café Con': 'Café con',
|
||||
'Te ': 'Té ',
|
||||
'Bocadillo De': 'Bocadillo de',
|
||||
'Dulce De': 'Dulce de',
|
||||
'Tarta De': 'Tarta de',
|
||||
}
|
||||
|
||||
for old, new in replacements.items():
|
||||
@@ -752,7 +896,7 @@ class EnhancedDataImportService:
|
||||
format_type: str,
|
||||
warning_count: int
|
||||
) -> List[str]:
|
||||
"""Generate contextual suggestions based on validation results"""
|
||||
"""Generate enhanced contextual suggestions"""
|
||||
suggestions = []
|
||||
|
||||
if validation_result.is_valid:
|
||||
@@ -761,21 +905,39 @@ class EnhancedDataImportService:
|
||||
|
||||
if validation_result.total_records > 1000:
|
||||
suggestions.append("Archivo grande: el procesamiento puede tomar varios minutos")
|
||||
suggestions.append("Considera dividir archivos muy grandes para mejor rendimiento")
|
||||
|
||||
if warning_count > 0:
|
||||
suggestions.append("Revisa las advertencias antes de continuar")
|
||||
suggestions.append("Los datos con advertencias se procesarán con valores por defecto")
|
||||
|
||||
# Format-specific suggestions
|
||||
if format_type == "csv":
|
||||
suggestions.append("Asegúrate de que las fechas estén en formato DD/MM/YYYY")
|
||||
suggestions.append("Verifica que los números usen punto decimal (no coma)")
|
||||
elif format_type in ["excel", "xlsx"]:
|
||||
suggestions.append("Solo se procesará la primera hoja del archivo Excel")
|
||||
suggestions.append("Evita celdas combinadas y fórmulas complejas")
|
||||
else:
|
||||
suggestions.append("Corrige los errores antes de continuar")
|
||||
suggestions.append("Verifica que el archivo tenga el formato correcto")
|
||||
|
||||
if format_type not in ["csv", "excel", "xlsx", "json"]:
|
||||
suggestions.append("Usa formato CSV o Excel")
|
||||
suggestions.append("Usa formato CSV o Excel para mejores resultados")
|
||||
suggestions.append("El formato JSON es para usuarios avanzados")
|
||||
|
||||
if validation_result.total_records == 0:
|
||||
suggestions.append("Asegúrate de que el archivo contenga datos")
|
||||
suggestions.append("Verifica que el archivo no esté corrupto")
|
||||
|
||||
# Missing column suggestions
|
||||
error_codes = [error.get("code", "") for error in validation_result.errors if isinstance(error, dict)]
|
||||
if "MISSING_DATE_COLUMN" in error_codes:
|
||||
suggestions.append("Incluye una columna de fecha (fecha, date, dia)")
|
||||
if "MISSING_PRODUCT_COLUMN" in error_codes:
|
||||
suggestions.append("Incluye una columna de producto (producto, product, item)")
|
||||
|
||||
return suggestions
|
||||
|
||||
|
||||
# Legacy compatibility alias
|
||||
DataImportService = EnhancedDataImportService
|
||||
# Main DataImportService class with enhanced functionality
|
||||
232
services/sales/app/services/messaging.py
Normal file
232
services/sales/app/services/messaging.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# services/sales/app/services/messaging.py
|
||||
"""
|
||||
Sales Service Messaging - Event Publishing using shared messaging infrastructure
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
|
||||
from shared.messaging.rabbitmq import RabbitMQClient
|
||||
from shared.messaging.events import BaseEvent, DataImportedEvent
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SalesEventPublisher:
|
||||
"""Sales service event publisher using RabbitMQ"""
|
||||
|
||||
def __init__(self):
|
||||
self.enabled = True
|
||||
self._rabbitmq_client = None
|
||||
|
||||
async def _get_rabbitmq_client(self):
|
||||
"""Get or create RabbitMQ client"""
|
||||
if not self._rabbitmq_client:
|
||||
self._rabbitmq_client = RabbitMQClient(
|
||||
connection_url=settings.RABBITMQ_URL,
|
||||
service_name="sales-service"
|
||||
)
|
||||
await self._rabbitmq_client.connect()
|
||||
return self._rabbitmq_client
|
||||
|
||||
async def publish_sales_created(self, sales_data: Dict[str, Any], correlation_id: Optional[str] = None) -> bool:
|
||||
"""Publish sales created event"""
|
||||
try:
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
# Create event
|
||||
event = BaseEvent(
|
||||
service_name="sales-service",
|
||||
data={
|
||||
"record_id": str(sales_data.get("id")),
|
||||
"tenant_id": str(sales_data.get("tenant_id")),
|
||||
"product_name": sales_data.get("product_name"),
|
||||
"revenue": float(sales_data.get("revenue", 0)),
|
||||
"quantity_sold": sales_data.get("quantity_sold", 0),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
},
|
||||
event_type="sales.created",
|
||||
correlation_id=correlation_id
|
||||
)
|
||||
|
||||
# Publish via RabbitMQ
|
||||
client = await self._get_rabbitmq_client()
|
||||
success = await client.publish_event(
|
||||
exchange_name="sales.events",
|
||||
routing_key="sales.created",
|
||||
event_data=event.to_dict()
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Sales record created event published",
|
||||
record_id=sales_data.get("id"),
|
||||
tenant_id=sales_data.get("tenant_id"),
|
||||
product=sales_data.get("product_name"))
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish sales created event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_sales_updated(self, sales_data: Dict[str, Any], correlation_id: Optional[str] = None) -> bool:
|
||||
"""Publish sales updated event"""
|
||||
try:
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
event = BaseEvent(
|
||||
service_name="sales-service",
|
||||
data={
|
||||
"record_id": str(sales_data.get("id")),
|
||||
"tenant_id": str(sales_data.get("tenant_id")),
|
||||
"product_name": sales_data.get("product_name"),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
},
|
||||
event_type="sales.updated",
|
||||
correlation_id=correlation_id
|
||||
)
|
||||
|
||||
client = await self._get_rabbitmq_client()
|
||||
success = await client.publish_event(
|
||||
exchange_name="sales.events",
|
||||
routing_key="sales.updated",
|
||||
event_data=event.to_dict()
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Sales record updated event published",
|
||||
record_id=sales_data.get("id"),
|
||||
tenant_id=sales_data.get("tenant_id"))
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish sales updated event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_sales_deleted(self, record_id: UUID, tenant_id: UUID, correlation_id: Optional[str] = None) -> bool:
|
||||
"""Publish sales deleted event"""
|
||||
try:
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
event = BaseEvent(
|
||||
service_name="sales-service",
|
||||
data={
|
||||
"record_id": str(record_id),
|
||||
"tenant_id": str(tenant_id),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
},
|
||||
event_type="sales.deleted",
|
||||
correlation_id=correlation_id
|
||||
)
|
||||
|
||||
client = await self._get_rabbitmq_client()
|
||||
success = await client.publish_event(
|
||||
exchange_name="sales.events",
|
||||
routing_key="sales.deleted",
|
||||
event_data=event.to_dict()
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Sales record deleted event published",
|
||||
record_id=record_id,
|
||||
tenant_id=tenant_id)
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish sales deleted event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_data_imported(self, import_result: Dict[str, Any], correlation_id: Optional[str] = None) -> bool:
|
||||
"""Publish data imported event"""
|
||||
try:
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
event = DataImportedEvent(
|
||||
service_name="sales-service",
|
||||
data={
|
||||
"records_created": import_result.get("records_created", 0),
|
||||
"records_updated": import_result.get("records_updated", 0),
|
||||
"records_failed": import_result.get("records_failed", 0),
|
||||
"tenant_id": str(import_result.get("tenant_id")),
|
||||
"success": import_result.get("success", False),
|
||||
"file_name": import_result.get("file_name"),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
},
|
||||
correlation_id=correlation_id
|
||||
)
|
||||
|
||||
client = await self._get_rabbitmq_client()
|
||||
success = await client.publish_event(
|
||||
exchange_name="data.events",
|
||||
routing_key="data.imported",
|
||||
event_data=event.to_dict()
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Sales data imported event published",
|
||||
records_created=import_result.get("records_created"),
|
||||
tenant_id=import_result.get("tenant_id"),
|
||||
success=import_result.get("success"))
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish data imported event", error=str(e))
|
||||
return False
|
||||
|
||||
async def publish_analytics_generated(self, analytics_data: Dict[str, Any], correlation_id: Optional[str] = None) -> bool:
|
||||
"""Publish analytics generated event"""
|
||||
try:
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
event = BaseEvent(
|
||||
service_name="sales-service",
|
||||
data={
|
||||
"tenant_id": str(analytics_data.get("tenant_id")),
|
||||
"total_revenue": float(analytics_data.get("total_revenue", 0)),
|
||||
"total_quantity": analytics_data.get("total_quantity", 0),
|
||||
"total_transactions": analytics_data.get("total_transactions", 0),
|
||||
"period_start": analytics_data.get("period_start"),
|
||||
"period_end": analytics_data.get("period_end"),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
},
|
||||
event_type="analytics.generated",
|
||||
correlation_id=correlation_id
|
||||
)
|
||||
|
||||
client = await self._get_rabbitmq_client()
|
||||
success = await client.publish_event(
|
||||
exchange_name="analytics.events",
|
||||
routing_key="analytics.generated",
|
||||
event_data=event.to_dict()
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Sales analytics generated event published",
|
||||
tenant_id=analytics_data.get("tenant_id"),
|
||||
total_revenue=analytics_data.get("total_revenue"))
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish analytics generated event", error=str(e))
|
||||
return False
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup RabbitMQ connections"""
|
||||
if self._rabbitmq_client:
|
||||
await self._rabbitmq_client.disconnect()
|
||||
|
||||
|
||||
# Global instance
|
||||
sales_publisher = SalesEventPublisher()
|
||||
171
services/sales/app/services/product_service.py
Normal file
171
services/sales/app/services/product_service.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# services/sales/app/services/product_service.py
|
||||
"""
|
||||
Product Service - Business Logic Layer
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
import structlog
|
||||
|
||||
from app.models.sales import Product
|
||||
from app.repositories.product_repository import ProductRepository
|
||||
from app.schemas.sales import ProductCreate, ProductUpdate
|
||||
from app.core.database import get_db_transaction
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ProductService:
|
||||
"""Service layer for product operations"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def create_product(
|
||||
self,
|
||||
product_data: ProductCreate,
|
||||
tenant_id: UUID,
|
||||
user_id: Optional[UUID] = None
|
||||
) -> Product:
|
||||
"""Create a new product with business validation"""
|
||||
try:
|
||||
# Business validation
|
||||
await self._validate_product_data(product_data, tenant_id)
|
||||
|
||||
async with get_db_transaction() as db:
|
||||
repository = ProductRepository(db)
|
||||
product = await repository.create_product(product_data, tenant_id)
|
||||
|
||||
logger.info("Created product", product_id=product.id, tenant_id=tenant_id)
|
||||
return product
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create product", error=str(e), tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def update_product(
|
||||
self,
|
||||
product_id: UUID,
|
||||
update_data: ProductUpdate,
|
||||
tenant_id: UUID
|
||||
) -> Product:
|
||||
"""Update a product"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = ProductRepository(db)
|
||||
|
||||
# Verify product belongs to tenant
|
||||
existing_product = await repository.get_by_id(product_id)
|
||||
if not existing_product or existing_product.tenant_id != tenant_id:
|
||||
raise ValueError(f"Product {product_id} not found for tenant {tenant_id}")
|
||||
|
||||
# Update the product
|
||||
updated_product = await repository.update(product_id, update_data.model_dump(exclude_unset=True))
|
||||
|
||||
logger.info("Updated product", product_id=product_id, tenant_id=tenant_id)
|
||||
return updated_product
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update product", error=str(e), product_id=product_id, tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def get_products(self, tenant_id: UUID) -> List[Product]:
|
||||
"""Get all products for a tenant"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = ProductRepository(db)
|
||||
products = await repository.get_by_tenant(tenant_id)
|
||||
|
||||
logger.info("Retrieved products", count=len(products), tenant_id=tenant_id)
|
||||
return products
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get products", error=str(e), tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def get_product(self, product_id: UUID, tenant_id: UUID) -> Optional[Product]:
|
||||
"""Get a specific product"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = ProductRepository(db)
|
||||
product = await repository.get_by_id(product_id)
|
||||
|
||||
# Verify product belongs to tenant
|
||||
if product and product.tenant_id != tenant_id:
|
||||
return None
|
||||
|
||||
return product
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get product", error=str(e), product_id=product_id, tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def delete_product(self, product_id: UUID, tenant_id: UUID) -> bool:
|
||||
"""Delete a product"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = ProductRepository(db)
|
||||
|
||||
# Verify product belongs to tenant
|
||||
existing_product = await repository.get_by_id(product_id)
|
||||
if not existing_product or existing_product.tenant_id != tenant_id:
|
||||
raise ValueError(f"Product {product_id} not found for tenant {tenant_id}")
|
||||
|
||||
success = await repository.delete(product_id)
|
||||
|
||||
if success:
|
||||
logger.info("Deleted product", product_id=product_id, tenant_id=tenant_id)
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to delete product", error=str(e), product_id=product_id, tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def get_products_by_category(self, tenant_id: UUID, category: str) -> List[Product]:
|
||||
"""Get products by category"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = ProductRepository(db)
|
||||
products = await repository.get_by_category(tenant_id, category)
|
||||
|
||||
logger.info("Retrieved products by category", count=len(products), category=category, tenant_id=tenant_id)
|
||||
return products
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get products by category", error=str(e), category=category, tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def search_products(self, tenant_id: UUID, search_term: str) -> List[Product]:
|
||||
"""Search products by name or SKU"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = ProductRepository(db)
|
||||
products = await repository.search_products(tenant_id, search_term)
|
||||
|
||||
logger.info("Searched products", count=len(products), search_term=search_term, tenant_id=tenant_id)
|
||||
return products
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to search products", error=str(e), search_term=search_term, tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def _validate_product_data(self, product_data: ProductCreate, tenant_id: UUID):
|
||||
"""Validate product data according to business rules"""
|
||||
# Check if product with same SKU already exists
|
||||
if product_data.sku:
|
||||
async with get_db_transaction() as db:
|
||||
repository = ProductRepository(db)
|
||||
existing_product = await repository.get_by_sku(tenant_id, product_data.sku)
|
||||
if existing_product:
|
||||
raise ValueError(f"Product with SKU {product_data.sku} already exists for tenant {tenant_id}")
|
||||
|
||||
# Validate seasonal dates
|
||||
if product_data.is_seasonal:
|
||||
if not product_data.seasonal_start or not product_data.seasonal_end:
|
||||
raise ValueError("Seasonal products must have start and end dates")
|
||||
if product_data.seasonal_start >= product_data.seasonal_end:
|
||||
raise ValueError("Seasonal start date must be before end date")
|
||||
|
||||
logger.info("Product data validation passed", tenant_id=tenant_id)
|
||||
282
services/sales/app/services/sales_service.py
Normal file
282
services/sales/app/services/sales_service.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# services/sales/app/services/sales_service.py
|
||||
"""
|
||||
Sales Service - Business Logic Layer
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
import structlog
|
||||
|
||||
from app.models.sales import SalesData
|
||||
from app.repositories.sales_repository import SalesRepository
|
||||
from app.schemas.sales import SalesDataCreate, SalesDataUpdate, SalesDataQuery, SalesAnalytics
|
||||
from app.core.database import get_db_transaction
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SalesService:
|
||||
"""Service layer for sales operations"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def create_sales_record(
|
||||
self,
|
||||
sales_data: SalesDataCreate,
|
||||
tenant_id: UUID,
|
||||
user_id: Optional[UUID] = None
|
||||
) -> SalesData:
|
||||
"""Create a new sales record with business validation"""
|
||||
try:
|
||||
# Business validation
|
||||
await self._validate_sales_data(sales_data, tenant_id)
|
||||
|
||||
# Set user who created the record
|
||||
if user_id:
|
||||
sales_data_dict = sales_data.model_dump()
|
||||
sales_data_dict['created_by'] = user_id
|
||||
sales_data = SalesDataCreate(**sales_data_dict)
|
||||
|
||||
async with get_db_transaction() as db:
|
||||
repository = SalesRepository(db)
|
||||
record = await repository.create_sales_record(sales_data, tenant_id)
|
||||
|
||||
# Additional business logic (e.g., notifications, analytics updates)
|
||||
await self._post_create_actions(record)
|
||||
|
||||
return record
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create sales record in service", error=str(e), tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def update_sales_record(
|
||||
self,
|
||||
record_id: UUID,
|
||||
update_data: SalesDataUpdate,
|
||||
tenant_id: UUID
|
||||
) -> SalesData:
|
||||
"""Update a sales record"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = SalesRepository(db)
|
||||
|
||||
# Verify record belongs to tenant
|
||||
existing_record = await repository.get_by_id(record_id)
|
||||
if not existing_record or existing_record.tenant_id != tenant_id:
|
||||
raise ValueError(f"Sales record {record_id} not found for tenant {tenant_id}")
|
||||
|
||||
# Update the record
|
||||
updated_record = await repository.update(record_id, update_data.model_dump(exclude_unset=True))
|
||||
|
||||
logger.info("Updated sales record", record_id=record_id, tenant_id=tenant_id)
|
||||
return updated_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update sales record", error=str(e), record_id=record_id, tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def get_sales_records(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
query_params: Optional[SalesDataQuery] = None
|
||||
) -> List[SalesData]:
|
||||
"""Get sales records for a tenant"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = SalesRepository(db)
|
||||
records = await repository.get_by_tenant(tenant_id, query_params)
|
||||
|
||||
logger.info("Retrieved sales records", count=len(records), tenant_id=tenant_id)
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales records", error=str(e), tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def get_sales_record(self, record_id: UUID, tenant_id: UUID) -> Optional[SalesData]:
|
||||
"""Get a specific sales record"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = SalesRepository(db)
|
||||
record = await repository.get_by_id(record_id)
|
||||
|
||||
# Verify record belongs to tenant
|
||||
if record and record.tenant_id != tenant_id:
|
||||
return None
|
||||
|
||||
return record
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales record", error=str(e), record_id=record_id, tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def delete_sales_record(self, record_id: UUID, tenant_id: UUID) -> bool:
|
||||
"""Delete a sales record"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = SalesRepository(db)
|
||||
|
||||
# Verify record belongs to tenant
|
||||
existing_record = await repository.get_by_id(record_id)
|
||||
if not existing_record or existing_record.tenant_id != tenant_id:
|
||||
raise ValueError(f"Sales record {record_id} not found for tenant {tenant_id}")
|
||||
|
||||
success = await repository.delete(record_id)
|
||||
|
||||
if success:
|
||||
logger.info("Deleted sales record", record_id=record_id, tenant_id=tenant_id)
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to delete sales record", error=str(e), record_id=record_id, tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def get_product_sales(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
product_name: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> List[SalesData]:
|
||||
"""Get sales records for a specific product"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = SalesRepository(db)
|
||||
records = await repository.get_by_product(tenant_id, product_name, start_date, end_date)
|
||||
|
||||
logger.info(
|
||||
"Retrieved product sales",
|
||||
count=len(records),
|
||||
product=product_name,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get product sales", error=str(e), tenant_id=tenant_id, product=product_name)
|
||||
raise
|
||||
|
||||
async def get_sales_analytics(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get sales analytics for a tenant"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = SalesRepository(db)
|
||||
analytics = await repository.get_analytics(tenant_id, start_date, end_date)
|
||||
|
||||
logger.info("Retrieved sales analytics", tenant_id=tenant_id)
|
||||
return analytics
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get sales analytics", error=str(e), tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def get_product_categories(self, tenant_id: UUID) -> List[str]:
|
||||
"""Get distinct product categories"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = SalesRepository(db)
|
||||
categories = await repository.get_product_categories(tenant_id)
|
||||
|
||||
return categories
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get product categories", error=str(e), tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def validate_sales_record(
|
||||
self,
|
||||
record_id: UUID,
|
||||
tenant_id: UUID,
|
||||
validation_notes: Optional[str] = None
|
||||
) -> SalesData:
|
||||
"""Validate a sales record"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = SalesRepository(db)
|
||||
|
||||
# Verify record belongs to tenant
|
||||
existing_record = await repository.get_by_id(record_id)
|
||||
if not existing_record or existing_record.tenant_id != tenant_id:
|
||||
raise ValueError(f"Sales record {record_id} not found for tenant {tenant_id}")
|
||||
|
||||
validated_record = await repository.validate_record(record_id, validation_notes)
|
||||
|
||||
logger.info("Validated sales record", record_id=record_id, tenant_id=tenant_id)
|
||||
return validated_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate sales record", error=str(e), record_id=record_id, tenant_id=tenant_id)
|
||||
raise
|
||||
|
||||
async def _validate_sales_data(self, sales_data: SalesDataCreate, tenant_id: UUID):
|
||||
"""Validate sales data according to business rules"""
|
||||
# Example business validations
|
||||
|
||||
# Check if revenue matches quantity * unit_price (if unit_price provided)
|
||||
if sales_data.unit_price and sales_data.quantity_sold:
|
||||
expected_revenue = sales_data.unit_price * sales_data.quantity_sold
|
||||
# Apply discount if any
|
||||
if sales_data.discount_applied:
|
||||
expected_revenue *= (1 - sales_data.discount_applied / 100)
|
||||
|
||||
# Allow for small rounding differences
|
||||
if abs(float(sales_data.revenue) - float(expected_revenue)) > 0.01:
|
||||
logger.warning(
|
||||
"Revenue mismatch detected",
|
||||
expected=float(expected_revenue),
|
||||
actual=float(sales_data.revenue),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Check date validity (not in future)
|
||||
if sales_data.date > datetime.utcnow():
|
||||
raise ValueError("Sales date cannot be in the future")
|
||||
|
||||
# Additional business rules can be added here
|
||||
logger.info("Sales data validation passed", tenant_id=tenant_id)
|
||||
|
||||
async def _post_create_actions(self, record: SalesData):
|
||||
"""Actions to perform after creating a sales record"""
|
||||
try:
|
||||
# Here you could:
|
||||
# - Send notifications
|
||||
# - Update analytics caches
|
||||
# - Trigger ML model updates
|
||||
# - Update inventory levels (future integration)
|
||||
|
||||
logger.info("Post-create actions completed", record_id=record.id)
|
||||
|
||||
except Exception as e:
|
||||
# Don't fail the main operation for auxiliary actions
|
||||
logger.warning("Failed to execute post-create actions", error=str(e), record_id=record.id)
|
||||
|
||||
async def get_products_list(self, tenant_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get list of all products with sales data for tenant using repository pattern"""
|
||||
try:
|
||||
async with get_db_transaction() as db:
|
||||
repository = SalesRepository(db)
|
||||
|
||||
# Use repository method for product statistics
|
||||
products = await repository.get_product_statistics(tenant_id)
|
||||
|
||||
logger.debug("Products list retrieved successfully",
|
||||
tenant_id=tenant_id,
|
||||
product_count=len(products))
|
||||
|
||||
return products
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get products list",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id)
|
||||
raise DatabaseError(f"Failed to get products list: {str(e)}")
|
||||
19
services/sales/pytest.ini
Normal file
19
services/sales/pytest.ini
Normal file
@@ -0,0 +1,19 @@
|
||||
[tool:pytest]
|
||||
testpaths = tests
|
||||
asyncio_mode = auto
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts =
|
||||
-v
|
||||
--tb=short
|
||||
--strict-markers
|
||||
--disable-warnings
|
||||
--cov=app
|
||||
--cov-report=term-missing
|
||||
--cov-report=html:htmlcov
|
||||
markers =
|
||||
unit: Unit tests
|
||||
integration: Integration tests
|
||||
slow: Slow running tests
|
||||
external: Tests requiring external services
|
||||
39
services/sales/requirements.txt
Normal file
39
services/sales/requirements.txt
Normal file
@@ -0,0 +1,39 @@
|
||||
# services/sales/requirements.txt
|
||||
# FastAPI and web framework
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
|
||||
# Database
|
||||
sqlalchemy==2.0.23
|
||||
psycopg2-binary==2.9.9
|
||||
asyncpg==0.29.0
|
||||
aiosqlite==0.19.0
|
||||
alembic==1.12.1
|
||||
|
||||
# Data processing
|
||||
pandas==2.1.3
|
||||
numpy==1.25.2
|
||||
|
||||
# HTTP clients
|
||||
httpx==0.25.2
|
||||
aiofiles==23.2.0
|
||||
|
||||
# Validation and serialization
|
||||
pydantic==2.5.0
|
||||
pydantic-settings==2.0.3
|
||||
|
||||
# Authentication and security
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
|
||||
# Logging and monitoring
|
||||
structlog==23.2.0
|
||||
prometheus-client==0.19.0
|
||||
|
||||
# Message queues
|
||||
aio-pika==9.3.1
|
||||
|
||||
# Note: pytest and testing dependencies are in tests/requirements.txt
|
||||
|
||||
# Development
|
||||
python-multipart==0.0.6
|
||||
1
services/sales/shared/shared
Symbolic link
1
services/sales/shared/shared
Symbolic link
@@ -0,0 +1 @@
|
||||
/Users/urtzialfaro/Documents/bakery-ia/shared
|
||||
239
services/sales/tests/conftest.py
Normal file
239
services/sales/tests/conftest.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# services/sales/tests/conftest.py
|
||||
"""
|
||||
Pytest configuration and fixtures for Sales Service tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import AsyncGenerator
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.main import app
|
||||
from app.core.config import settings
|
||||
from app.core.database import Base, get_db
|
||||
from app.models.sales import SalesData
|
||||
from app.schemas.sales import SalesDataCreate
|
||||
|
||||
|
||||
# Test database configuration
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create event loop for the test session"""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_engine():
|
||||
"""Create test database engine"""
|
||||
engine = create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
poolclass=StaticPool,
|
||||
connect_args={"check_same_thread": False}
|
||||
)
|
||||
|
||||
# Create tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_db_session(test_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create test database session"""
|
||||
async_session = async_sessionmaker(
|
||||
test_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client():
|
||||
"""Create test client"""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def override_get_db(test_db_session):
|
||||
"""Override get_db dependency for testing"""
|
||||
async def _override_get_db():
|
||||
yield test_db_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# Test data fixtures
|
||||
@pytest.fixture
|
||||
def sample_tenant_id() -> UUID:
|
||||
"""Sample tenant ID for testing"""
|
||||
return uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_sales_data(sample_tenant_id: UUID) -> SalesDataCreate:
|
||||
"""Sample sales data for testing"""
|
||||
return SalesDataCreate(
|
||||
date=datetime.now(timezone.utc),
|
||||
product_name="Pan Integral",
|
||||
product_category="Panadería",
|
||||
product_sku="PAN001",
|
||||
quantity_sold=5,
|
||||
unit_price=Decimal("2.50"),
|
||||
revenue=Decimal("12.50"),
|
||||
cost_of_goods=Decimal("6.25"),
|
||||
discount_applied=Decimal("0"),
|
||||
location_id="STORE_001",
|
||||
sales_channel="in_store",
|
||||
source="manual",
|
||||
notes="Test sale",
|
||||
weather_condition="sunny",
|
||||
is_holiday=False,
|
||||
is_weekend=False
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_sales_records(sample_tenant_id: UUID) -> list[dict]:
|
||||
"""Multiple sample sales records"""
|
||||
base_date = datetime.now(timezone.utc)
|
||||
return [
|
||||
{
|
||||
"tenant_id": sample_tenant_id,
|
||||
"date": base_date,
|
||||
"product_name": "Croissant",
|
||||
"quantity_sold": 3,
|
||||
"revenue": Decimal("7.50"),
|
||||
"location_id": "STORE_001",
|
||||
"source": "manual"
|
||||
},
|
||||
{
|
||||
"tenant_id": sample_tenant_id,
|
||||
"date": base_date,
|
||||
"product_name": "Café Americano",
|
||||
"quantity_sold": 2,
|
||||
"revenue": Decimal("5.00"),
|
||||
"location_id": "STORE_001",
|
||||
"source": "pos"
|
||||
},
|
||||
{
|
||||
"tenant_id": sample_tenant_id,
|
||||
"date": base_date,
|
||||
"product_name": "Bocadillo Jamón",
|
||||
"quantity_sold": 1,
|
||||
"revenue": Decimal("4.50"),
|
||||
"location_id": "STORE_002",
|
||||
"source": "manual"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csv_data() -> str:
|
||||
"""Sample CSV data for import testing"""
|
||||
return """date,product,quantity,revenue,location
|
||||
2024-01-15,Pan Integral,5,12.50,STORE_001
|
||||
2024-01-15,Croissant,3,7.50,STORE_001
|
||||
2024-01-15,Café Americano,2,5.00,STORE_002
|
||||
2024-01-16,Pan de Molde,8,16.00,STORE_001
|
||||
2024-01-16,Magdalenas,6,9.00,STORE_002"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_json_data() -> str:
|
||||
"""Sample JSON data for import testing"""
|
||||
return """[
|
||||
{
|
||||
"date": "2024-01-15",
|
||||
"product": "Pan Integral",
|
||||
"quantity": 5,
|
||||
"revenue": 12.50,
|
||||
"location": "STORE_001"
|
||||
},
|
||||
{
|
||||
"date": "2024-01-15",
|
||||
"product": "Croissant",
|
||||
"quantity": 3,
|
||||
"revenue": 7.50,
|
||||
"location": "STORE_001"
|
||||
}
|
||||
]"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def populated_db(test_db_session: AsyncSession, sample_sales_records: list[dict]):
|
||||
"""Database populated with test data"""
|
||||
for record_data in sample_sales_records:
|
||||
sales_record = SalesData(**record_data)
|
||||
test_db_session.add(sales_record)
|
||||
|
||||
await test_db_session.commit()
|
||||
yield test_db_session
|
||||
|
||||
|
||||
# Mock fixtures for external dependencies
|
||||
@pytest.fixture
|
||||
def mock_messaging():
|
||||
"""Mock messaging service"""
|
||||
class MockMessaging:
|
||||
def __init__(self):
|
||||
self.published_events = []
|
||||
|
||||
async def publish_sales_created(self, data):
|
||||
self.published_events.append(("sales_created", data))
|
||||
return True
|
||||
|
||||
async def publish_data_imported(self, data):
|
||||
self.published_events.append(("data_imported", data))
|
||||
return True
|
||||
|
||||
return MockMessaging()
|
||||
|
||||
|
||||
# Performance testing fixtures
|
||||
@pytest.fixture
|
||||
def large_csv_data() -> str:
|
||||
"""Large CSV data for performance testing"""
|
||||
headers = "date,product,quantity,revenue,location\n"
|
||||
rows = []
|
||||
|
||||
for i in range(1000): # 1000 records
|
||||
rows.append(f"2024-01-{(i % 30) + 1:02d},Producto_{i % 10},1,2.50,STORE_{i % 3 + 1:03d}")
|
||||
|
||||
return headers + "\n".join(rows)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def performance_test_data(sample_tenant_id: UUID) -> list[dict]:
|
||||
"""Large dataset for performance testing"""
|
||||
records = []
|
||||
base_date = datetime.now(timezone.utc)
|
||||
|
||||
for i in range(500): # 500 records
|
||||
records.append({
|
||||
"tenant_id": sample_tenant_id,
|
||||
"date": base_date,
|
||||
"product_name": f"Test Product {i % 20}",
|
||||
"quantity_sold": (i % 10) + 1,
|
||||
"revenue": Decimal(str(((i % 10) + 1) * 2.5)),
|
||||
"location_id": f"STORE_{(i % 5) + 1:03d}",
|
||||
"source": "test"
|
||||
})
|
||||
|
||||
return records
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user