Add all the code for training service

This commit is contained in:
Urtzi Alfaro
2025-07-19 16:59:37 +02:00
parent 42097202d2
commit f3071c00bd
21 changed files with 7504 additions and 764 deletions

View File

@@ -0,0 +1,688 @@
# services/training/tests/test_service.py
"""
Tests for training service business logic layer
"""
import pytest
from unittest.mock import AsyncMock, Mock, patch
from datetime import datetime, timedelta
import httpx
from app.services.training_service import TrainingService
from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
from app.models.training import ModelTrainingLog, TrainedModel
class TestTrainingService:
"""Test the training service business logic"""
@pytest.fixture
def training_service(self, mock_ml_trainer):
return TrainingService()
@pytest.mark.asyncio
async def test_create_training_job_success(
self,
training_service,
test_db_session
):
"""Test successful training job creation"""
job_id = "test-job-123"
tenant_id = "test-tenant"
config = {"include_weather": True, "include_traffic": True}
result = await training_service.create_training_job(
db=test_db_session,
tenant_id=tenant_id,
job_id=job_id,
config=config
)
assert isinstance(result, ModelTrainingLog)
assert result.job_id == job_id
assert result.tenant_id == tenant_id
assert result.status == "pending"
assert result.progress == 0
assert result.config == config
@pytest.mark.asyncio
async def test_create_single_product_job_success(
self,
training_service,
test_db_session
):
"""Test successful single product job creation"""
job_id = "test-product-job-123"
tenant_id = "test-tenant"
product_name = "Pan Integral"
config = {"include_weather": True}
result = await training_service.create_single_product_job(
db=test_db_session,
tenant_id=tenant_id,
product_name=product_name,
job_id=job_id,
config=config
)
assert isinstance(result, ModelTrainingLog)
assert result.job_id == job_id
assert result.tenant_id == tenant_id
assert result.config["single_product"] == product_name
assert f"Initializing training for {product_name}" in result.current_step
@pytest.mark.asyncio
async def test_get_job_status_existing(
self,
training_service,
test_db_session,
training_job_in_db
):
"""Test getting status of existing job"""
result = await training_service.get_job_status(
db=test_db_session,
job_id=training_job_in_db.job_id,
tenant_id=training_job_in_db.tenant_id
)
assert result is not None
assert result.job_id == training_job_in_db.job_id
assert result.status == training_job_in_db.status
@pytest.mark.asyncio
async def test_get_job_status_nonexistent(
self,
training_service,
test_db_session
):
"""Test getting status of non-existent job"""
result = await training_service.get_job_status(
db=test_db_session,
job_id="nonexistent-job",
tenant_id="test-tenant"
)
assert result is None
@pytest.mark.asyncio
async def test_list_training_jobs(
self,
training_service,
test_db_session,
training_job_in_db
):
"""Test listing training jobs"""
result = await training_service.list_training_jobs(
db=test_db_session,
tenant_id=training_job_in_db.tenant_id,
limit=10
)
assert isinstance(result, list)
assert len(result) >= 1
assert result[0].job_id == training_job_in_db.job_id
@pytest.mark.asyncio
async def test_list_training_jobs_with_filter(
self,
training_service,
test_db_session,
training_job_in_db
):
"""Test listing training jobs with status filter"""
result = await training_service.list_training_jobs(
db=test_db_session,
tenant_id=training_job_in_db.tenant_id,
limit=10,
status_filter="pending"
)
assert isinstance(result, list)
for job in result:
assert job.status == "pending"
@pytest.mark.asyncio
async def test_cancel_training_job_success(
self,
training_service,
test_db_session,
training_job_in_db
):
"""Test successful job cancellation"""
result = await training_service.cancel_training_job(
db=test_db_session,
job_id=training_job_in_db.job_id,
tenant_id=training_job_in_db.tenant_id
)
assert result is True
# Verify status was updated
updated_job = await training_service.get_job_status(
db=test_db_session,
job_id=training_job_in_db.job_id,
tenant_id=training_job_in_db.tenant_id
)
assert updated_job.status == "cancelled"
@pytest.mark.asyncio
async def test_cancel_nonexistent_job(
self,
training_service,
test_db_session
):
"""Test cancelling non-existent job"""
result = await training_service.cancel_training_job(
db=test_db_session,
job_id="nonexistent-job",
tenant_id="test-tenant"
)
assert result is False
@pytest.mark.asyncio
async def test_validate_training_data_valid(
self,
training_service,
test_db_session,
mock_data_service
):
"""Test validation with valid data"""
config = {"min_data_points": 30}
result = await training_service.validate_training_data(
db=test_db_session,
tenant_id="test-tenant",
config=config
)
assert isinstance(result, dict)
assert "is_valid" in result
assert "issues" in result
assert "recommendations" in result
assert "estimated_time_minutes" in result
@pytest.mark.asyncio
async def test_validate_training_data_no_data(
self,
training_service,
test_db_session
):
"""Test validation with no data"""
config = {"min_data_points": 30}
with patch('app.services.training_service.TrainingService._fetch_sales_data', return_value=AsyncMock(return_value=[])):
result = await training_service.validate_training_data(
db=test_db_session,
tenant_id="test-tenant",
config=config
)
assert result["is_valid"] is False
assert "No sales data found" in result["issues"][0]
@pytest.mark.asyncio
async def test_update_job_status(
self,
training_service,
test_db_session,
training_job_in_db
):
"""Test updating job status"""
await training_service._update_job_status(
db=test_db_session,
job_id=training_job_in_db.job_id,
status="running",
progress=50,
current_step="Training models"
)
# Verify update
updated_job = await training_service.get_job_status(
db=test_db_session,
job_id=training_job_in_db.job_id,
tenant_id=training_job_in_db.tenant_id
)
assert updated_job.status == "running"
assert updated_job.progress == 50
assert updated_job.current_step == "Training models"
@pytest.mark.asyncio
async def test_store_trained_models(
self,
training_service,
test_db_session
):
"""Test storing trained models"""
tenant_id = "test-tenant"
training_results = {
"training_results": {
"Pan Integral": {
"status": "success",
"model_info": {
"model_id": "test-model-123",
"model_path": "/test/models/test-model-123.pkl",
"type": "prophet",
"training_samples": 100,
"features": ["temperature", "humidity"],
"hyperparameters": {"seasonality_mode": "additive"},
"training_metrics": {"mae": 5.2, "rmse": 7.8},
"data_period": {
"start_date": "2024-01-01T00:00:00",
"end_date": "2024-01-31T00:00:00"
}
}
}
}
}
await training_service._store_trained_models(
db=test_db_session,
tenant_id=tenant_id,
training_results=training_results
)
# Verify model was stored
from sqlalchemy import select
result = await test_db_session.execute(
select(TrainedModel).where(
TrainedModel.tenant_id == tenant_id,
TrainedModel.product_name == "Pan Integral"
)
)
stored_model = result.scalar_one_or_none()
assert stored_model is not None
assert stored_model.model_id == "test-model-123"
assert stored_model.is_active is True
@pytest.mark.asyncio
async def test_get_training_logs(
self,
training_service,
test_db_session,
training_job_in_db
):
"""Test getting training logs"""
result = await training_service.get_training_logs(
db=test_db_session,
job_id=training_job_in_db.job_id,
tenant_id=training_job_in_db.tenant_id
)
assert isinstance(result, list)
assert len(result) > 0
# Check log content
log_text = " ".join(result)
assert training_job_in_db.job_id in log_text or "Job started" in log_text
class TestTrainingServiceDataFetching:
"""Test external data fetching functionality"""
@pytest.fixture
def training_service(self):
return TrainingService()
@pytest.mark.asyncio
async def test_fetch_sales_data_success(self, training_service):
"""Test successful sales data fetching"""
mock_request = Mock()
mock_request.start_date = None
mock_request.end_date = None
mock_response_data = {
"sales": [
{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45}
]
}
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
result = await training_service._fetch_sales_data(
tenant_id="test-tenant",
request=mock_request
)
assert result == mock_response_data["sales"]
@pytest.mark.asyncio
async def test_fetch_sales_data_error(self, training_service):
"""Test sales data fetching with API error"""
mock_request = Mock()
mock_request.start_date = None
mock_request.end_date = None
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 500
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
result = await training_service._fetch_sales_data(
tenant_id="test-tenant",
request=mock_request
)
assert result == []
@pytest.mark.asyncio
async def test_fetch_weather_data_success(self, training_service):
"""Test successful weather data fetching"""
mock_request = Mock()
mock_request.start_date = None
mock_request.end_date = None
mock_response_data = {
"weather": [
{"date": "2024-01-01", "temperature": 15.2, "precipitation": 0.0}
]
}
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
result = await training_service._fetch_weather_data(
tenant_id="test-tenant",
request=mock_request
)
assert result == mock_response_data["weather"]
@pytest.mark.asyncio
async def test_fetch_traffic_data_success(self, training_service):
"""Test successful traffic data fetching"""
mock_request = Mock()
mock_request.start_date = None
mock_request.end_date = None
mock_response_data = {
"traffic": [
{"date": "2024-01-01", "traffic_volume": 120}
]
}
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
result = await training_service._fetch_traffic_data(
tenant_id="test-tenant",
request=mock_request
)
assert result == mock_response_data["traffic"]
@pytest.mark.asyncio
async def test_fetch_data_with_date_filters(self, training_service):
"""Test data fetching with date filters"""
from datetime import datetime
mock_request = Mock()
mock_request.start_date = datetime(2024, 1, 1)
mock_request.end_date = datetime(2024, 1, 31)
with patch('httpx.AsyncClient') as mock_client:
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"sales": []}
mock_get = mock_client.return_value.__aenter__.return_value.get
mock_get.return_value = mock_response
await training_service._fetch_sales_data(
tenant_id="test-tenant",
request=mock_request
)
# Verify dates were passed in params
call_args = mock_get.call_args
params = call_args[1]["params"]
assert "start_date" in params
assert "end_date" in params
assert params["start_date"] == "2024-01-01T00:00:00"
assert params["end_date"] == "2024-01-31T00:00:00"
class TestTrainingServiceExecution:
"""Test training execution workflow"""
@pytest.fixture
def training_service(self, mock_ml_trainer):
return TrainingService()
@pytest.mark.asyncio
async def test_execute_training_job_success(
self,
training_service,
test_db_session,
mock_messaging,
mock_data_service
):
"""Test successful training job execution"""
# Create job first
job_id = "test-execution-job"
training_log = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id=job_id,
config={"include_weather": True}
)
request = TrainingJobRequest(
include_weather=True,
include_traffic=True,
min_data_points=30
)
with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch_sales, \
patch('app.services.training_service.TrainingService._fetch_weather_data') as mock_fetch_weather, \
patch('app.services.training_service.TrainingService._fetch_traffic_data') as mock_fetch_traffic, \
patch('app.services.training_service.TrainingService._store_trained_models') as mock_store:
mock_fetch_sales.return_value = [{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45}]
mock_fetch_weather.return_value = []
mock_fetch_traffic.return_value = []
mock_store.return_value = None
await training_service.execute_training_job(
db=test_db_session,
job_id=job_id,
tenant_id="test-tenant",
request=request
)
# Verify job was completed
updated_job = await training_service.get_job_status(
db=test_db_session,
job_id=job_id,
tenant_id="test-tenant"
)
assert updated_job.status == "completed"
assert updated_job.progress == 100
@pytest.mark.asyncio
async def test_execute_training_job_failure(
self,
training_service,
test_db_session,
mock_messaging
):
"""Test training job execution with failure"""
# Create job first
job_id = "test-failure-job"
await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id=job_id,
config={}
)
request = TrainingJobRequest(min_data_points=30)
with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch:
mock_fetch.side_effect = Exception("Data service unavailable")
with pytest.raises(Exception):
await training_service.execute_training_job(
db=test_db_session,
job_id=job_id,
tenant_id="test-tenant",
request=request
)
# Verify job was marked as failed
updated_job = await training_service.get_job_status(
db=test_db_session,
job_id=job_id,
tenant_id="test-tenant"
)
assert updated_job.status == "failed"
assert "Data service unavailable" in updated_job.error_message
@pytest.mark.asyncio
async def test_execute_single_product_training_success(
self,
training_service,
test_db_session,
mock_messaging,
mock_data_service
):
"""Test successful single product training execution"""
job_id = "test-single-product-job"
product_name = "Pan Integral"
await training_service.create_single_product_job(
db=test_db_session,
tenant_id="test-tenant",
product_name=product_name,
job_id=job_id,
config={}
)
request = SingleProductTrainingRequest(
include_weather=True,
include_traffic=False
)
with patch('app.services.training_service.TrainingService._fetch_product_sales_data') as mock_fetch_sales, \
patch('app.services.training_service.TrainingService._fetch_weather_data') as mock_fetch_weather, \
patch('app.services.training_service.TrainingService._store_single_trained_model') as mock_store:
mock_fetch_sales.return_value = [{"date": "2024-01-01", "product_name": product_name, "quantity": 45}]
mock_fetch_weather.return_value = []
mock_store.return_value = None
await training_service.execute_single_product_training(
db=test_db_session,
job_id=job_id,
tenant_id="test-tenant",
product_name=product_name,
request=request
)
# Verify job was completed
updated_job = await training_service.get_job_status(
db=test_db_session,
job_id=job_id,
tenant_id="test-tenant"
)
assert updated_job.status == "completed"
assert updated_job.progress == 100
class TestTrainingServiceEdgeCases:
"""Test edge cases and error conditions"""
@pytest.fixture
def training_service(self):
return TrainingService()
@pytest.mark.asyncio
async def test_database_connection_failure(self, training_service):
"""Test handling of database connection failures"""
with patch('sqlalchemy.ext.asyncio.AsyncSession') as mock_session:
mock_session.side_effect = Exception("Database connection failed")
with pytest.raises(Exception):
await training_service.create_training_job(
db=mock_session,
tenant_id="test-tenant",
job_id="test-job",
config={}
)
@pytest.mark.asyncio
async def test_external_service_timeout(self, training_service):
"""Test handling of external service timeouts"""
mock_request = Mock()
mock_request.start_date = None
mock_request.end_date = None
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get.side_effect = httpx.TimeoutException("Request timeout")
result = await training_service._fetch_sales_data(
tenant_id="test-tenant",
request=mock_request
)
# Should return empty list on timeout
assert result == []
@pytest.mark.asyncio
async def test_concurrent_job_creation(self, training_service, test_db_session):
"""Test handling of concurrent job creation"""
# This test would need more sophisticated setup for true concurrency testing
# For now, just test that multiple jobs can be created
job_ids = ["concurrent-job-1", "concurrent-job-2", "concurrent-job-3"]
jobs = []
for job_id in job_ids:
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id=job_id,
config={}
)
jobs.append(job)
assert len(jobs) == 3
for i, job in enumerate(jobs):
assert job.job_id == job_ids[i]
@pytest.mark.asyncio
async def test_malformed_config_handling(self, training_service, test_db_session):
"""Test handling of malformed configuration"""
malformed_config = {
"invalid_key": "invalid_value",
"nested": {"data": None}
}
# Should not raise exception, just store the config as-is
job = await training_service.create_training_job(
db=test_db_session,
tenant_id="test-tenant",
job_id="malformed-config-job",
config=malformed_config
)
assert job.config == malformed_config