""" Tests for authentication module (starpunk/auth.py) """ import hashlib import secrets from datetime import datetime, timedelta from unittest.mock import MagicMock, patch import httpx import pytest from flask import g from starpunk.auth import ( AuthError, IndieLoginError, InvalidStateError, UnauthorizedError, _cleanup_expired_sessions, _generate_state_token, _hash_token, _verify_state_token, create_session, destroy_session, handle_callback, initiate_login, require_auth, verify_session, ) @pytest.fixture def app(tmp_path): """Create Flask app for testing""" from starpunk import create_app # Create test-specific data directory test_data_dir = tmp_path / "data" test_data_dir.mkdir(parents=True, exist_ok=True) app = create_app( { "TESTING": True, "SITE_URL": "http://localhost:5000", "ADMIN_ME": "https://example.com", "SESSION_SECRET": secrets.token_hex(32), "SESSION_LIFETIME": 30, "INDIELOGIN_URL": "https://indielogin.com", "DATA_PATH": test_data_dir, "NOTES_PATH": test_data_dir / "notes", "DATABASE_PATH": test_data_dir / "starpunk.db", } ) return app @pytest.fixture def db(app): """Get database connection""" from starpunk.database import get_db with app.app_context(): yield get_db(app) @pytest.fixture def client(app): """Get Flask test client""" return app.test_client() # Test helper functions class TestHelpers: def test_hash_token(self): """Test token hashing""" token = "test-token-123" expected = hashlib.sha256(token.encode()).hexdigest() assert _hash_token(token) == expected def test_hash_token_consistent(self): """Test that hashing is consistent""" token = "test-token" hash1 = _hash_token(token) hash2 = _hash_token(token) assert hash1 == hash2 def test_hash_token_different_inputs(self): """Test that different tokens produce different hashes""" token1 = "token1" token2 = "token2" assert _hash_token(token1) != _hash_token(token2) def test_generate_state_token(self): """Test state token generation""" token = _generate_state_token() assert isinstance(token, str) assert len(token) > 0 def test_generate_state_token_unique(self): """Test that generated tokens are unique""" tokens = [_generate_state_token() for _ in range(10)] assert len(set(tokens)) == 10 class TestStateTokenVerification: def test_verify_valid_state_token(self, app, db): """Test verifying a valid state token""" with app.app_context(): state = secrets.token_urlsafe(32) expires_at = datetime.utcnow() + timedelta(minutes=5) db.execute( "INSERT INTO auth_state (state, expires_at) VALUES (?, ?)", (state, expires_at), ) db.commit() assert _verify_state_token(state) is True # Token should be deleted after verification result = db.execute( "SELECT 1 FROM auth_state WHERE state = ?", (state,) ).fetchone() assert result is None def test_verify_invalid_state_token(self, app): """Test verifying an invalid state token""" with app.app_context(): assert _verify_state_token("invalid-token") is False def test_verify_expired_state_token(self, app, db): """Test verifying an expired state token""" with app.app_context(): state = secrets.token_urlsafe(32) expires_at = datetime.utcnow() - timedelta(minutes=5) db.execute( "INSERT INTO auth_state (state, expires_at) VALUES (?, ?)", (state, expires_at), ) db.commit() assert _verify_state_token(state) is False class TestCleanup: def test_cleanup_expired_sessions(self, app, db): """Test cleanup of expired sessions""" with app.app_context(): # Create expired session token_hash = _hash_token("expired-token") expires_at = datetime.utcnow() - timedelta(days=1) db.execute( """ INSERT INTO sessions (session_token_hash, me, expires_at) VALUES (?, ?, ?) """, (token_hash, "https://example.com", expires_at), ) db.commit() _cleanup_expired_sessions() # Expired session should be deleted result = db.execute( "SELECT 1 FROM sessions WHERE session_token_hash = ?", (token_hash,) ).fetchone() assert result is None def test_cleanup_expired_auth_state(self, app, db): """Test cleanup of expired auth state""" with app.app_context(): state = secrets.token_urlsafe(32) expires_at = datetime.utcnow() - timedelta(minutes=10) db.execute( "INSERT INTO auth_state (state, expires_at) VALUES (?, ?)", (state, expires_at), ) db.commit() _cleanup_expired_sessions() # Expired state should be deleted result = db.execute( "SELECT 1 FROM auth_state WHERE state = ?", (state,) ).fetchone() assert result is None def test_cleanup_keeps_valid_sessions(self, app, db): """Test that cleanup keeps valid sessions""" with app.app_context(): token_hash = _hash_token("valid-token") expires_at = datetime.utcnow() + timedelta(days=30) db.execute( """ INSERT INTO sessions (session_token_hash, me, expires_at) VALUES (?, ?, ?) """, (token_hash, "https://example.com", expires_at), ) db.commit() _cleanup_expired_sessions() # Valid session should still exist result = db.execute( "SELECT 1 FROM sessions WHERE session_token_hash = ?", (token_hash,) ).fetchone() assert result is not None class TestInitiateLogin: def test_initiate_login_success(self, app, db): """Test successful login initiation""" with app.app_context(): me_url = "https://example.com" auth_url = initiate_login(me_url) assert "indielogin.com/auth" in auth_url assert "me=https%3A%2F%2Fexample.com" in auth_url assert "client_id=" in auth_url assert "redirect_uri=" in auth_url assert "state=" in auth_url assert "response_type=code" in auth_url # State should be stored in database result = db.execute("SELECT COUNT(*) as count FROM auth_state").fetchone() assert result["count"] > 0 def test_initiate_login_invalid_url(self, app): """Test login initiation with invalid URL""" with app.app_context(): with pytest.raises(ValueError, match="Invalid URL format"): initiate_login("not-a-url") def test_initiate_login_stores_state(self, app, db): """Test that state token is stored""" with app.app_context(): me_url = "https://example.com" auth_url = initiate_login(me_url) # Extract state from URL state_param = [p for p in auth_url.split("&") if p.startswith("state=")][0] state = state_param.split("=")[1] # State should exist in database result = db.execute( "SELECT expires_at FROM auth_state WHERE state = ?", (state,) ).fetchone() assert result is not None class TestHandleCallback: @patch("starpunk.auth.httpx.post") def test_handle_callback_success(self, mock_post, app, db, client): """Test successful callback handling""" with app.test_request_context(): # Setup state token state = secrets.token_urlsafe(32) expires_at = datetime.utcnow() + timedelta(minutes=5) db.execute( "INSERT INTO auth_state (state, expires_at) VALUES (?, ?)", (state, expires_at), ) db.commit() # Mock IndieLogin response mock_response = MagicMock() mock_response.json.return_value = {"me": "https://example.com"} mock_post.return_value = mock_response # Handle callback code = "test-code" session_token = handle_callback(code, state) assert session_token is not None assert isinstance(session_token, str) # Session should be created token_hash = _hash_token(session_token) result = db.execute( "SELECT me FROM sessions WHERE session_token_hash = ?", (token_hash,) ).fetchone() assert result is not None assert result["me"] == "https://example.com" def test_handle_callback_invalid_state(self, app): """Test callback with invalid state""" with app.app_context(): with pytest.raises(InvalidStateError): handle_callback("code", "invalid-state") @patch("starpunk.auth.httpx.post") def test_handle_callback_unauthorized_user(self, mock_post, app, db): """Test callback with unauthorized user""" with app.app_context(): # Setup state token state = secrets.token_urlsafe(32) expires_at = datetime.utcnow() + timedelta(minutes=5) db.execute( "INSERT INTO auth_state (state, expires_at) VALUES (?, ?)", (state, expires_at), ) db.commit() # Mock IndieLogin response with different user mock_response = MagicMock() mock_response.json.return_value = {"me": "https://attacker.com"} mock_post.return_value = mock_response with pytest.raises(UnauthorizedError): handle_callback("code", state) @patch("starpunk.auth.httpx.post") def test_handle_callback_indielogin_error(self, mock_post, app, db): """Test callback with IndieLogin error""" with app.app_context(): # Setup state token state = secrets.token_urlsafe(32) expires_at = datetime.utcnow() + timedelta(minutes=5) db.execute( "INSERT INTO auth_state (state, expires_at) VALUES (?, ?)", (state, expires_at), ) db.commit() # Mock IndieLogin error mock_post.side_effect = httpx.RequestError("Connection failed") with pytest.raises(IndieLoginError): handle_callback("code", state) @patch("starpunk.auth.httpx.post") def test_handle_callback_no_identity(self, mock_post, app, db): """Test callback with no identity in response""" with app.app_context(): # Setup state token state = secrets.token_urlsafe(32) expires_at = datetime.utcnow() + timedelta(minutes=5) db.execute( "INSERT INTO auth_state (state, expires_at) VALUES (?, ?)", (state, expires_at), ) db.commit() # Mock IndieLogin response without 'me' field mock_response = MagicMock() mock_response.json.return_value = {} mock_post.return_value = mock_response with pytest.raises(IndieLoginError, match="No identity returned"): handle_callback("code", state) class TestCreateSession: def test_create_session_success(self, app, db, client): """Test successful session creation""" with app.test_request_context(): me = "https://example.com" session_token = create_session(me) assert session_token is not None assert isinstance(session_token, str) # Session should exist in database token_hash = _hash_token(session_token) result = db.execute( """ SELECT me, expires_at, created_at FROM sessions WHERE session_token_hash = ? """, (token_hash,), ).fetchone() assert result is not None assert result["me"] == me assert result["expires_at"] is not None def test_create_session_metadata(self, app, db, client): """Test that session stores metadata""" with app.test_request_context( headers={"User-Agent": "Test Browser"}, environ_base={"REMOTE_ADDR": "127.0.0.1"}, ): me = "https://example.com" session_token = create_session(me) token_hash = _hash_token(session_token) result = db.execute( """ SELECT user_agent, ip_address FROM sessions WHERE session_token_hash = ? """, (token_hash,), ).fetchone() assert result["user_agent"] == "Test Browser" assert result["ip_address"] == "127.0.0.1" class TestVerifySession: def test_verify_valid_session(self, app, db, client): """Test verifying a valid session""" with app.test_request_context(): # Create session me = "https://example.com" session_token = create_session(me) # Verify session session_info = verify_session(session_token) assert session_info is not None assert session_info["me"] == me assert "created_at" in session_info assert "expires_at" in session_info def test_verify_invalid_session(self, app): """Test verifying an invalid session""" with app.app_context(): session_info = verify_session("invalid-token") assert session_info is None def test_verify_expired_session(self, app, db): """Test verifying an expired session""" with app.app_context(): # Create expired session token = secrets.token_urlsafe(32) token_hash = _hash_token(token) expires_at = datetime.utcnow() - timedelta(days=1) db.execute( """ INSERT INTO sessions (session_token_hash, me, expires_at) VALUES (?, ?, ?) """, (token_hash, "https://example.com", expires_at), ) db.commit() session_info = verify_session(token) assert session_info is None def test_verify_session_updates_last_used(self, app, db, client): """Test that verification updates last_used_at""" with app.test_request_context(): # Create session me = "https://example.com" session_token = create_session(me) # Verify session verify_session(session_token) # Check last_used_at is set token_hash = _hash_token(session_token) result = db.execute( "SELECT last_used_at FROM sessions WHERE session_token_hash = ?", (token_hash,), ).fetchone() assert result["last_used_at"] is not None def test_verify_empty_token(self, app): """Test verifying empty token""" with app.app_context(): assert verify_session("") is None assert verify_session(None) is None class TestDestroySession: def test_destroy_session_success(self, app, db, client): """Test successful session destruction""" with app.test_request_context(): # Create session me = "https://example.com" session_token = create_session(me) # Destroy session destroy_session(session_token) # Session should no longer exist token_hash = _hash_token(session_token) result = db.execute( "SELECT 1 FROM sessions WHERE session_token_hash = ?", (token_hash,) ).fetchone() assert result is None def test_destroy_invalid_session(self, app): """Test destroying an invalid session (should not raise error)""" with app.app_context(): destroy_session("invalid-token") # Should not raise def test_destroy_empty_token(self, app): """Test destroying empty token""" with app.app_context(): destroy_session("") # Should not raise destroy_session(None) # Should not raise class TestRequireAuthDecorator: def test_require_auth_with_valid_session(self, app, db, client): """Test require_auth decorator with valid session""" with app.test_request_context(): # Create session me = "https://example.com" session_token = create_session(me) # Create test route @require_auth def protected_route(): return "Protected content" # Manually set cookie header environ = {"HTTP_COOKIE": f"session={session_token}"} with app.test_request_context(environ_base=environ): result = protected_route() assert result == "Protected content" assert hasattr(g, "user") assert g.user["me"] == me def test_require_auth_without_session(self, app, client): """Test require_auth decorator without session""" # Create test route @require_auth def protected_route(): return "Protected content" # Call protected route without session with app.test_request_context(): with patch("starpunk.auth.redirect") as mock_redirect: with patch("starpunk.auth.url_for") as mock_url_for: mock_url_for.return_value = "/auth/login" protected_route() mock_redirect.assert_called_once() def test_require_auth_with_expired_session(self, app, db, client): """Test require_auth decorator with expired session""" # Create expired session with app.app_context(): token = secrets.token_urlsafe(32) token_hash = _hash_token(token) expires_at = datetime.utcnow() - timedelta(days=1) db.execute( """ INSERT INTO sessions (session_token_hash, me, expires_at) VALUES (?, ?, ?) """, (token_hash, "https://example.com", expires_at), ) db.commit() # Create test route @require_auth def protected_route(): return "Protected content" # Call protected route with expired session environ = {"HTTP_COOKIE": f"session={token}"} with app.test_request_context(environ_base=environ): with patch("starpunk.auth.redirect") as mock_redirect: with patch("starpunk.auth.url_for") as mock_url_for: mock_url_for.return_value = "/auth/login" protected_route() mock_redirect.assert_called_once() class TestSecurityFeatures: def test_token_hashing_prevents_plaintext_storage(self, app, db, client): """Test that tokens are hashed, not stored in plaintext""" with app.test_request_context(): me = "https://example.com" session_token = create_session(me) # Database should not contain plaintext token result = db.execute("SELECT session_token_hash FROM sessions").fetchone() assert result["session_token_hash"] != session_token assert len(result["session_token_hash"]) == 64 # SHA-256 hex length def test_state_tokens_are_single_use(self, app, db): """Test that state tokens can only be used once""" with app.app_context(): state = secrets.token_urlsafe(32) expires_at = datetime.utcnow() + timedelta(minutes=5) db.execute( "INSERT INTO auth_state (state, expires_at) VALUES (?, ?)", (state, expires_at), ) db.commit() # First verification should succeed assert _verify_state_token(state) is True # Second verification should fail (token deleted) assert _verify_state_token(state) is False def test_session_expiry(self, app, db, client): """Test that sessions expire correctly""" with app.test_request_context(): # Create session with custom lifetime app.config["SESSION_LIFETIME"] = 1 # 1 day me = "https://example.com" session_token = create_session(me) token_hash = _hash_token(session_token) result = db.execute( "SELECT expires_at FROM sessions WHERE session_token_hash = ?", (token_hash,), ).fetchone() expires_at = datetime.fromisoformat(result["expires_at"]) created_at = datetime.utcnow() # Should expire approximately 1 day from now # (allow for minor timing differences) delta = expires_at - created_at assert delta.total_seconds() >= 86000 # At least 23.8 hours assert delta.total_seconds() <= 86401 # At most 1 day + 1 second class TestExceptionHierarchy: def test_exception_inheritance(self): """Test that custom exceptions inherit correctly""" assert issubclass(InvalidStateError, AuthError) assert issubclass(UnauthorizedError, AuthError) assert issubclass(IndieLoginError, AuthError) assert issubclass(AuthError, Exception) def test_exception_messages(self): """Test that exceptions can carry messages""" error = InvalidStateError("Test message") assert str(error) == "Test message" error = UnauthorizedError("Unauthorized") assert str(error) == "Unauthorized" error = IndieLoginError("Service error") assert str(error) == "Service error"