""" Tests for migration race condition fix Tests cover: - Concurrent migration execution with multiple workers - Lock retry logic with exponential backoff - Graduated logging levels - Connection timeout handling - Maximum retry exhaustion - Worker coordination (one applies, others wait) """ import pytest import sqlite3 import tempfile import time import multiprocessing from pathlib import Path from unittest.mock import patch, MagicMock, call from multiprocessing import Barrier from starpunk.migrations import ( MigrationError, run_migrations, ) from starpunk import create_app @pytest.fixture def temp_db(): """Create a temporary database for testing""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: db_path = Path(f.name) yield db_path # Cleanup if db_path.exists(): db_path.unlink() class TestRetryLogic: """Test retry logic for lock acquisition""" def test_success_on_first_attempt(self, temp_db): """Test successful migration on first attempt (no retry needed)""" # Initialize database with proper schema first from starpunk.database import init_db from starpunk import create_app app = create_app({'DATABASE_PATH': str(temp_db)}) init_db(app) # Verify migrations table exists and has records conn = sqlite3.connect(temp_db) cursor = conn.execute("SELECT COUNT(*) FROM schema_migrations") count = cursor.fetchone()[0] conn.close() # Should have migration records assert count >= 0 # At least migrations table created def test_retry_on_locked_database(self, temp_db): """Test retry logic when database is locked""" with patch('sqlite3.connect') as mock_connect: # Create mock connection that succeeds on 3rd attempt mock_conn = MagicMock() mock_conn.execute.return_value.fetchone.return_value = (0,) # Empty migrations # First 2 attempts fail with locked error mock_connect.side_effect = [ sqlite3.OperationalError("database is locked"), sqlite3.OperationalError("database is locked"), mock_conn # Success on 3rd attempt ] # This should succeed after retries # Note: Will fail since mock doesn't fully implement migrations, # but we're testing that connect() is called 3 times try: run_migrations(str(temp_db)) except: pass # Expected to fail with mock # Verify 3 connection attempts were made assert mock_connect.call_count == 3 def test_exponential_backoff_timing(self, temp_db): """Test that exponential backoff delays increase correctly""" delays = [] def mock_sleep(duration): delays.append(duration) with patch('time.sleep', side_effect=mock_sleep): with patch('time.time', return_value=0): # Prevent timeout from triggering with patch('sqlite3.connect') as mock_connect: # Always fail with locked error mock_connect.side_effect = sqlite3.OperationalError("database is locked") # Should exhaust retries with pytest.raises(MigrationError, match="Failed to acquire migration lock"): run_migrations(str(temp_db)) # Verify exponential backoff (10 retries = 9 sleeps between attempts) # First attempt doesn't sleep, then sleep before retry 2, 3, ... 10 assert len(delays) == 9, f"Expected 9 delays (10 retries), got {len(delays)}" # Check delays are increasing (exponential with jitter) # Base is 0.1, so: 0.2+jitter, 0.4+jitter, 0.8+jitter, etc. for i in range(len(delays) - 1): # Each delay should be roughly double previous (within jitter range) # Allow for jitter of 0.1s assert delays[i+1] > delays[i] * 0.9, f"Delay {i+1} ({delays[i+1]}) not greater than previous ({delays[i]})" def test_max_retries_exhaustion(self, temp_db): """Test that retries are exhausted after max attempts""" with patch('sqlite3.connect') as mock_connect: # Always return locked error mock_connect.side_effect = sqlite3.OperationalError("database is locked") # Should raise MigrationError after exhausting retries with pytest.raises(MigrationError) as exc_info: run_migrations(str(temp_db)) # Verify error message is helpful error_msg = str(exc_info.value) assert "Failed to acquire migration lock" in error_msg assert "10 attempts" in error_msg assert "Possible causes" in error_msg # MAX_RETRIES=10 means 10 attempts total (not initial + 10 retries) assert mock_connect.call_count == 10 def test_total_timeout_protection(self, temp_db): """Test that total timeout limit (120s) is respected""" with patch('time.time') as mock_time: with patch('time.sleep'): with patch('sqlite3.connect') as mock_connect: # Simulate time passing (need enough values for all retries) # Each retry checks time twice, so provide plenty of values times = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 130, 140, 150] mock_time.side_effect = times mock_connect.side_effect = sqlite3.OperationalError("database is locked") # Should timeout before exhausting retries with pytest.raises(MigrationError) as exc_info: run_migrations(str(temp_db)) error_msg = str(exc_info.value) assert "Migration timeout" in error_msg or "Failed to acquire" in error_msg class TestGraduatedLogging: """Test graduated logging levels based on retry count""" def test_debug_level_for_early_retries(self, temp_db, caplog): """Test DEBUG level for retries 1-3""" with patch('time.sleep'): with patch('sqlite3.connect') as mock_connect: # Fail 3 times, then succeed mock_conn = MagicMock() mock_conn.execute.return_value.fetchone.return_value = (0,) errors = [sqlite3.OperationalError("database is locked")] * 3 mock_connect.side_effect = errors + [mock_conn] import logging with caplog.at_level(logging.DEBUG): try: run_migrations(str(temp_db)) except: pass # Check that DEBUG messages were logged for early retries debug_msgs = [r for r in caplog.records if r.levelname == 'DEBUG' and 'retry' in r.message.lower()] assert len(debug_msgs) >= 1 # At least one DEBUG retry message def test_info_level_for_middle_retries(self, temp_db, caplog): """Test INFO level for retries 4-7""" with patch('time.sleep'): with patch('sqlite3.connect') as mock_connect: # Fail 5 times to get into INFO range errors = [sqlite3.OperationalError("database is locked")] * 5 mock_connect.side_effect = errors import logging with caplog.at_level(logging.INFO): try: run_migrations(str(temp_db)) except MigrationError: pass # Check that INFO messages were logged for middle retries info_msgs = [r for r in caplog.records if r.levelname == 'INFO' and 'retry' in r.message.lower()] assert len(info_msgs) >= 1 # At least one INFO retry message def test_warning_level_for_late_retries(self, temp_db, caplog): """Test WARNING level for retries 8+""" with patch('time.sleep'): with patch('sqlite3.connect') as mock_connect: # Fail 9 times to get into WARNING range errors = [sqlite3.OperationalError("database is locked")] * 9 mock_connect.side_effect = errors import logging with caplog.at_level(logging.WARNING): try: run_migrations(str(temp_db)) except MigrationError: pass # Check that WARNING messages were logged for late retries warning_msgs = [r for r in caplog.records if r.levelname == 'WARNING' and 'retry' in r.message.lower()] assert len(warning_msgs) >= 1 # At least one WARNING retry message class TestConnectionManagement: """Test connection lifecycle management""" def test_new_connection_per_retry(self, temp_db): """Test that each retry creates a new connection""" with patch('sqlite3.connect') as mock_connect: # Track connection instances connections = [] def track_connection(*args, **kwargs): conn = MagicMock() connections.append(conn) raise sqlite3.OperationalError("database is locked") mock_connect.side_effect = track_connection try: run_migrations(str(temp_db)) except MigrationError: pass # Each retry should have created a new connection # Initial + 10 retries = 11 total assert len(connections) == 11 def test_connection_closed_on_failure(self, temp_db): """Test that connection is closed even on failure""" with patch('sqlite3.connect') as mock_connect: mock_conn = MagicMock() mock_connect.return_value = mock_conn # Make execute raise an error mock_conn.execute.side_effect = Exception("Test error") try: run_migrations(str(temp_db)) except: pass # Connection should have been closed mock_conn.close.assert_called() def test_connection_timeout_setting(self, temp_db): """Test that connection timeout is set to 30s""" with patch('sqlite3.connect') as mock_connect: mock_conn = MagicMock() mock_conn.execute.return_value.fetchone.return_value = (0,) mock_connect.return_value = mock_conn try: run_migrations(str(temp_db)) except: pass # Verify connect was called with timeout=30.0 mock_connect.assert_called_with(str(temp_db), timeout=30.0) class TestConcurrentExecution: """Test concurrent worker scenarios""" def test_concurrent_workers_barrier_sync(self): """Test multiple workers starting simultaneously with barrier""" # This test uses actual multiprocessing with barrier synchronization with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "test.db" # Create a barrier for 4 workers barrier = Barrier(4) results = [] def worker(worker_id): """Worker function that waits at barrier then runs migrations""" try: barrier.wait() # All workers start together run_migrations(str(db_path)) return True except Exception as e: return False # Run 4 workers concurrently with multiprocessing.Pool(4) as pool: results = pool.map(worker, range(4)) # All workers should succeed (one applies, others wait) assert all(results), f"Some workers failed: {results}" # Verify migrations were applied correctly conn = sqlite3.connect(db_path) cursor = conn.execute("SELECT COUNT(*) FROM schema_migrations") count = cursor.fetchone()[0] conn.close() # Should have migration records assert count >= 0 def test_sequential_worker_startup(self): """Test workers starting one after another""" with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "test.db" # First worker applies migrations run_migrations(str(db_path)) # Second worker should detect completed migrations run_migrations(str(db_path)) # Third worker should also succeed run_migrations(str(db_path)) # All should succeed without errors def test_worker_late_arrival(self): """Test worker arriving after migrations complete""" with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "test.db" # First worker completes migrations run_migrations(str(db_path)) # Simulate some time passing time.sleep(0.1) # Late worker should detect completed migrations immediately start_time = time.time() run_migrations(str(db_path)) elapsed = time.time() - start_time # Should be very fast (< 1s) since migrations already applied assert elapsed < 1.0 class TestErrorHandling: """Test error handling scenarios""" def test_rollback_on_migration_failure(self, temp_db): """Test that transaction is rolled back on migration failure""" with patch('sqlite3.connect') as mock_connect: mock_conn = MagicMock() mock_connect.return_value = mock_conn # Make migration execution fail mock_conn.executescript.side_effect = Exception("Migration failed") mock_conn.execute.return_value.fetchone.side_effect = [ (0,), # migration_count check # Will fail before getting here ] with pytest.raises(MigrationError): run_migrations(str(temp_db)) # Rollback should have been called mock_conn.rollback.assert_called() def test_rollback_failure_causes_system_exit(self, temp_db): """Test that rollback failure raises SystemExit""" with patch('sqlite3.connect') as mock_connect: mock_conn = MagicMock() mock_connect.return_value = mock_conn # Make both migration and rollback fail mock_conn.executescript.side_effect = Exception("Migration failed") mock_conn.rollback.side_effect = Exception("Rollback failed") mock_conn.execute.return_value.fetchone.return_value = (0,) with pytest.raises(SystemExit): run_migrations(str(temp_db)) def test_helpful_error_message_on_retry_exhaustion(self, temp_db): """Test that error message provides actionable guidance""" with patch('sqlite3.connect') as mock_connect: mock_connect.side_effect = sqlite3.OperationalError("database is locked") with pytest.raises(MigrationError) as exc_info: run_migrations(str(temp_db)) error_msg = str(exc_info.value) # Should contain helpful information assert "Failed to acquire migration lock" in error_msg assert "attempts" in error_msg assert "Possible causes" in error_msg assert "Another process" in error_msg or "stuck" in error_msg assert "Action:" in error_msg or "Restart" in error_msg class TestPerformance: """Test performance characteristics""" def test_single_worker_performance(self): """Test that single worker completes quickly""" with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "test.db" start_time = time.time() run_migrations(str(db_path)) elapsed = time.time() - start_time # Should complete in under 1 second for single worker assert elapsed < 1.0, f"Single worker took {elapsed}s (target: <1s)" def test_concurrent_workers_performance(self): """Test that 4 concurrent workers complete in reasonable time""" with tempfile.TemporaryDirectory() as tmpdir: db_path = Path(tmpdir) / "test.db" def worker(worker_id): run_migrations(str(db_path)) return True start_time = time.time() with multiprocessing.Pool(4) as pool: results = pool.map(worker, range(4)) elapsed = time.time() - start_time # All should succeed assert all(results) # Should complete in under 5 seconds # (includes lock contention and retry delays) assert elapsed < 5.0, f"4 workers took {elapsed}s (target: <5s)" class TestBeginImmediateTransaction: """Test BEGIN IMMEDIATE transaction usage""" def test_begin_immediate_called(self, temp_db): """Test that BEGIN IMMEDIATE is used for locking""" with patch('sqlite3.connect') as mock_connect: mock_conn = MagicMock() mock_connect.return_value = mock_conn mock_conn.execute.return_value.fetchone.return_value = (0,) try: run_migrations(str(temp_db)) except: pass # Verify BEGIN IMMEDIATE was called calls = [str(call) for call in mock_conn.execute.call_args_list] begin_immediate_calls = [c for c in calls if 'BEGIN IMMEDIATE' in c] assert len(begin_immediate_calls) > 0, "BEGIN IMMEDIATE not called" if __name__ == "__main__": pytest.main([__file__, "-v"])