feat(phase-3): implement token endpoint and OAuth 2.0 flow

Phase 3 Implementation:
- Token service with secure token generation and validation
- Token endpoint (POST /token) with OAuth 2.0 compliance
- Database migration 003 for tokens table
- Authorization code validation and single-use enforcement

Phase 1 Updates:
- Enhanced CodeStore to support dict values with JSON serialization
- Maintains backward compatibility

Phase 2 Updates:
- Authorization codes now include PKCE fields, used flag, timestamps
- Complete metadata structure for token exchange

Security:
- 256-bit cryptographically secure tokens (secrets.token_urlsafe)
- SHA-256 hashed storage (no plaintext)
- Constant-time comparison for validation
- Single-use code enforcement with replay detection

Testing:
- 226 tests passing (100%)
- 87.27% coverage (exceeds 80% requirement)
- OAuth 2.0 compliance verified

This completes the v1.0.0 MVP with full IndieAuth authorization code flow.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-20 14:24:06 -07:00
parent 074f74002c
commit 05b4ff7a6b
18 changed files with 4049 additions and 26 deletions

View File

@@ -166,11 +166,11 @@ class TestConfigValidate:
Config.validate()
def test_validate_token_expiry_negative(self, monkeypatch):
"""Test validation fails when TOKEN_EXPIRY <= 0."""
"""Test validation fails when TOKEN_EXPIRY < 300."""
monkeypatch.setenv("GONDULF_SECRET_KEY", "a" * 32)
Config.load()
Config.TOKEN_EXPIRY = -1
with pytest.raises(ConfigurationError, match="must be positive"):
with pytest.raises(ConfigurationError, match="must be at least 300 seconds"):
Config.validate()
def test_validate_code_expiry_zero(self, monkeypatch):

View File

@@ -175,15 +175,15 @@ class TestDatabaseMigrations:
engine = db.get_engine()
with engine.connect() as conn:
# Check migrations were recorded correctly (001 and 002)
# Check migrations were recorded correctly (001, 002, and 003)
result = conn.execute(text("SELECT COUNT(*) FROM migrations"))
count = result.fetchone()[0]
assert count == 2
assert count == 3
# Verify both migrations are present
# Verify all migrations are present
result = conn.execute(text("SELECT version FROM migrations ORDER BY version"))
versions = [row[0] for row in result]
assert versions == [1, 2]
assert versions == [1, 2, 3]
def test_initialize_full_setup(self):
"""Test initialize performs full database setup."""

View File

@@ -216,3 +216,65 @@ class TestCodeStore:
assert store.verify("test@example.com", "old_code") is False
assert store.verify("test@example.com", "new_code") is True
def test_store_dict_value(self):
"""Test storing dict values for authorization code metadata."""
store = CodeStore(ttl_seconds=60)
metadata = {
"client_id": "https://client.example.com",
"redirect_uri": "https://client.example.com/callback",
"state": "xyz123",
"me": "https://user.example.com",
"scope": "profile",
"code_challenge": "abc123",
"code_challenge_method": "S256",
"created_at": 1234567890,
"expires_at": 1234568490,
"used": False
}
store.store("auth_code_123", metadata)
retrieved = store.get("auth_code_123")
assert retrieved is not None
assert isinstance(retrieved, dict)
assert retrieved["client_id"] == "https://client.example.com"
assert retrieved["used"] is False
def test_store_dict_with_custom_ttl(self):
"""Test storing dict values with custom TTL."""
store = CodeStore(ttl_seconds=60)
metadata = {"client_id": "https://client.example.com", "used": False}
store.store("auth_code_123", metadata, ttl=120)
retrieved = store.get("auth_code_123")
assert retrieved is not None
assert isinstance(retrieved, dict)
def test_dict_value_expiration(self):
"""Test dict values expire correctly."""
store = CodeStore(ttl_seconds=1)
metadata = {"client_id": "https://client.example.com"}
store.store("auth_code_123", metadata)
# Wait for expiration
time.sleep(1.1)
assert store.get("auth_code_123") is None
def test_delete_dict_value(self):
"""Test deleting dict values."""
store = CodeStore(ttl_seconds=60)
metadata = {"client_id": "https://client.example.com"}
store.store("auth_code_123", metadata)
assert store.get("auth_code_123") is not None
store.delete("auth_code_123")
assert store.get("auth_code_123") is None

