""" Unit tests for Token Service. Tests token generation, validation, revocation, and cleanup. """ import hashlib import time from datetime import datetime, timedelta import pytest from sqlalchemy import text from gondulf.database.connection import Database from gondulf.services.token_service import TokenService @pytest.fixture def test_database(tmp_path): """Create test database with migrations.""" db_path = tmp_path / "test.db" db = Database(f"sqlite:///{db_path}") db.ensure_database_directory() db.run_migrations() return db @pytest.fixture def token_service(test_database): """Create token service with test database.""" return TokenService( database=test_database, token_length=32, token_ttl=3600 ) class TestTokenGeneration: """Tests for token generation.""" def test_generate_token_returns_string(self, token_service): """Test that generate_token returns a string token.""" token = token_service.generate_token( me="https://example.com", client_id="https://client.example.com", scope="" ) assert isinstance(token, str) assert len(token) == 43 # 32 bytes base64url = 43 chars def test_generate_token_stores_hash(self, token_service, test_database): """Test that token is stored as SHA-256 hash, not plaintext.""" token = token_service.generate_token( me="https://example.com", client_id="https://client.example.com", scope="" ) token_hash = hashlib.sha256(token.encode('utf-8')).hexdigest() engine = test_database.get_engine() with engine.connect() as conn: result = conn.execute( text("SELECT token_hash FROM tokens WHERE token_hash = :hash"), {"hash": token_hash} ).fetchone() assert result is not None assert result[0] == token_hash def test_generate_token_stores_metadata(self, token_service, test_database): """Test that token metadata is stored correctly.""" me = "https://example.com" client_id = "https://client.example.com" scope = "profile" token = token_service.generate_token(me=me, client_id=client_id, scope=scope) token_hash = hashlib.sha256(token.encode('utf-8')).hexdigest() engine = test_database.get_engine() with engine.connect() as conn: result = conn.execute( text("SELECT me, client_id, scope FROM tokens WHERE token_hash = :hash"), {"hash": token_hash} ).fetchone() assert result is not None assert result[0] == me assert result[1] == client_id assert result[2] == scope def test_generate_token_sets_expiration(self, token_service, test_database): """Test that token expiration is calculated correctly.""" token = token_service.generate_token( me="https://example.com", client_id="https://client.example.com", scope="" ) token_hash = hashlib.sha256(token.encode('utf-8')).hexdigest() engine = test_database.get_engine() with engine.connect() as conn: result = conn.execute( text("SELECT issued_at, expires_at FROM tokens WHERE token_hash = :hash"), {"hash": token_hash} ).fetchone() issued_at = datetime.fromisoformat(result[0]) expires_at = datetime.fromisoformat(result[1]) # Should be ~3600 seconds apart time_diff = (expires_at - issued_at).total_seconds() assert 3590 < time_diff < 3610 # Allow 10 second variance def test_generate_token_is_random(self, token_service): """Test that generated tokens are cryptographically random.""" tokens = set() for _ in range(100): token = token_service.generate_token( me="https://example.com", client_id="https://client.example.com", scope="" ) tokens.add(token) # All 100 tokens should be unique assert len(tokens) == 100 class TestTokenValidation: """Tests for token validation.""" def test_validate_token_success(self, token_service): """Test validating a valid token returns metadata.""" token = token_service.generate_token( me="https://example.com", client_id="https://client.example.com", scope="profile" ) metadata = token_service.validate_token(token) assert metadata is not None assert metadata['me'] == "https://example.com" assert metadata['client_id'] == "https://client.example.com" assert metadata['scope'] == "profile" def test_validate_token_not_found(self, token_service): """Test validating non-existent token returns None.""" fake_token = "invalid_token_12345678901234567890123456" metadata = token_service.validate_token(fake_token) assert metadata is None def test_validate_token_expired(self, token_service, test_database): """Test validating expired token returns None.""" # Generate token with short TTL short_ttl_service = TokenService( database=test_database, token_length=32, token_ttl=1 # 1 second ) token = short_ttl_service.generate_token( me="https://example.com", client_id="https://client.example.com", scope="" ) # Wait for expiration time.sleep(1.1) metadata = short_ttl_service.validate_token(token) assert metadata is None def test_validate_token_revoked(self, token_service): """Test validating revoked token returns None.""" token = token_service.generate_token( me="https://example.com", client_id="https://client.example.com", scope="" ) # Revoke the token token_service.revoke_token(token) # Validation should fail metadata = token_service.validate_token(token) assert metadata is None class TestTokenRevocation: """Tests for token revocation.""" def test_revoke_token_success(self, token_service): """Test revoking a valid token returns True.""" token = token_service.generate_token( me="https://example.com", client_id="https://client.example.com", scope="" ) result = token_service.revoke_token(token) assert result is True def test_revoke_token_not_found(self, token_service): """Test revoking non-existent token returns False.""" fake_token = "invalid_token_12345678901234567890123456" result = token_service.revoke_token(fake_token) assert result is False def test_revoked_token_fails_validation(self, token_service): """Test that revoked tokens cannot be validated.""" token = token_service.generate_token( me="https://example.com", client_id="https://client.example.com", scope="" ) # Revoke and try to validate token_service.revoke_token(token) metadata = token_service.validate_token(token) assert metadata is None class TestTokenCleanup: """Tests for expired token cleanup.""" def test_cleanup_expired_tokens(self, test_database): """Test cleanup deletes expired tokens.""" # Create service with short TTL short_ttl_service = TokenService( database=test_database, token_length=32, token_ttl=1 # 1 second ) # Generate multiple tokens for i in range(3): short_ttl_service.generate_token( me=f"https://example{i}.com", client_id="https://client.example.com", scope="" ) # Wait for expiration time.sleep(1.1) # Run cleanup deleted_count = short_ttl_service.cleanup_expired_tokens() assert deleted_count == 3 def test_cleanup_preserves_valid_tokens(self, test_database): """Test cleanup doesn't delete valid tokens.""" service = TokenService( database=test_database, token_length=32, token_ttl=3600 # 1 hour ) # Generate token token = service.generate_token( me="https://example.com", client_id="https://client.example.com", scope="" ) # Run cleanup deleted_count = service.cleanup_expired_tokens() # No tokens should be deleted assert deleted_count == 0 # Token should still be valid metadata = service.validate_token(token) assert metadata is not None def test_cleanup_empty_database(self, token_service): """Test cleanup handles empty database gracefully.""" deleted_count = token_service.cleanup_expired_tokens() assert deleted_count == 0 class TestTokenServiceConfiguration: """Tests for token service configuration.""" def test_custom_token_length(self, test_database): """Test custom token length is respected.""" service = TokenService( database=test_database, token_length=16, # Smaller token token_ttl=3600 ) token = service.generate_token( me="https://example.com", client_id="https://client.example.com", scope="" ) # 16 bytes base64url = ~22 chars assert len(token) == 22 def test_custom_ttl(self, test_database): """Test custom TTL is respected.""" service = TokenService( database=test_database, token_length=32, token_ttl=7200 # 2 hours ) token = service.generate_token( me="https://example.com", client_id="https://client.example.com", scope="" ) token_hash = hashlib.sha256(token.encode('utf-8')).hexdigest() engine = test_database.get_engine() with engine.connect() as conn: result = conn.execute( text("SELECT issued_at, expires_at FROM tokens WHERE token_hash = :hash"), {"hash": token_hash} ).fetchone() issued_at = datetime.fromisoformat(result[0]) expires_at = datetime.fromisoformat(result[1]) # Should be ~7200 seconds apart time_diff = (expires_at - issued_at).total_seconds() assert 7190 < time_diff < 7210