feat(phase-2): implement domain verification system
Implements complete domain verification flow with: - rel=me link verification service - HTML fetching with security controls - Rate limiting to prevent abuse - Email validation utilities - Authorization and verification API endpoints - User-facing templates for authorization and verification flows This completes Phase 2: Domain Verification as designed. Tests: - All Phase 2 unit tests passing - Coverage: 85% overall - Migration tests updated 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
236
tests/unit/test_domain_verification.py
Normal file
236
tests/unit/test_domain_verification.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""Tests for domain verification service."""
|
||||
import pytest
|
||||
from unittest.mock import Mock, MagicMock
|
||||
|
||||
from gondulf.services.domain_verification import DomainVerificationService
|
||||
from gondulf.dns import DNSService
|
||||
from gondulf.email import EmailService
|
||||
from gondulf.storage import CodeStore
|
||||
from gondulf.services.html_fetcher import HTMLFetcherService
|
||||
from gondulf.services.relme_parser import RelMeParser
|
||||
|
||||
|
||||
class TestDomainVerificationService:
|
||||
"""Tests for DomainVerificationService."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dns(self):
|
||||
"""Mock DNS service."""
|
||||
return Mock(spec=DNSService)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_email(self):
|
||||
"""Mock email service."""
|
||||
return Mock(spec=EmailService)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage(self):
|
||||
"""Mock code storage."""
|
||||
return Mock(spec=CodeStore)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_fetcher(self):
|
||||
"""Mock HTML fetcher."""
|
||||
return Mock(spec=HTMLFetcherService)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_parser(self):
|
||||
"""Mock rel=me parser."""
|
||||
return Mock(spec=RelMeParser)
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_dns, mock_email, mock_storage, mock_fetcher, mock_parser):
|
||||
"""Create domain verification service with mocks."""
|
||||
return DomainVerificationService(
|
||||
dns_service=mock_dns,
|
||||
email_service=mock_email,
|
||||
code_storage=mock_storage,
|
||||
html_fetcher=mock_fetcher,
|
||||
relme_parser=mock_parser
|
||||
)
|
||||
|
||||
def test_generate_verification_code(self, service):
|
||||
"""Test verification code generation."""
|
||||
code = service.generate_verification_code()
|
||||
assert isinstance(code, str)
|
||||
assert len(code) == 6
|
||||
assert code.isdigit()
|
||||
|
||||
def test_generate_verification_code_unique(self, service):
|
||||
"""Test that generated codes are different."""
|
||||
code1 = service.generate_verification_code()
|
||||
code2 = service.generate_verification_code()
|
||||
# Very unlikely to be the same, but possible
|
||||
# Just check they're both valid
|
||||
assert code1.isdigit()
|
||||
assert code2.isdigit()
|
||||
|
||||
def test_start_verification_dns_fails(self, service, mock_dns):
|
||||
"""Test start_verification when DNS verification fails."""
|
||||
mock_dns.verify_txt_record.return_value = False
|
||||
|
||||
result = service.start_verification("example.com", "https://example.com/")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["error"] == "dns_verification_failed"
|
||||
|
||||
def test_start_verification_email_discovery_fails(
|
||||
self, service, mock_dns, mock_fetcher, mock_parser
|
||||
):
|
||||
"""Test start_verification when email discovery fails."""
|
||||
mock_dns.verify_txt_record.return_value = True
|
||||
mock_fetcher.fetch.return_value = "<html></html>"
|
||||
mock_parser.find_email.return_value = None
|
||||
|
||||
result = service.start_verification("example.com", "https://example.com/")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["error"] == "email_discovery_failed"
|
||||
|
||||
def test_start_verification_invalid_email_format(
|
||||
self, service, mock_dns, mock_fetcher, mock_parser
|
||||
):
|
||||
"""Test start_verification with invalid email format."""
|
||||
mock_dns.verify_txt_record.return_value = True
|
||||
mock_fetcher.fetch.return_value = "<html></html>"
|
||||
mock_parser.find_email.return_value = "not-an-email"
|
||||
|
||||
result = service.start_verification("example.com", "https://example.com/")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["error"] == "invalid_email_format"
|
||||
|
||||
def test_start_verification_email_send_fails(
|
||||
self, service, mock_dns, mock_fetcher, mock_parser, mock_email
|
||||
):
|
||||
"""Test start_verification when email sending fails."""
|
||||
mock_dns.verify_txt_record.return_value = True
|
||||
mock_fetcher.fetch.return_value = "<html></html>"
|
||||
mock_parser.find_email.return_value = "user@example.com"
|
||||
mock_email.send_verification_code.side_effect = Exception("SMTP error")
|
||||
|
||||
result = service.start_verification("example.com", "https://example.com/")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["error"] == "email_send_failed"
|
||||
|
||||
def test_start_verification_success(
|
||||
self, service, mock_dns, mock_fetcher, mock_parser, mock_email, mock_storage
|
||||
):
|
||||
"""Test successful verification start."""
|
||||
mock_dns.verify_txt_record.return_value = True
|
||||
mock_fetcher.fetch.return_value = "<html></html>"
|
||||
mock_parser.find_email.return_value = "user@example.com"
|
||||
|
||||
result = service.start_verification("example.com", "https://example.com/")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["email"] == "u***@example.com" # Masked
|
||||
assert result["verification_method"] == "email"
|
||||
mock_email.send_verification_code.assert_called_once()
|
||||
assert mock_storage.store.call_count == 2 # Code and email stored
|
||||
|
||||
def test_verify_email_code_invalid(self, service, mock_storage):
|
||||
"""Test verify_email_code with invalid code."""
|
||||
mock_storage.verify.return_value = False
|
||||
|
||||
result = service.verify_email_code("example.com", "123456")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["error"] == "invalid_code"
|
||||
|
||||
def test_verify_email_code_email_not_found(self, service, mock_storage):
|
||||
"""Test verify_email_code when email not in storage."""
|
||||
mock_storage.verify.return_value = True
|
||||
mock_storage.get.return_value = None
|
||||
|
||||
result = service.verify_email_code("example.com", "123456")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["error"] == "email_not_found"
|
||||
|
||||
def test_verify_email_code_success(self, service, mock_storage):
|
||||
"""Test successful email code verification."""
|
||||
mock_storage.verify.return_value = True
|
||||
mock_storage.get.return_value = "user@example.com"
|
||||
|
||||
result = service.verify_email_code("example.com", "123456")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["email"] == "user@example.com"
|
||||
mock_storage.delete.assert_called_once()
|
||||
|
||||
def test_create_authorization_code(self, service, mock_storage):
|
||||
"""Test authorization code creation."""
|
||||
code = service.create_authorization_code(
|
||||
client_id="https://client.example.com/",
|
||||
redirect_uri="https://client.example.com/callback",
|
||||
state="test_state",
|
||||
code_challenge="challenge",
|
||||
code_challenge_method="S256",
|
||||
scope="profile",
|
||||
me="https://user.example.com/"
|
||||
)
|
||||
|
||||
assert isinstance(code, str)
|
||||
assert len(code) > 0
|
||||
mock_storage.store.assert_called_once()
|
||||
|
||||
def test_verify_dns_record_success(self, service, mock_dns):
|
||||
"""Test DNS record verification success."""
|
||||
mock_dns.verify_txt_record.return_value = True
|
||||
|
||||
result = service._verify_dns_record("example.com")
|
||||
|
||||
assert result is True
|
||||
mock_dns.verify_txt_record.assert_called_with("example.com", "gondulf-verify-domain")
|
||||
|
||||
def test_verify_dns_record_failure(self, service, mock_dns):
|
||||
"""Test DNS record verification failure."""
|
||||
mock_dns.verify_txt_record.return_value = False
|
||||
|
||||
result = service._verify_dns_record("example.com")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_verify_dns_record_exception(self, service, mock_dns):
|
||||
"""Test DNS record verification handles exceptions."""
|
||||
mock_dns.verify_txt_record.side_effect = Exception("DNS error")
|
||||
|
||||
result = service._verify_dns_record("example.com")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_discover_email_success(self, service, mock_fetcher, mock_parser):
|
||||
"""Test email discovery success."""
|
||||
mock_fetcher.fetch.return_value = "<html></html>"
|
||||
mock_parser.find_email.return_value = "user@example.com"
|
||||
|
||||
email = service._discover_email("https://example.com/")
|
||||
|
||||
assert email == "user@example.com"
|
||||
|
||||
def test_discover_email_fetch_fails(self, service, mock_fetcher):
|
||||
"""Test email discovery when fetch fails."""
|
||||
mock_fetcher.fetch.return_value = None
|
||||
|
||||
email = service._discover_email("https://example.com/")
|
||||
|
||||
assert email is None
|
||||
|
||||
def test_discover_email_no_email_found(self, service, mock_fetcher, mock_parser):
|
||||
"""Test email discovery when no email found."""
|
||||
mock_fetcher.fetch.return_value = "<html></html>"
|
||||
mock_parser.find_email.return_value = None
|
||||
|
||||
email = service._discover_email("https://example.com/")
|
||||
|
||||
assert email is None
|
||||
|
||||
def test_discover_email_exception(self, service, mock_fetcher):
|
||||
"""Test email discovery handles exceptions."""
|
||||
mock_fetcher.fetch.side_effect = Exception("Fetch error")
|
||||
|
||||
email = service._discover_email("https://example.com/")
|
||||
|
||||
assert email is None
|
||||
175
tests/unit/test_html_fetcher.py
Normal file
175
tests/unit/test_html_fetcher.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Tests for HTML fetcher service."""
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from urllib.error import URLError, HTTPError
|
||||
|
||||
from gondulf.services.html_fetcher import HTMLFetcherService
|
||||
|
||||
|
||||
class TestHTMLFetcherService:
|
||||
"""Tests for HTMLFetcherService."""
|
||||
|
||||
def test_init_default_params(self):
|
||||
"""Test initialization with default parameters."""
|
||||
fetcher = HTMLFetcherService()
|
||||
assert fetcher.timeout == 10
|
||||
assert fetcher.max_size == 1024 * 1024
|
||||
assert fetcher.max_redirects == 5
|
||||
assert "Gondulf" in fetcher.user_agent
|
||||
|
||||
def test_init_custom_params(self):
|
||||
"""Test initialization with custom parameters."""
|
||||
fetcher = HTMLFetcherService(
|
||||
timeout=5,
|
||||
max_size=512 * 1024,
|
||||
max_redirects=3,
|
||||
user_agent="TestAgent/1.0"
|
||||
)
|
||||
assert fetcher.timeout == 5
|
||||
assert fetcher.max_size == 512 * 1024
|
||||
assert fetcher.max_redirects == 3
|
||||
assert fetcher.user_agent == "TestAgent/1.0"
|
||||
|
||||
def test_fetch_requires_https(self):
|
||||
"""Test that fetch requires HTTPS URLs."""
|
||||
fetcher = HTMLFetcherService()
|
||||
with pytest.raises(ValueError, match="must use HTTPS"):
|
||||
fetcher.fetch("http://example.com/")
|
||||
|
||||
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
|
||||
def test_fetch_success(self, mock_urlopen):
|
||||
"""Test successful HTML fetch."""
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b"<html><body>Test</body></html>"
|
||||
mock_response.headers.get_content_charset.return_value = "utf-8"
|
||||
mock_response.headers.get.return_value = None # No Content-Length header
|
||||
mock_response.__enter__.return_value = mock_response
|
||||
mock_response.__exit__.return_value = None
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
fetcher = HTMLFetcherService()
|
||||
html = fetcher.fetch("https://example.com/")
|
||||
|
||||
assert html == "<html><body>Test</body></html>"
|
||||
mock_urlopen.assert_called_once()
|
||||
|
||||
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
|
||||
def test_fetch_respects_timeout(self, mock_urlopen):
|
||||
"""Test that fetch respects timeout parameter."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b"<html></html>"
|
||||
mock_response.headers.get_content_charset.return_value = "utf-8"
|
||||
mock_response.headers.get.return_value = None
|
||||
mock_response.__enter__.return_value = mock_response
|
||||
mock_response.__exit__.return_value = None
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
fetcher = HTMLFetcherService(timeout=15)
|
||||
fetcher.fetch("https://example.com/")
|
||||
|
||||
call_kwargs = mock_urlopen.call_args[1]
|
||||
assert call_kwargs['timeout'] == 15
|
||||
|
||||
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
|
||||
def test_fetch_content_length_too_large(self, mock_urlopen):
|
||||
"""Test that fetch returns None if Content-Length exceeds max_size."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.headers.get.return_value = str(2 * 1024 * 1024) # 2MB
|
||||
mock_response.__enter__.return_value = mock_response
|
||||
mock_response.__exit__.return_value = None
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
fetcher = HTMLFetcherService(max_size=1024 * 1024) # 1MB max
|
||||
html = fetcher.fetch("https://example.com/")
|
||||
|
||||
assert html is None
|
||||
|
||||
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
|
||||
def test_fetch_response_too_large(self, mock_urlopen):
|
||||
"""Test that fetch returns None if response exceeds max_size."""
|
||||
# Create response larger than max_size
|
||||
large_content = b"x" * (1024 * 1024 + 1) # 1MB + 1 byte
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = large_content
|
||||
mock_response.headers.get_content_charset.return_value = "utf-8"
|
||||
mock_response.headers.get.return_value = None
|
||||
mock_response.__enter__.return_value = mock_response
|
||||
mock_response.__exit__.return_value = None
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
fetcher = HTMLFetcherService(max_size=1024 * 1024)
|
||||
html = fetcher.fetch("https://example.com/")
|
||||
|
||||
assert html is None
|
||||
|
||||
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
|
||||
def test_fetch_url_error(self, mock_urlopen):
|
||||
"""Test that fetch returns None on URLError."""
|
||||
mock_urlopen.side_effect = URLError("Connection failed")
|
||||
|
||||
fetcher = HTMLFetcherService()
|
||||
html = fetcher.fetch("https://example.com/")
|
||||
|
||||
assert html is None
|
||||
|
||||
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
|
||||
def test_fetch_http_error(self, mock_urlopen):
|
||||
"""Test that fetch returns None on HTTPError."""
|
||||
mock_urlopen.side_effect = HTTPError(
|
||||
"https://example.com/",
|
||||
404,
|
||||
"Not Found",
|
||||
{},
|
||||
None
|
||||
)
|
||||
|
||||
fetcher = HTMLFetcherService()
|
||||
html = fetcher.fetch("https://example.com/")
|
||||
|
||||
assert html is None
|
||||
|
||||
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
|
||||
def test_fetch_timeout_error(self, mock_urlopen):
|
||||
"""Test that fetch returns None on timeout."""
|
||||
mock_urlopen.side_effect = TimeoutError("Request timed out")
|
||||
|
||||
fetcher = HTMLFetcherService()
|
||||
html = fetcher.fetch("https://example.com/")
|
||||
|
||||
assert html is None
|
||||
|
||||
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
|
||||
def test_fetch_unicode_decode_error(self, mock_urlopen):
|
||||
"""Test that fetch returns None on Unicode decode error."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b"\xff\xfe" # Invalid UTF-8
|
||||
mock_response.headers.get_content_charset.return_value = "utf-8"
|
||||
mock_response.headers.get.return_value = None
|
||||
mock_response.__enter__.return_value = mock_response
|
||||
mock_response.__exit__.return_value = None
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
fetcher = HTMLFetcherService()
|
||||
# Should use 'replace' error handling and return a string
|
||||
html = fetcher.fetch("https://example.com/")
|
||||
|
||||
assert html is not None # Should not fail, uses error='replace'
|
||||
|
||||
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
|
||||
def test_fetch_sets_user_agent(self, mock_urlopen):
|
||||
"""Test that fetch sets User-Agent header."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b"<html></html>"
|
||||
mock_response.headers.get_content_charset.return_value = "utf-8"
|
||||
mock_response.headers.get.return_value = None
|
||||
mock_response.__enter__.return_value = mock_response
|
||||
mock_response.__exit__.return_value = None
|
||||
mock_urlopen.return_value = mock_response
|
||||
|
||||
fetcher = HTMLFetcherService(user_agent="CustomAgent/2.0")
|
||||
fetcher.fetch("https://example.com/")
|
||||
|
||||
# Check that User-Agent header was set
|
||||
request = mock_urlopen.call_args[0][0]
|
||||
assert request.get_header('User-agent') == "CustomAgent/2.0"
|
||||
171
tests/unit/test_rate_limiter.py
Normal file
171
tests/unit/test_rate_limiter.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Tests for rate limiter service."""
|
||||
import pytest
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
from gondulf.services.rate_limiter import RateLimiter
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
"""Tests for RateLimiter."""
|
||||
|
||||
def test_init_default_params(self):
|
||||
"""Test initialization with default parameters."""
|
||||
limiter = RateLimiter()
|
||||
assert limiter.max_attempts == 3
|
||||
assert limiter.window_seconds == 3600
|
||||
|
||||
def test_init_custom_params(self):
|
||||
"""Test initialization with custom parameters."""
|
||||
limiter = RateLimiter(max_attempts=5, window_hours=2)
|
||||
assert limiter.max_attempts == 5
|
||||
assert limiter.window_seconds == 7200
|
||||
|
||||
def test_check_rate_limit_no_attempts(self):
|
||||
"""Test rate limit check with no previous attempts."""
|
||||
limiter = RateLimiter()
|
||||
assert limiter.check_rate_limit("example.com") is True
|
||||
|
||||
def test_check_rate_limit_within_limit(self):
|
||||
"""Test rate limit check within limit."""
|
||||
limiter = RateLimiter(max_attempts=3)
|
||||
limiter.record_attempt("example.com")
|
||||
limiter.record_attempt("example.com")
|
||||
|
||||
assert limiter.check_rate_limit("example.com") is True
|
||||
|
||||
def test_check_rate_limit_at_limit(self):
|
||||
"""Test rate limit check at exact limit."""
|
||||
limiter = RateLimiter(max_attempts=3)
|
||||
limiter.record_attempt("example.com")
|
||||
limiter.record_attempt("example.com")
|
||||
limiter.record_attempt("example.com")
|
||||
|
||||
assert limiter.check_rate_limit("example.com") is False
|
||||
|
||||
def test_check_rate_limit_exceeded(self):
|
||||
"""Test rate limit check when exceeded."""
|
||||
limiter = RateLimiter(max_attempts=2)
|
||||
limiter.record_attempt("example.com")
|
||||
limiter.record_attempt("example.com")
|
||||
|
||||
assert limiter.check_rate_limit("example.com") is False
|
||||
|
||||
def test_record_attempt_creates_entry(self):
|
||||
"""Test that record_attempt creates new entry."""
|
||||
limiter = RateLimiter()
|
||||
limiter.record_attempt("example.com")
|
||||
|
||||
assert "example.com" in limiter._attempts
|
||||
assert len(limiter._attempts["example.com"]) == 1
|
||||
|
||||
def test_record_attempt_appends_to_existing(self):
|
||||
"""Test that record_attempt appends to existing entry."""
|
||||
limiter = RateLimiter()
|
||||
limiter.record_attempt("example.com")
|
||||
limiter.record_attempt("example.com")
|
||||
|
||||
assert len(limiter._attempts["example.com"]) == 2
|
||||
|
||||
def test_clean_old_attempts_removes_expired(self):
|
||||
"""Test that old attempts are cleaned up."""
|
||||
limiter = RateLimiter(max_attempts=3, window_hours=1)
|
||||
|
||||
# Mock time to control timestamps
|
||||
with patch('time.time', return_value=1000):
|
||||
limiter.record_attempt("example.com")
|
||||
|
||||
# Move time forward past window
|
||||
with patch('time.time', return_value=1000 + 3700): # 1 hour + 100 seconds
|
||||
limiter._clean_old_attempts("example.com")
|
||||
|
||||
assert "example.com" not in limiter._attempts
|
||||
|
||||
def test_clean_old_attempts_preserves_recent(self):
|
||||
"""Test that recent attempts are preserved."""
|
||||
limiter = RateLimiter(max_attempts=3, window_hours=1)
|
||||
|
||||
with patch('time.time', return_value=1000):
|
||||
limiter.record_attempt("example.com")
|
||||
|
||||
# Move time forward but still within window
|
||||
with patch('time.time', return_value=1000 + 1800): # 30 minutes
|
||||
limiter._clean_old_attempts("example.com")
|
||||
|
||||
assert "example.com" in limiter._attempts
|
||||
assert len(limiter._attempts["example.com"]) == 1
|
||||
|
||||
def test_check_rate_limit_cleans_old_attempts(self):
|
||||
"""Test that check_rate_limit cleans old attempts."""
|
||||
limiter = RateLimiter(max_attempts=2, window_hours=1)
|
||||
|
||||
# Record attempts at time 1000
|
||||
with patch('time.time', return_value=1000):
|
||||
limiter.record_attempt("example.com")
|
||||
limiter.record_attempt("example.com")
|
||||
|
||||
# Check limit should be False
|
||||
with patch('time.time', return_value=1000):
|
||||
assert limiter.check_rate_limit("example.com") is False
|
||||
|
||||
# Move time forward past window
|
||||
with patch('time.time', return_value=1000 + 3700):
|
||||
# Old attempts should be cleaned, limit should pass
|
||||
assert limiter.check_rate_limit("example.com") is True
|
||||
|
||||
def test_different_domains_independent(self):
|
||||
"""Test that different domains have independent limits."""
|
||||
limiter = RateLimiter(max_attempts=2)
|
||||
|
||||
limiter.record_attempt("example.com")
|
||||
limiter.record_attempt("example.com")
|
||||
limiter.record_attempt("other.com")
|
||||
|
||||
assert limiter.check_rate_limit("example.com") is False
|
||||
assert limiter.check_rate_limit("other.com") is True
|
||||
|
||||
def test_get_remaining_attempts_initial(self):
|
||||
"""Test getting remaining attempts initially."""
|
||||
limiter = RateLimiter(max_attempts=3)
|
||||
assert limiter.get_remaining_attempts("example.com") == 3
|
||||
|
||||
def test_get_remaining_attempts_after_one(self):
|
||||
"""Test getting remaining attempts after one attempt."""
|
||||
limiter = RateLimiter(max_attempts=3)
|
||||
limiter.record_attempt("example.com")
|
||||
assert limiter.get_remaining_attempts("example.com") == 2
|
||||
|
||||
def test_get_remaining_attempts_exhausted(self):
|
||||
"""Test getting remaining attempts when exhausted."""
|
||||
limiter = RateLimiter(max_attempts=3)
|
||||
limiter.record_attempt("example.com")
|
||||
limiter.record_attempt("example.com")
|
||||
limiter.record_attempt("example.com")
|
||||
assert limiter.get_remaining_attempts("example.com") == 0
|
||||
|
||||
def test_get_reset_time_no_attempts(self):
|
||||
"""Test getting reset time with no attempts."""
|
||||
limiter = RateLimiter()
|
||||
assert limiter.get_reset_time("example.com") == 0
|
||||
|
||||
def test_get_reset_time_with_attempts(self):
|
||||
"""Test getting reset time with attempts."""
|
||||
limiter = RateLimiter(window_hours=1)
|
||||
|
||||
with patch('time.time', return_value=1000):
|
||||
limiter.record_attempt("example.com")
|
||||
reset_time = limiter.get_reset_time("example.com")
|
||||
assert reset_time == 1000 + 3600
|
||||
|
||||
def test_get_reset_time_multiple_attempts(self):
|
||||
"""Test getting reset time with multiple attempts (returns oldest)."""
|
||||
limiter = RateLimiter(window_hours=1)
|
||||
|
||||
with patch('time.time', return_value=1000):
|
||||
limiter.record_attempt("example.com")
|
||||
|
||||
with patch('time.time', return_value=2000):
|
||||
limiter.record_attempt("example.com")
|
||||
# Reset time should be based on oldest attempt
|
||||
reset_time = limiter.get_reset_time("example.com")
|
||||
assert reset_time == 1000 + 3600
|
||||
181
tests/unit/test_relme_parser.py
Normal file
181
tests/unit/test_relme_parser.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Tests for rel=me parser service."""
|
||||
import pytest
|
||||
|
||||
from gondulf.services.relme_parser import RelMeParser
|
||||
|
||||
|
||||
class TestRelMeParser:
|
||||
"""Tests for RelMeParser."""
|
||||
|
||||
def test_parse_relme_links_basic(self):
|
||||
"""Test parsing basic rel=me links."""
|
||||
html = """
|
||||
<html>
|
||||
<body>
|
||||
<a rel="me" href="https://github.com/user">GitHub</a>
|
||||
<a rel="me" href="mailto:user@example.com">Email</a>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
parser = RelMeParser()
|
||||
links = parser.parse_relme_links(html)
|
||||
|
||||
assert len(links) == 2
|
||||
assert "https://github.com/user" in links
|
||||
assert "mailto:user@example.com" in links
|
||||
|
||||
def test_parse_relme_links_link_tag(self):
|
||||
"""Test parsing rel=me from <link> tags."""
|
||||
html = """
|
||||
<html>
|
||||
<head>
|
||||
<link rel="me" href="https://twitter.com/user">
|
||||
</head>
|
||||
</html>
|
||||
"""
|
||||
parser = RelMeParser()
|
||||
links = parser.parse_relme_links(html)
|
||||
|
||||
assert len(links) == 1
|
||||
assert "https://twitter.com/user" in links
|
||||
|
||||
def test_parse_relme_links_no_rel_me(self):
|
||||
"""Test parsing HTML with no rel=me links."""
|
||||
html = """
|
||||
<html>
|
||||
<body>
|
||||
<a href="https://example.com">Link</a>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
parser = RelMeParser()
|
||||
links = parser.parse_relme_links(html)
|
||||
|
||||
assert len(links) == 0
|
||||
|
||||
def test_parse_relme_links_no_href(self):
|
||||
"""Test parsing rel=me link without href."""
|
||||
html = """
|
||||
<html>
|
||||
<body>
|
||||
<a rel="me">No href</a>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
parser = RelMeParser()
|
||||
links = parser.parse_relme_links(html)
|
||||
|
||||
assert len(links) == 0
|
||||
|
||||
def test_parse_relme_links_malformed_html(self):
|
||||
"""Test parsing malformed HTML returns empty list."""
|
||||
html = "<html><body><<>>broken"
|
||||
parser = RelMeParser()
|
||||
links = parser.parse_relme_links(html)
|
||||
|
||||
# Should not crash, returns what it can parse
|
||||
assert isinstance(links, list)
|
||||
|
||||
def test_extract_mailto_email_basic(self):
|
||||
"""Test extracting email from mailto: link."""
|
||||
links = ["mailto:user@example.com"]
|
||||
parser = RelMeParser()
|
||||
email = parser.extract_mailto_email(links)
|
||||
|
||||
assert email == "user@example.com"
|
||||
|
||||
def test_extract_mailto_email_with_query(self):
|
||||
"""Test extracting email from mailto: link with query parameters."""
|
||||
links = ["mailto:user@example.com?subject=Hello"]
|
||||
parser = RelMeParser()
|
||||
email = parser.extract_mailto_email(links)
|
||||
|
||||
assert email == "user@example.com"
|
||||
|
||||
def test_extract_mailto_email_multiple_links(self):
|
||||
"""Test extracting email from multiple links (returns first mailto:)."""
|
||||
links = [
|
||||
"https://github.com/user",
|
||||
"mailto:user@example.com",
|
||||
"mailto:other@example.com"
|
||||
]
|
||||
parser = RelMeParser()
|
||||
email = parser.extract_mailto_email(links)
|
||||
|
||||
assert email == "user@example.com"
|
||||
|
||||
def test_extract_mailto_email_no_mailto(self):
|
||||
"""Test extracting email when no mailto: links present."""
|
||||
links = ["https://github.com/user", "https://twitter.com/user"]
|
||||
parser = RelMeParser()
|
||||
email = parser.extract_mailto_email(links)
|
||||
|
||||
assert email is None
|
||||
|
||||
def test_extract_mailto_email_invalid_format(self):
|
||||
"""Test extracting email from malformed mailto: link."""
|
||||
links = ["mailto:notanemail"]
|
||||
parser = RelMeParser()
|
||||
email = parser.extract_mailto_email(links)
|
||||
|
||||
# Should return None for invalid email format
|
||||
assert email is None
|
||||
|
||||
def test_extract_mailto_email_empty_list(self):
|
||||
"""Test extracting email from empty list."""
|
||||
parser = RelMeParser()
|
||||
email = parser.extract_mailto_email([])
|
||||
|
||||
assert email is None
|
||||
|
||||
def test_find_email_success(self):
|
||||
"""Test find_email combining parse and extract."""
|
||||
html = """
|
||||
<html>
|
||||
<body>
|
||||
<a rel="me" href="https://github.com/user">GitHub</a>
|
||||
<a rel="me" href="mailto:user@example.com">Email</a>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
parser = RelMeParser()
|
||||
email = parser.find_email(html)
|
||||
|
||||
assert email == "user@example.com"
|
||||
|
||||
def test_find_email_no_email(self):
|
||||
"""Test find_email when no email present."""
|
||||
html = """
|
||||
<html>
|
||||
<body>
|
||||
<a rel="me" href="https://github.com/user">GitHub</a>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
parser = RelMeParser()
|
||||
email = parser.find_email(html)
|
||||
|
||||
assert email is None
|
||||
|
||||
def test_find_email_malformed_html(self):
|
||||
"""Test find_email with malformed HTML."""
|
||||
html = "<html><<broken>>"
|
||||
parser = RelMeParser()
|
||||
email = parser.find_email(html)
|
||||
|
||||
assert email is None
|
||||
|
||||
def test_parse_relme_multiple_rel_values(self):
|
||||
"""Test parsing link with multiple rel values including 'me'."""
|
||||
html = """
|
||||
<html>
|
||||
<body>
|
||||
<a rel="me nofollow" href="https://example.com">Link</a>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
parser = RelMeParser()
|
||||
links = parser.parse_relme_links(html)
|
||||
|
||||
assert len(links) == 1
|
||||
assert "https://example.com" in links
|
||||
199
tests/unit/test_validation.py
Normal file
199
tests/unit/test_validation.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Tests for validation utilities."""
|
||||
import pytest
|
||||
|
||||
from gondulf.utils.validation import (
|
||||
mask_email,
|
||||
normalize_client_id,
|
||||
validate_redirect_uri,
|
||||
extract_domain_from_url,
|
||||
validate_email
|
||||
)
|
||||
|
||||
|
||||
class TestMaskEmail:
|
||||
"""Tests for mask_email function."""
|
||||
|
||||
def test_mask_email_basic(self):
|
||||
"""Test basic email masking."""
|
||||
assert mask_email("user@example.com") == "u***@example.com"
|
||||
|
||||
def test_mask_email_long_local(self):
|
||||
"""Test masking email with long local part."""
|
||||
assert mask_email("verylongusername@example.com") == "v***@example.com"
|
||||
|
||||
def test_mask_email_single_char_local(self):
|
||||
"""Test masking email with single character local part."""
|
||||
# Should return unchanged if local part is only 1 character
|
||||
assert mask_email("a@example.com") == "a@example.com"
|
||||
|
||||
def test_mask_email_no_at_sign(self):
|
||||
"""Test masking invalid email without @ sign."""
|
||||
assert mask_email("notanemail") == "notanemail"
|
||||
|
||||
def test_mask_email_empty_string(self):
|
||||
"""Test masking empty string."""
|
||||
assert mask_email("") == ""
|
||||
|
||||
|
||||
class TestNormalizeClientId:
|
||||
"""Tests for normalize_client_id function."""
|
||||
|
||||
def test_normalize_basic_https(self):
|
||||
"""Test normalizing basic HTTPS URL."""
|
||||
assert normalize_client_id("https://example.com/") == "https://example.com/"
|
||||
|
||||
def test_normalize_remove_default_port(self):
|
||||
"""Test normalizing URL with default HTTPS port."""
|
||||
assert normalize_client_id("https://example.com:443/") == "https://example.com/"
|
||||
|
||||
def test_normalize_preserve_non_default_port(self):
|
||||
"""Test normalizing URL with non-default port."""
|
||||
assert normalize_client_id("https://example.com:8443/") == "https://example.com:8443/"
|
||||
|
||||
def test_normalize_preserve_path(self):
|
||||
"""Test normalizing URL with path."""
|
||||
assert normalize_client_id("https://example.com/app") == "https://example.com/app"
|
||||
|
||||
def test_normalize_preserve_query(self):
|
||||
"""Test normalizing URL with query string."""
|
||||
assert normalize_client_id("https://example.com/?foo=bar") == "https://example.com/?foo=bar"
|
||||
|
||||
def test_normalize_http_scheme_raises_error(self):
|
||||
"""Test that HTTP scheme raises ValueError."""
|
||||
with pytest.raises(ValueError, match="must use https scheme"):
|
||||
normalize_client_id("http://example.com/")
|
||||
|
||||
def test_normalize_no_scheme_raises_error(self):
|
||||
"""Test that missing scheme raises ValueError."""
|
||||
with pytest.raises(ValueError, match="must use https scheme"):
|
||||
normalize_client_id("example.com")
|
||||
|
||||
|
||||
class TestValidateRedirectUri:
|
||||
"""Tests for validate_redirect_uri function."""
|
||||
|
||||
def test_validate_same_origin(self):
|
||||
"""Test redirect URI with same origin as client_id."""
|
||||
assert validate_redirect_uri(
|
||||
"https://example.com/callback",
|
||||
"https://example.com/"
|
||||
) is True
|
||||
|
||||
def test_validate_different_path_same_origin(self):
|
||||
"""Test redirect URI with different path but same origin."""
|
||||
assert validate_redirect_uri(
|
||||
"https://example.com/auth/callback",
|
||||
"https://example.com/"
|
||||
) is True
|
||||
|
||||
def test_validate_subdomain(self):
|
||||
"""Test redirect URI on subdomain of client_id."""
|
||||
assert validate_redirect_uri(
|
||||
"https://app.example.com/callback",
|
||||
"https://example.com/"
|
||||
) is True
|
||||
|
||||
def test_validate_different_domain_fails(self):
|
||||
"""Test redirect URI on completely different domain fails."""
|
||||
assert validate_redirect_uri(
|
||||
"https://evil.com/callback",
|
||||
"https://example.com/"
|
||||
) is False
|
||||
|
||||
def test_validate_localhost_http_allowed(self):
|
||||
"""Test that localhost can use HTTP."""
|
||||
assert validate_redirect_uri(
|
||||
"http://localhost/callback",
|
||||
"https://example.com/"
|
||||
) is True
|
||||
|
||||
def test_validate_127_0_0_1_http_allowed(self):
|
||||
"""Test that 127.0.0.1 can use HTTP."""
|
||||
assert validate_redirect_uri(
|
||||
"http://127.0.0.1:8000/callback",
|
||||
"https://example.com/"
|
||||
) is True
|
||||
|
||||
def test_validate_http_non_localhost_fails(self):
|
||||
"""Test that HTTP on non-localhost fails."""
|
||||
assert validate_redirect_uri(
|
||||
"http://example.com/callback",
|
||||
"https://example.com/"
|
||||
) is False
|
||||
|
||||
def test_validate_malformed_uri_fails(self):
|
||||
"""Test that malformed URI fails gracefully."""
|
||||
assert validate_redirect_uri(
|
||||
"not a url",
|
||||
"https://example.com/"
|
||||
) is False
|
||||
|
||||
|
||||
class TestExtractDomainFromUrl:
|
||||
"""Tests for extract_domain_from_url function."""
|
||||
|
||||
def test_extract_domain_basic(self):
|
||||
"""Test extracting domain from basic URL."""
|
||||
assert extract_domain_from_url("https://example.com/") == "example.com"
|
||||
|
||||
def test_extract_domain_with_path(self):
|
||||
"""Test extracting domain from URL with path."""
|
||||
assert extract_domain_from_url("https://example.com/path/to/page") == "example.com"
|
||||
|
||||
def test_extract_domain_with_port(self):
|
||||
"""Test extracting domain from URL with port."""
|
||||
assert extract_domain_from_url("https://example.com:8443/") == "example.com"
|
||||
|
||||
def test_extract_domain_subdomain(self):
|
||||
"""Test extracting subdomain."""
|
||||
assert extract_domain_from_url("https://blog.example.com/") == "blog.example.com"
|
||||
|
||||
def test_extract_domain_no_hostname_raises_error(self):
|
||||
"""Test that URL without hostname raises ValueError."""
|
||||
with pytest.raises(ValueError, match="URL has no hostname"):
|
||||
extract_domain_from_url("file:///path/to/file")
|
||||
|
||||
def test_extract_domain_invalid_url_raises_error(self):
|
||||
"""Test that invalid URL raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid URL"):
|
||||
extract_domain_from_url("not a url")
|
||||
|
||||
|
||||
class TestValidateEmail:
|
||||
"""Tests for validate_email function."""
|
||||
|
||||
def test_validate_email_basic(self):
|
||||
"""Test validating basic email."""
|
||||
assert validate_email("user@example.com") is True
|
||||
|
||||
def test_validate_email_with_plus(self):
|
||||
"""Test validating email with plus sign."""
|
||||
assert validate_email("user+tag@example.com") is True
|
||||
|
||||
def test_validate_email_with_dots(self):
|
||||
"""Test validating email with dots."""
|
||||
assert validate_email("first.last@example.com") is True
|
||||
|
||||
def test_validate_email_subdomain(self):
|
||||
"""Test validating email with subdomain."""
|
||||
assert validate_email("user@mail.example.com") is True
|
||||
|
||||
def test_validate_email_no_at_sign(self):
|
||||
"""Test that email without @ sign fails."""
|
||||
assert validate_email("notanemail") is False
|
||||
|
||||
def test_validate_email_no_domain(self):
|
||||
"""Test that email without domain fails."""
|
||||
assert validate_email("user@") is False
|
||||
|
||||
def test_validate_email_no_local_part(self):
|
||||
"""Test that email without local part fails."""
|
||||
assert validate_email("@example.com") is False
|
||||
|
||||
def test_validate_email_no_tld(self):
|
||||
"""Test that email without TLD fails."""
|
||||
assert validate_email("user@example") is False
|
||||
|
||||
def test_validate_email_empty_string(self):
|
||||
"""Test that empty string fails."""
|
||||
assert validate_email("") is False
|
||||
Reference in New Issue
Block a user