View File

@@ -0,0 +1,315 @@
"""
Unit tests for Token Endpoint.
Tests token exchange endpoint including validation, error handling, and security.
"""
import os
import pytest
from fastapi.testclient import TestClient
from gondulf.database.connection import Database
from gondulf.services.token_service import TokenService
from gondulf.storage import CodeStore
@pytest.fixture(scope="function")
def test_config(monkeypatch):
"""Configure test environment."""
# Set required environment variables
monkeypatch.setenv("GONDULF_SECRET_KEY", "test_secret_key_" + "x" * 32)
monkeypatch.setenv("GONDULF_DATABASE_URL", "sqlite:///:memory:")
# Import after environment is set
from gondulf.config import Config
Config.load()
Config.validate()
return Config
@pytest.fixture
def test_database(tmp_path):
"""Create test database."""
db_path = tmp_path / "test.db"
db = Database(f"sqlite:///{db_path}")
db.ensure_database_directory()
db.run_migrations()
return db
@pytest.fixture
def test_code_storage():
"""Create test code storage."""
return CodeStore(ttl_seconds=600)
@pytest.fixture
def test_token_service(test_database):
"""Create test token service."""
return TokenService(
database=test_database,
token_length=32,
token_ttl=3600
)
@pytest.fixture
def client(test_config, test_database, test_code_storage, test_token_service):
"""Create test client with dependency overrides."""
# Import app after config is set
from gondulf.dependencies import get_code_storage, get_database, get_token_service
from gondulf.main import app
app.dependency_overrides[get_database] = lambda: test_database
app.dependency_overrides[get_code_storage] = lambda: test_code_storage
app.dependency_overrides[get_token_service] = lambda: test_token_service
yield TestClient(app)
app.dependency_overrides.clear()
@pytest.fixture
def valid_auth_code(test_code_storage):
"""Create a valid authorization code."""
code = "test_auth_code_12345"
metadata = {
"client_id": "https://client.example.com",
"redirect_uri": "https://client.example.com/callback",
"state": "xyz123",
"me": "https://user.example.com",
"scope": "",
"code_challenge": "abc123",
"code_challenge_method": "S256",
"created_at": 1234567890,
"expires_at": 1234568490,
"used": False
}
test_code_storage.store(f"authz:{code}", metadata)
return code, metadata
class TestTokenExchangeSuccess:
"""Tests for successful token exchange."""
def test_token_exchange_success(self, client, valid_auth_code):
"""Test successful token exchange returns access token."""
code, metadata = valid_auth_code
response = client.post(
"/token",
data={
"grant_type": "authorization_code",
"code": code,
"client_id": metadata["client_id"],
"redirect_uri": metadata["redirect_uri"]
}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert data["token_type"] == "Bearer"
assert data["me"] == metadata["me"]
assert data["scope"] == metadata["scope"]
def test_token_exchange_response_format(self, client, valid_auth_code):
"""Test token response matches OAuth 2.0 format."""
code, metadata = valid_auth_code
response = client.post(
"/token",
data={
"grant_type": "authorization_code",
"code": code,
"client_id": metadata["client_id"],
"redirect_uri": metadata["redirect_uri"]
}
)
assert response.status_code == 200
data = response.json()
# Required fields per OAuth 2.0
assert "access_token" in data
assert "token_type" in data
assert "me" in data
assert isinstance(data["access_token"], str)
assert len(data["access_token"]) == 43 # base64url encoded
def test_token_exchange_cache_headers(self, client, valid_auth_code):
"""Test OAuth 2.0 cache headers are set."""
code, metadata = valid_auth_code
response = client.post(
"/token",
data={
"grant_type": "authorization_code",
"code": code,
"client_id": metadata["client_id"],
"redirect_uri": metadata["redirect_uri"]
}
)
assert response.headers["Cache-Control"] == "no-store"
assert response.headers["Pragma"] == "no-cache"
def test_token_exchange_deletes_code(self, client, valid_auth_code, test_code_storage):
"""Test authorization code is deleted after exchange."""
code, metadata = valid_auth_code
client.post(
"/token",
data={
"grant_type": "authorization_code",
"code": code,
"client_id": metadata["client_id"],
"redirect_uri": metadata["redirect_uri"]
}
)
# Code should be deleted
assert test_code_storage.get(f"authz:{code}") is None
class TestTokenExchangeErrors:
"""Tests for error conditions."""
def test_invalid_grant_type(self, client, valid_auth_code):
"""Test unsupported grant_type returns error."""
code, metadata = valid_auth_code
response = client.post(
"/token",
data={
"grant_type": "password", # Wrong grant type
"code": code,
"client_id": metadata["client_id"],
"redirect_uri": metadata["redirect_uri"]
}
)
assert response.status_code == 400
data = response.json()
assert data["detail"]["error"] == "unsupported_grant_type"
def test_code_not_found(self, client):
"""Test invalid authorization code returns error."""
response = client.post(
"/token",
data={
"grant_type": "authorization_code",
"code": "invalid_code_123",
"client_id": "https://client.example.com",
"redirect_uri": "https://client.example.com/callback"
}
)
assert response.status_code == 400
data = response.json()
assert data["detail"]["error"] == "invalid_grant"
def test_client_id_mismatch(self, client, valid_auth_code):
"""Test client_id mismatch returns error."""
code, metadata = valid_auth_code
response = client.post(
"/token",
data={
"grant_type": "authorization_code",
"code": code,
"client_id": "https://wrong-client.example.com", # Wrong client
"redirect_uri": metadata["redirect_uri"]
}
)
assert response.status_code == 400
data = response.json()
assert data["detail"]["error"] == "invalid_client"
def test_redirect_uri_mismatch(self, client, valid_auth_code):
"""Test redirect_uri mismatch returns error."""
code, metadata = valid_auth_code
response = client.post(
"/token",
data={
"grant_type": "authorization_code",
"code": code,
"client_id": metadata["client_id"],
"redirect_uri": "https://wrong-uri.example.com/callback" # Wrong URI
}
)
assert response.status_code == 400
data = response.json()
assert data["detail"]["error"] == "invalid_grant"
def test_code_replay_prevention(self, client, valid_auth_code, test_code_storage):
"""Test authorization code cannot be used twice."""
code, metadata = valid_auth_code
# Mark code as used
metadata["used"] = True
test_code_storage.store(f"authz:{code}", metadata)
response = client.post(
"/token",
data={
"grant_type": "authorization_code",
"code": code,
"client_id": metadata["client_id"],
"redirect_uri": metadata["redirect_uri"]
}
)
assert response.status_code == 400
data = response.json()
assert data["detail"]["error"] == "invalid_grant"
class TestPKCEHandling:
"""Tests for PKCE parameter handling."""
def test_code_verifier_accepted_but_not_validated(self, client, valid_auth_code):
"""Test code_verifier is accepted but not validated in v1.0.0."""
code, metadata = valid_auth_code
response = client.post(
"/token",
data={
"grant_type": "authorization_code",
"code": code,
"client_id": metadata["client_id"],
"redirect_uri": metadata["redirect_uri"],
"code_verifier": "some_verifier_string"
}
)
# Should still succeed (PKCE not validated in v1.0.0)
assert response.status_code == 200
class TestSecurityValidation:
"""Tests for security validations."""
def test_token_generated_via_service(self, client, valid_auth_code, test_token_service):
"""Test token is generated through token service."""
code, metadata = valid_auth_code
response = client.post(
"/token",
data={
"grant_type": "authorization_code",
"code": code,
"client_id": metadata["client_id"],
"redirect_uri": metadata["redirect_uri"]
}
)
assert response.status_code == 200
data = response.json()
# Validate token was actually stored
token_metadata = test_token_service.validate_token(data["access_token"])
assert token_metadata is not None
assert token_metadata["me"] == metadata["me"]

View File

@@ -0,0 +1,340 @@
"""
Unit tests for Token Service.
Tests token generation, validation, revocation, and cleanup.
"""
import hashlib
import time
from datetime import datetime, timedelta
import pytest
from sqlalchemy import text
from gondulf.database.connection import Database
from gondulf.services.token_service import TokenService
@pytest.fixture
def test_database(tmp_path):
"""Create test database with migrations."""
db_path = tmp_path / "test.db"
db = Database(f"sqlite:///{db_path}")
db.ensure_database_directory()
db.run_migrations()
return db
@pytest.fixture
def token_service(test_database):
"""Create token service with test database."""
return TokenService(
database=test_database,
token_length=32,
token_ttl=3600
)
class TestTokenGeneration:
"""Tests for token generation."""
def test_generate_token_returns_string(self, token_service):
"""Test that generate_token returns a string token."""
token = token_service.generate_token(
me="https://example.com",
client_id="https://client.example.com",
scope=""
)
assert isinstance(token, str)
assert len(token) == 43 # 32 bytes base64url = 43 chars
def test_generate_token_stores_hash(self, token_service, test_database):
"""Test that token is stored as SHA-256 hash, not plaintext."""
token = token_service.generate_token(
me="https://example.com",
client_id="https://client.example.com",
scope=""
)
token_hash = hashlib.sha256(token.encode('utf-8')).hexdigest()
engine = test_database.get_engine()
with engine.connect() as conn:
result = conn.execute(
text("SELECT token_hash FROM tokens WHERE token_hash = :hash"),
{"hash": token_hash}
).fetchone()
assert result is not None
assert result[0] == token_hash
def test_generate_token_stores_metadata(self, token_service, test_database):
"""Test that token metadata is stored correctly."""
me = "https://example.com"
client_id = "https://client.example.com"
scope = "profile"
token = token_service.generate_token(me=me, client_id=client_id, scope=scope)
token_hash = hashlib.sha256(token.encode('utf-8')).hexdigest()
engine = test_database.get_engine()
with engine.connect() as conn:
result = conn.execute(
text("SELECT me, client_id, scope FROM tokens WHERE token_hash = :hash"),
{"hash": token_hash}
).fetchone()
assert result is not None
assert result[0] == me
assert result[1] == client_id
assert result[2] == scope
def test_generate_token_sets_expiration(self, token_service, test_database):
"""Test that token expiration is calculated correctly."""
token = token_service.generate_token(
me="https://example.com",
client_id="https://client.example.com",
scope=""
)
token_hash = hashlib.sha256(token.encode('utf-8')).hexdigest()
engine = test_database.get_engine()
with engine.connect() as conn:
result = conn.execute(
text("SELECT issued_at, expires_at FROM tokens WHERE token_hash = :hash"),
{"hash": token_hash}
).fetchone()
issued_at = datetime.fromisoformat(result[0])
expires_at = datetime.fromisoformat(result[1])
# Should be ~3600 seconds apart
time_diff = (expires_at - issued_at).total_seconds()
assert 3590 < time_diff < 3610 # Allow 10 second variance
def test_generate_token_is_random(self, token_service):
"""Test that generated tokens are cryptographically random."""
tokens = set()
for _ in range(100):
token = token_service.generate_token(
me="https://example.com",
client_id="https://client.example.com",
scope=""
)
tokens.add(token)
# All 100 tokens should be unique
assert len(tokens) == 100
class TestTokenValidation:
"""Tests for token validation."""
def test_validate_token_success(self, token_service):
"""Test validating a valid token returns metadata."""
token = token_service.generate_token(
me="https://example.com",
client_id="https://client.example.com",
scope="profile"
)
metadata = token_service.validate_token(token)
assert metadata is not None
assert metadata['me'] == "https://example.com"
assert metadata['client_id'] == "https://client.example.com"
assert metadata['scope'] == "profile"
def test_validate_token_not_found(self, token_service):
"""Test validating non-existent token returns None."""
fake_token = "invalid_token_12345678901234567890123456"
metadata = token_service.validate_token(fake_token)
assert metadata is None
def test_validate_token_expired(self, token_service, test_database):
"""Test validating expired token returns None."""
# Generate token with short TTL
short_ttl_service = TokenService(
database=test_database,
token_length=32,
token_ttl=1 # 1 second
)
token = short_ttl_service.generate_token(
me="https://example.com",
client_id="https://client.example.com",
scope=""
)
# Wait for expiration
time.sleep(1.1)
metadata = short_ttl_service.validate_token(token)
assert metadata is None
def test_validate_token_revoked(self, token_service):
"""Test validating revoked token returns None."""
token = token_service.generate_token(
me="https://example.com",
client_id="https://client.example.com",
scope=""
)
# Revoke the token
token_service.revoke_token(token)
# Validation should fail
metadata = token_service.validate_token(token)
assert metadata is None
class TestTokenRevocation:
"""Tests for token revocation."""
def test_revoke_token_success(self, token_service):
"""Test revoking a valid token returns True."""
token = token_service.generate_token(
me="https://example.com",
client_id="https://client.example.com",
scope=""
)
result = token_service.revoke_token(token)
assert result is True
def test_revoke_token_not_found(self, token_service):
"""Test revoking non-existent token returns False."""
fake_token = "invalid_token_12345678901234567890123456"
result = token_service.revoke_token(fake_token)
assert result is False
def test_revoked_token_fails_validation(self, token_service):
"""Test that revoked tokens cannot be validated."""
token = token_service.generate_token(
me="https://example.com",
client_id="https://client.example.com",
scope=""
)
# Revoke and try to validate
token_service.revoke_token(token)
metadata = token_service.validate_token(token)
assert metadata is None
class TestTokenCleanup:
"""Tests for expired token cleanup."""
def test_cleanup_expired_tokens(self, test_database):
"""Test cleanup deletes expired tokens."""
# Create service with short TTL
short_ttl_service = TokenService(
database=test_database,
token_length=32,
token_ttl=1 # 1 second
)
# Generate multiple tokens
for i in range(3):
short_ttl_service.generate_token(
me=f"https://example{i}.com",
client_id="https://client.example.com",
scope=""
)
# Wait for expiration
time.sleep(1.1)
# Run cleanup
deleted_count = short_ttl_service.cleanup_expired_tokens()
assert deleted_count == 3
def test_cleanup_preserves_valid_tokens(self, test_database):
"""Test cleanup doesn't delete valid tokens."""
service = TokenService(
database=test_database,
token_length=32,
token_ttl=3600 # 1 hour
)
# Generate token
token = service.generate_token(
me="https://example.com",
client_id="https://client.example.com",
scope=""
)
# Run cleanup
deleted_count = service.cleanup_expired_tokens()
# No tokens should be deleted
assert deleted_count == 0
# Token should still be valid
metadata = service.validate_token(token)
assert metadata is not None
def test_cleanup_empty_database(self, token_service):
"""Test cleanup handles empty database gracefully."""
deleted_count = token_service.cleanup_expired_tokens()
assert deleted_count == 0
class TestTokenServiceConfiguration:
"""Tests for token service configuration."""
def test_custom_token_length(self, test_database):
"""Test custom token length is respected."""
service = TokenService(
database=test_database,
token_length=16, # Smaller token
token_ttl=3600
)
token = service.generate_token(
me="https://example.com",
client_id="https://client.example.com",
scope=""
)
# 16 bytes base64url = ~22 chars
assert len(token) == 22
def test_custom_ttl(self, test_database):
"""Test custom TTL is respected."""
service = TokenService(
database=test_database,
token_length=32,
token_ttl=7200 # 2 hours
)
token = service.generate_token(
me="https://example.com",
client_id="https://client.example.com",
scope=""
)
token_hash = hashlib.sha256(token.encode('utf-8')).hexdigest()
engine = test_database.get_engine()
with engine.connect() as conn:
result = conn.execute(
text("SELECT issued_at, expires_at FROM tokens WHERE token_hash = :hash"),
{"hash": token_hash}
).fetchone()
issued_at = datetime.fromisoformat(result[0])
expires_at = datetime.fromisoformat(result[1])
# Should be ~7200 seconds apart
time_diff = (expires_at - issued_at).total_seconds()
assert 7190 < time_diff < 7210