688 lines
23 KiB
Python
688 lines
23 KiB
Python
# services/training/tests/test_service.py
|
|
"""
|
|
Tests for training service business logic layer
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
from datetime import datetime, timedelta
|
|
import httpx
|
|
|
|
from app.services.training_service import TrainingService
|
|
from app.schemas.training import TrainingJobRequest, SingleProductTrainingRequest
|
|
from app.models.training import ModelTrainingLog, TrainedModel
|
|
|
|
|
|
class TestTrainingService:
|
|
"""Test the training service business logic"""
|
|
|
|
@pytest.fixture
|
|
def training_service(self, mock_ml_trainer):
|
|
return TrainingService()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_training_job_success(
|
|
self,
|
|
training_service,
|
|
test_db_session
|
|
):
|
|
"""Test successful training job creation"""
|
|
job_id = "test-job-123"
|
|
tenant_id = "test-tenant"
|
|
config = {"include_weather": True, "include_traffic": True}
|
|
|
|
result = await training_service.create_training_job(
|
|
db=test_db_session,
|
|
tenant_id=tenant_id,
|
|
job_id=job_id,
|
|
config=config
|
|
)
|
|
|
|
assert isinstance(result, ModelTrainingLog)
|
|
assert result.job_id == job_id
|
|
assert result.tenant_id == tenant_id
|
|
assert result.status == "pending"
|
|
assert result.progress == 0
|
|
assert result.config == config
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_single_product_job_success(
|
|
self,
|
|
training_service,
|
|
test_db_session
|
|
):
|
|
"""Test successful single product job creation"""
|
|
job_id = "test-product-job-123"
|
|
tenant_id = "test-tenant"
|
|
product_name = "Pan Integral"
|
|
config = {"include_weather": True}
|
|
|
|
result = await training_service.create_single_product_job(
|
|
db=test_db_session,
|
|
tenant_id=tenant_id,
|
|
product_name=product_name,
|
|
job_id=job_id,
|
|
config=config
|
|
)
|
|
|
|
assert isinstance(result, ModelTrainingLog)
|
|
assert result.job_id == job_id
|
|
assert result.tenant_id == tenant_id
|
|
assert result.config["single_product"] == product_name
|
|
assert f"Initializing training for {product_name}" in result.current_step
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_job_status_existing(
|
|
self,
|
|
training_service,
|
|
test_db_session,
|
|
training_job_in_db
|
|
):
|
|
"""Test getting status of existing job"""
|
|
result = await training_service.get_job_status(
|
|
db=test_db_session,
|
|
job_id=training_job_in_db.job_id,
|
|
tenant_id=training_job_in_db.tenant_id
|
|
)
|
|
|
|
assert result is not None
|
|
assert result.job_id == training_job_in_db.job_id
|
|
assert result.status == training_job_in_db.status
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_job_status_nonexistent(
|
|
self,
|
|
training_service,
|
|
test_db_session
|
|
):
|
|
"""Test getting status of non-existent job"""
|
|
result = await training_service.get_job_status(
|
|
db=test_db_session,
|
|
job_id="nonexistent-job",
|
|
tenant_id="test-tenant"
|
|
)
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_training_jobs(
|
|
self,
|
|
training_service,
|
|
test_db_session,
|
|
training_job_in_db
|
|
):
|
|
"""Test listing training jobs"""
|
|
result = await training_service.list_training_jobs(
|
|
db=test_db_session,
|
|
tenant_id=training_job_in_db.tenant_id,
|
|
limit=10
|
|
)
|
|
|
|
assert isinstance(result, list)
|
|
assert len(result) >= 1
|
|
assert result[0].job_id == training_job_in_db.job_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_training_jobs_with_filter(
|
|
self,
|
|
training_service,
|
|
test_db_session,
|
|
training_job_in_db
|
|
):
|
|
"""Test listing training jobs with status filter"""
|
|
result = await training_service.list_training_jobs(
|
|
db=test_db_session,
|
|
tenant_id=training_job_in_db.tenant_id,
|
|
limit=10,
|
|
status_filter="pending"
|
|
)
|
|
|
|
assert isinstance(result, list)
|
|
for job in result:
|
|
assert job.status == "pending"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_training_job_success(
|
|
self,
|
|
training_service,
|
|
test_db_session,
|
|
training_job_in_db
|
|
):
|
|
"""Test successful job cancellation"""
|
|
result = await training_service.cancel_training_job(
|
|
db=test_db_session,
|
|
job_id=training_job_in_db.job_id,
|
|
tenant_id=training_job_in_db.tenant_id
|
|
)
|
|
|
|
assert result is True
|
|
|
|
# Verify status was updated
|
|
updated_job = await training_service.get_job_status(
|
|
db=test_db_session,
|
|
job_id=training_job_in_db.job_id,
|
|
tenant_id=training_job_in_db.tenant_id
|
|
)
|
|
assert updated_job.status == "cancelled"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_nonexistent_job(
|
|
self,
|
|
training_service,
|
|
test_db_session
|
|
):
|
|
"""Test cancelling non-existent job"""
|
|
result = await training_service.cancel_training_job(
|
|
db=test_db_session,
|
|
job_id="nonexistent-job",
|
|
tenant_id="test-tenant"
|
|
)
|
|
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_training_data_valid(
|
|
self,
|
|
training_service,
|
|
test_db_session,
|
|
mock_data_service
|
|
):
|
|
"""Test validation with valid data"""
|
|
config = {"min_data_points": 30}
|
|
|
|
result = await training_service.validate_training_data(
|
|
db=test_db_session,
|
|
tenant_id="test-tenant",
|
|
config=config
|
|
)
|
|
|
|
assert isinstance(result, dict)
|
|
assert "is_valid" in result
|
|
assert "issues" in result
|
|
assert "recommendations" in result
|
|
assert "estimated_time_minutes" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_training_data_no_data(
|
|
self,
|
|
training_service,
|
|
test_db_session
|
|
):
|
|
"""Test validation with no data"""
|
|
config = {"min_data_points": 30}
|
|
|
|
with patch('app.services.training_service.TrainingService._fetch_sales_data', return_value=AsyncMock(return_value=[])):
|
|
result = await training_service.validate_training_data(
|
|
db=test_db_session,
|
|
tenant_id="test-tenant",
|
|
config=config
|
|
)
|
|
|
|
assert result["is_valid"] is False
|
|
assert "No sales data found" in result["issues"][0]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_job_status(
|
|
self,
|
|
training_service,
|
|
test_db_session,
|
|
training_job_in_db
|
|
):
|
|
"""Test updating job status"""
|
|
await training_service._update_job_status(
|
|
db=test_db_session,
|
|
job_id=training_job_in_db.job_id,
|
|
status="running",
|
|
progress=50,
|
|
current_step="Training models"
|
|
)
|
|
|
|
# Verify update
|
|
updated_job = await training_service.get_job_status(
|
|
db=test_db_session,
|
|
job_id=training_job_in_db.job_id,
|
|
tenant_id=training_job_in_db.tenant_id
|
|
)
|
|
|
|
assert updated_job.status == "running"
|
|
assert updated_job.progress == 50
|
|
assert updated_job.current_step == "Training models"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_store_trained_models(
|
|
self,
|
|
training_service,
|
|
test_db_session
|
|
):
|
|
"""Test storing trained models"""
|
|
tenant_id = "test-tenant"
|
|
training_results = {
|
|
"training_results": {
|
|
"Pan Integral": {
|
|
"status": "success",
|
|
"model_info": {
|
|
"model_id": "test-model-123",
|
|
"model_path": "/test/models/test-model-123.pkl",
|
|
"type": "prophet",
|
|
"training_samples": 100,
|
|
"features": ["temperature", "humidity"],
|
|
"hyperparameters": {"seasonality_mode": "additive"},
|
|
"training_metrics": {"mae": 5.2, "rmse": 7.8},
|
|
"data_period": {
|
|
"start_date": "2024-01-01T00:00:00",
|
|
"end_date": "2024-01-31T00:00:00"
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
await training_service._store_trained_models(
|
|
db=test_db_session,
|
|
tenant_id=tenant_id,
|
|
training_results=training_results
|
|
)
|
|
|
|
# Verify model was stored
|
|
from sqlalchemy import select
|
|
result = await test_db_session.execute(
|
|
select(TrainedModel).where(
|
|
TrainedModel.tenant_id == tenant_id,
|
|
TrainedModel.product_name == "Pan Integral"
|
|
)
|
|
)
|
|
|
|
stored_model = result.scalar_one_or_none()
|
|
assert stored_model is not None
|
|
assert stored_model.model_id == "test-model-123"
|
|
assert stored_model.is_active is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_training_logs(
|
|
self,
|
|
training_service,
|
|
test_db_session,
|
|
training_job_in_db
|
|
):
|
|
"""Test getting training logs"""
|
|
result = await training_service.get_training_logs(
|
|
db=test_db_session,
|
|
job_id=training_job_in_db.job_id,
|
|
tenant_id=training_job_in_db.tenant_id
|
|
)
|
|
|
|
assert isinstance(result, list)
|
|
assert len(result) > 0
|
|
|
|
# Check log content
|
|
log_text = " ".join(result)
|
|
assert training_job_in_db.job_id in log_text or "Job started" in log_text
|
|
|
|
|
|
class TestTrainingServiceDataFetching:
|
|
"""Test external data fetching functionality"""
|
|
|
|
@pytest.fixture
|
|
def training_service(self):
|
|
return TrainingService()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_sales_data_success(self, training_service):
|
|
"""Test successful sales data fetching"""
|
|
mock_request = Mock()
|
|
mock_request.start_date = None
|
|
mock_request.end_date = None
|
|
|
|
mock_response_data = {
|
|
"sales": [
|
|
{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45}
|
|
]
|
|
}
|
|
|
|
with patch('httpx.AsyncClient') as mock_client:
|
|
mock_response = Mock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_response_data
|
|
|
|
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
|
|
|
result = await training_service._fetch_sales_data(
|
|
tenant_id="test-tenant",
|
|
request=mock_request
|
|
)
|
|
|
|
assert result == mock_response_data["sales"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_sales_data_error(self, training_service):
|
|
"""Test sales data fetching with API error"""
|
|
mock_request = Mock()
|
|
mock_request.start_date = None
|
|
mock_request.end_date = None
|
|
|
|
with patch('httpx.AsyncClient') as mock_client:
|
|
mock_response = Mock()
|
|
mock_response.status_code = 500
|
|
|
|
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
|
|
|
result = await training_service._fetch_sales_data(
|
|
tenant_id="test-tenant",
|
|
request=mock_request
|
|
)
|
|
|
|
assert result == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_weather_data_success(self, training_service):
|
|
"""Test successful weather data fetching"""
|
|
mock_request = Mock()
|
|
mock_request.start_date = None
|
|
mock_request.end_date = None
|
|
|
|
mock_response_data = {
|
|
"weather": [
|
|
{"date": "2024-01-01", "temperature": 15.2, "precipitation": 0.0}
|
|
]
|
|
}
|
|
|
|
with patch('httpx.AsyncClient') as mock_client:
|
|
mock_response = Mock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_response_data
|
|
|
|
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
|
|
|
result = await training_service._fetch_weather_data(
|
|
tenant_id="test-tenant",
|
|
request=mock_request
|
|
)
|
|
|
|
assert result == mock_response_data["weather"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_traffic_data_success(self, training_service):
|
|
"""Test successful traffic data fetching"""
|
|
mock_request = Mock()
|
|
mock_request.start_date = None
|
|
mock_request.end_date = None
|
|
|
|
mock_response_data = {
|
|
"traffic": [
|
|
{"date": "2024-01-01", "traffic_volume": 120}
|
|
]
|
|
}
|
|
|
|
with patch('httpx.AsyncClient') as mock_client:
|
|
mock_response = Mock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_response_data
|
|
|
|
mock_client.return_value.__aenter__.return_value.get.return_value = mock_response
|
|
|
|
result = await training_service._fetch_traffic_data(
|
|
tenant_id="test-tenant",
|
|
request=mock_request
|
|
)
|
|
|
|
assert result == mock_response_data["traffic"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_data_with_date_filters(self, training_service):
|
|
"""Test data fetching with date filters"""
|
|
from datetime import datetime
|
|
|
|
mock_request = Mock()
|
|
mock_request.start_date = datetime(2024, 1, 1)
|
|
mock_request.end_date = datetime(2024, 1, 31)
|
|
|
|
with patch('httpx.AsyncClient') as mock_client:
|
|
mock_response = Mock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"sales": []}
|
|
|
|
mock_get = mock_client.return_value.__aenter__.return_value.get
|
|
mock_get.return_value = mock_response
|
|
|
|
await training_service._fetch_sales_data(
|
|
tenant_id="test-tenant",
|
|
request=mock_request
|
|
)
|
|
|
|
# Verify dates were passed in params
|
|
call_args = mock_get.call_args
|
|
params = call_args[1]["params"]
|
|
assert "start_date" in params
|
|
assert "end_date" in params
|
|
assert params["start_date"] == "2024-01-01T00:00:00"
|
|
assert params["end_date"] == "2024-01-31T00:00:00"
|
|
|
|
|
|
class TestTrainingServiceExecution:
|
|
"""Test training execution workflow"""
|
|
|
|
@pytest.fixture
|
|
def training_service(self, mock_ml_trainer):
|
|
return TrainingService()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_training_job_success(
|
|
self,
|
|
training_service,
|
|
test_db_session,
|
|
mock_messaging,
|
|
mock_data_service
|
|
):
|
|
"""Test successful training job execution"""
|
|
# Create job first
|
|
job_id = "test-execution-job"
|
|
training_log = await training_service.create_training_job(
|
|
db=test_db_session,
|
|
tenant_id="test-tenant",
|
|
job_id=job_id,
|
|
config={"include_weather": True}
|
|
)
|
|
|
|
request = TrainingJobRequest(
|
|
include_weather=True,
|
|
include_traffic=True,
|
|
min_data_points=30
|
|
)
|
|
|
|
with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch_sales, \
|
|
patch('app.services.training_service.TrainingService._fetch_weather_data') as mock_fetch_weather, \
|
|
patch('app.services.training_service.TrainingService._fetch_traffic_data') as mock_fetch_traffic, \
|
|
patch('app.services.training_service.TrainingService._store_trained_models') as mock_store:
|
|
|
|
mock_fetch_sales.return_value = [{"date": "2024-01-01", "product_name": "Pan Integral", "quantity": 45}]
|
|
mock_fetch_weather.return_value = []
|
|
mock_fetch_traffic.return_value = []
|
|
mock_store.return_value = None
|
|
|
|
await training_service.execute_training_job(
|
|
db=test_db_session,
|
|
job_id=job_id,
|
|
tenant_id="test-tenant",
|
|
request=request
|
|
)
|
|
|
|
# Verify job was completed
|
|
updated_job = await training_service.get_job_status(
|
|
db=test_db_session,
|
|
job_id=job_id,
|
|
tenant_id="test-tenant"
|
|
)
|
|
|
|
assert updated_job.status == "completed"
|
|
assert updated_job.progress == 100
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_training_job_failure(
|
|
self,
|
|
training_service,
|
|
test_db_session,
|
|
mock_messaging
|
|
):
|
|
"""Test training job execution with failure"""
|
|
# Create job first
|
|
job_id = "test-failure-job"
|
|
await training_service.create_training_job(
|
|
db=test_db_session,
|
|
tenant_id="test-tenant",
|
|
job_id=job_id,
|
|
config={}
|
|
)
|
|
|
|
request = TrainingJobRequest(min_data_points=30)
|
|
|
|
with patch('app.services.training_service.TrainingService._fetch_sales_data') as mock_fetch:
|
|
mock_fetch.side_effect = Exception("Data service unavailable")
|
|
|
|
with pytest.raises(Exception):
|
|
await training_service.execute_training_job(
|
|
db=test_db_session,
|
|
job_id=job_id,
|
|
tenant_id="test-tenant",
|
|
request=request
|
|
)
|
|
|
|
# Verify job was marked as failed
|
|
updated_job = await training_service.get_job_status(
|
|
db=test_db_session,
|
|
job_id=job_id,
|
|
tenant_id="test-tenant"
|
|
)
|
|
|
|
assert updated_job.status == "failed"
|
|
assert "Data service unavailable" in updated_job.error_message
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_single_product_training_success(
|
|
self,
|
|
training_service,
|
|
test_db_session,
|
|
mock_messaging,
|
|
mock_data_service
|
|
):
|
|
"""Test successful single product training execution"""
|
|
job_id = "test-single-product-job"
|
|
product_name = "Pan Integral"
|
|
|
|
await training_service.create_single_product_job(
|
|
db=test_db_session,
|
|
tenant_id="test-tenant",
|
|
product_name=product_name,
|
|
job_id=job_id,
|
|
config={}
|
|
)
|
|
|
|
request = SingleProductTrainingRequest(
|
|
include_weather=True,
|
|
include_traffic=False
|
|
)
|
|
|
|
with patch('app.services.training_service.TrainingService._fetch_product_sales_data') as mock_fetch_sales, \
|
|
patch('app.services.training_service.TrainingService._fetch_weather_data') as mock_fetch_weather, \
|
|
patch('app.services.training_service.TrainingService._store_single_trained_model') as mock_store:
|
|
|
|
mock_fetch_sales.return_value = [{"date": "2024-01-01", "product_name": product_name, "quantity": 45}]
|
|
mock_fetch_weather.return_value = []
|
|
mock_store.return_value = None
|
|
|
|
await training_service.execute_single_product_training(
|
|
db=test_db_session,
|
|
job_id=job_id,
|
|
tenant_id="test-tenant",
|
|
product_name=product_name,
|
|
request=request
|
|
)
|
|
|
|
# Verify job was completed
|
|
updated_job = await training_service.get_job_status(
|
|
db=test_db_session,
|
|
job_id=job_id,
|
|
tenant_id="test-tenant"
|
|
)
|
|
|
|
assert updated_job.status == "completed"
|
|
assert updated_job.progress == 100
|
|
|
|
|
|
class TestTrainingServiceEdgeCases:
|
|
"""Test edge cases and error conditions"""
|
|
|
|
@pytest.fixture
|
|
def training_service(self):
|
|
return TrainingService()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_database_connection_failure(self, training_service):
|
|
"""Test handling of database connection failures"""
|
|
with patch('sqlalchemy.ext.asyncio.AsyncSession') as mock_session:
|
|
mock_session.side_effect = Exception("Database connection failed")
|
|
|
|
with pytest.raises(Exception):
|
|
await training_service.create_training_job(
|
|
db=mock_session,
|
|
tenant_id="test-tenant",
|
|
job_id="test-job",
|
|
config={}
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_external_service_timeout(self, training_service):
|
|
"""Test handling of external service timeouts"""
|
|
mock_request = Mock()
|
|
mock_request.start_date = None
|
|
mock_request.end_date = None
|
|
|
|
with patch('httpx.AsyncClient') as mock_client:
|
|
mock_client.return_value.__aenter__.return_value.get.side_effect = httpx.TimeoutException("Request timeout")
|
|
|
|
result = await training_service._fetch_sales_data(
|
|
tenant_id="test-tenant",
|
|
request=mock_request
|
|
)
|
|
|
|
# Should return empty list on timeout
|
|
assert result == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_job_creation(self, training_service, test_db_session):
|
|
"""Test handling of concurrent job creation"""
|
|
# This test would need more sophisticated setup for true concurrency testing
|
|
# For now, just test that multiple jobs can be created
|
|
|
|
job_ids = ["concurrent-job-1", "concurrent-job-2", "concurrent-job-3"]
|
|
|
|
jobs = []
|
|
for job_id in job_ids:
|
|
job = await training_service.create_training_job(
|
|
db=test_db_session,
|
|
tenant_id="test-tenant",
|
|
job_id=job_id,
|
|
config={}
|
|
)
|
|
jobs.append(job)
|
|
|
|
assert len(jobs) == 3
|
|
for i, job in enumerate(jobs):
|
|
assert job.job_id == job_ids[i]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_malformed_config_handling(self, training_service, test_db_session):
|
|
"""Test handling of malformed configuration"""
|
|
malformed_config = {
|
|
"invalid_key": "invalid_value",
|
|
"nested": {"data": None}
|
|
}
|
|
|
|
# Should not raise exception, just store the config as-is
|
|
job = await training_service.create_training_job(
|
|
db=test_db_session,
|
|
tenant_id="test-tenant",
|
|
job_id="malformed-config-job",
|
|
config=malformed_config
|
|
)
|
|
|
|
assert job.config == malformed_config |