feat(phase-2): implement domain verification system

Implements complete domain verification flow with:
- rel=me link verification service
- HTML fetching with security controls
- Rate limiting to prevent abuse
- Email validation utilities
- Authorization and verification API endpoints
- User-facing templates for authorization and verification flows

This completes Phase 2: Domain Verification as designed.

Tests:
- All Phase 2 unit tests passing
- Coverage: 85% overall
- Migration tests updated

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-20 13:44:33 -07:00
parent 11ecd953d8
commit 074f74002c
28 changed files with 2283 additions and 14 deletions

View File

@@ -0,0 +1,236 @@
"""Tests for domain verification service."""
import pytest
from unittest.mock import Mock, MagicMock
from gondulf.services.domain_verification import DomainVerificationService
from gondulf.dns import DNSService
from gondulf.email import EmailService
from gondulf.storage import CodeStore
from gondulf.services.html_fetcher import HTMLFetcherService
from gondulf.services.relme_parser import RelMeParser
class TestDomainVerificationService:
"""Tests for DomainVerificationService."""
@pytest.fixture
def mock_dns(self):
"""Mock DNS service."""
return Mock(spec=DNSService)
@pytest.fixture
def mock_email(self):
"""Mock email service."""
return Mock(spec=EmailService)
@pytest.fixture
def mock_storage(self):
"""Mock code storage."""
return Mock(spec=CodeStore)
@pytest.fixture
def mock_fetcher(self):
"""Mock HTML fetcher."""
return Mock(spec=HTMLFetcherService)
@pytest.fixture
def mock_parser(self):
"""Mock rel=me parser."""
return Mock(spec=RelMeParser)
@pytest.fixture
def service(self, mock_dns, mock_email, mock_storage, mock_fetcher, mock_parser):
"""Create domain verification service with mocks."""
return DomainVerificationService(
dns_service=mock_dns,
email_service=mock_email,
code_storage=mock_storage,
html_fetcher=mock_fetcher,
relme_parser=mock_parser
)
def test_generate_verification_code(self, service):
"""Test verification code generation."""
code = service.generate_verification_code()
assert isinstance(code, str)
assert len(code) == 6
assert code.isdigit()
def test_generate_verification_code_unique(self, service):
"""Test that generated codes are different."""
code1 = service.generate_verification_code()
code2 = service.generate_verification_code()
# Very unlikely to be the same, but possible
# Just check they're both valid
assert code1.isdigit()
assert code2.isdigit()
def test_start_verification_dns_fails(self, service, mock_dns):
"""Test start_verification when DNS verification fails."""
mock_dns.verify_txt_record.return_value = False
result = service.start_verification("example.com", "https://example.com/")
assert result["success"] is False
assert result["error"] == "dns_verification_failed"
def test_start_verification_email_discovery_fails(
self, service, mock_dns, mock_fetcher, mock_parser
):
"""Test start_verification when email discovery fails."""
mock_dns.verify_txt_record.return_value = True
mock_fetcher.fetch.return_value = "<html></html>"
mock_parser.find_email.return_value = None
result = service.start_verification("example.com", "https://example.com/")
assert result["success"] is False
assert result["error"] == "email_discovery_failed"
def test_start_verification_invalid_email_format(
self, service, mock_dns, mock_fetcher, mock_parser
):
"""Test start_verification with invalid email format."""
mock_dns.verify_txt_record.return_value = True
mock_fetcher.fetch.return_value = "<html></html>"
mock_parser.find_email.return_value = "not-an-email"
result = service.start_verification("example.com", "https://example.com/")
assert result["success"] is False
assert result["error"] == "invalid_email_format"
def test_start_verification_email_send_fails(
self, service, mock_dns, mock_fetcher, mock_parser, mock_email
):
"""Test start_verification when email sending fails."""
mock_dns.verify_txt_record.return_value = True
mock_fetcher.fetch.return_value = "<html></html>"
mock_parser.find_email.return_value = "user@example.com"
mock_email.send_verification_code.side_effect = Exception("SMTP error")
result = service.start_verification("example.com", "https://example.com/")
assert result["success"] is False
assert result["error"] == "email_send_failed"
def test_start_verification_success(
self, service, mock_dns, mock_fetcher, mock_parser, mock_email, mock_storage
):
"""Test successful verification start."""
mock_dns.verify_txt_record.return_value = True
mock_fetcher.fetch.return_value = "<html></html>"
mock_parser.find_email.return_value = "user@example.com"
result = service.start_verification("example.com", "https://example.com/")
assert result["success"] is True
assert result["email"] == "u***@example.com" # Masked
assert result["verification_method"] == "email"
mock_email.send_verification_code.assert_called_once()
assert mock_storage.store.call_count == 2 # Code and email stored
def test_verify_email_code_invalid(self, service, mock_storage):
"""Test verify_email_code with invalid code."""
mock_storage.verify.return_value = False
result = service.verify_email_code("example.com", "123456")
assert result["success"] is False
assert result["error"] == "invalid_code"
def test_verify_email_code_email_not_found(self, service, mock_storage):
"""Test verify_email_code when email not in storage."""
mock_storage.verify.return_value = True
mock_storage.get.return_value = None
result = service.verify_email_code("example.com", "123456")
assert result["success"] is False
assert result["error"] == "email_not_found"
def test_verify_email_code_success(self, service, mock_storage):
"""Test successful email code verification."""
mock_storage.verify.return_value = True
mock_storage.get.return_value = "user@example.com"
result = service.verify_email_code("example.com", "123456")
assert result["success"] is True
assert result["email"] == "user@example.com"
mock_storage.delete.assert_called_once()
def test_create_authorization_code(self, service, mock_storage):
"""Test authorization code creation."""
code = service.create_authorization_code(
client_id="https://client.example.com/",
redirect_uri="https://client.example.com/callback",
state="test_state",
code_challenge="challenge",
code_challenge_method="S256",
scope="profile",
me="https://user.example.com/"
)
assert isinstance(code, str)
assert len(code) > 0
mock_storage.store.assert_called_once()
def test_verify_dns_record_success(self, service, mock_dns):
"""Test DNS record verification success."""
mock_dns.verify_txt_record.return_value = True
result = service._verify_dns_record("example.com")
assert result is True
mock_dns.verify_txt_record.assert_called_with("example.com", "gondulf-verify-domain")
def test_verify_dns_record_failure(self, service, mock_dns):
"""Test DNS record verification failure."""
mock_dns.verify_txt_record.return_value = False
result = service._verify_dns_record("example.com")
assert result is False
def test_verify_dns_record_exception(self, service, mock_dns):
"""Test DNS record verification handles exceptions."""
mock_dns.verify_txt_record.side_effect = Exception("DNS error")
result = service._verify_dns_record("example.com")
assert result is False
def test_discover_email_success(self, service, mock_fetcher, mock_parser):
"""Test email discovery success."""
mock_fetcher.fetch.return_value = "<html></html>"
mock_parser.find_email.return_value = "user@example.com"
email = service._discover_email("https://example.com/")
assert email == "user@example.com"
def test_discover_email_fetch_fails(self, service, mock_fetcher):
"""Test email discovery when fetch fails."""
mock_fetcher.fetch.return_value = None
email = service._discover_email("https://example.com/")
assert email is None
def test_discover_email_no_email_found(self, service, mock_fetcher, mock_parser):
"""Test email discovery when no email found."""
mock_fetcher.fetch.return_value = "<html></html>"
mock_parser.find_email.return_value = None
email = service._discover_email("https://example.com/")
assert email is None
def test_discover_email_exception(self, service, mock_fetcher):
"""Test email discovery handles exceptions."""
mock_fetcher.fetch.side_effect = Exception("Fetch error")
email = service._discover_email("https://example.com/")
assert email is None

View File

@@ -0,0 +1,175 @@
"""Tests for HTML fetcher service."""
import pytest
from unittest.mock import Mock, patch, MagicMock
from urllib.error import URLError, HTTPError
from gondulf.services.html_fetcher import HTMLFetcherService
class TestHTMLFetcherService:
"""Tests for HTMLFetcherService."""
def test_init_default_params(self):
"""Test initialization with default parameters."""
fetcher = HTMLFetcherService()
assert fetcher.timeout == 10
assert fetcher.max_size == 1024 * 1024
assert fetcher.max_redirects == 5
assert "Gondulf" in fetcher.user_agent
def test_init_custom_params(self):
"""Test initialization with custom parameters."""
fetcher = HTMLFetcherService(
timeout=5,
max_size=512 * 1024,
max_redirects=3,
user_agent="TestAgent/1.0"
)
assert fetcher.timeout == 5
assert fetcher.max_size == 512 * 1024
assert fetcher.max_redirects == 3
assert fetcher.user_agent == "TestAgent/1.0"
def test_fetch_requires_https(self):
"""Test that fetch requires HTTPS URLs."""
fetcher = HTMLFetcherService()
with pytest.raises(ValueError, match="must use HTTPS"):
fetcher.fetch("http://example.com/")
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
def test_fetch_success(self, mock_urlopen):
"""Test successful HTML fetch."""
# Mock response
mock_response = MagicMock()
mock_response.read.return_value = b"<html><body>Test</body></html>"
mock_response.headers.get_content_charset.return_value = "utf-8"
mock_response.headers.get.return_value = None # No Content-Length header
mock_response.__enter__.return_value = mock_response
mock_response.__exit__.return_value = None
mock_urlopen.return_value = mock_response
fetcher = HTMLFetcherService()
html = fetcher.fetch("https://example.com/")
assert html == "<html><body>Test</body></html>"
mock_urlopen.assert_called_once()
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
def test_fetch_respects_timeout(self, mock_urlopen):
"""Test that fetch respects timeout parameter."""
mock_response = MagicMock()
mock_response.read.return_value = b"<html></html>"
mock_response.headers.get_content_charset.return_value = "utf-8"
mock_response.headers.get.return_value = None
mock_response.__enter__.return_value = mock_response
mock_response.__exit__.return_value = None
mock_urlopen.return_value = mock_response
fetcher = HTMLFetcherService(timeout=15)
fetcher.fetch("https://example.com/")
call_kwargs = mock_urlopen.call_args[1]
assert call_kwargs['timeout'] == 15
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
def test_fetch_content_length_too_large(self, mock_urlopen):
"""Test that fetch returns None if Content-Length exceeds max_size."""
mock_response = MagicMock()
mock_response.headers.get.return_value = str(2 * 1024 * 1024) # 2MB
mock_response.__enter__.return_value = mock_response
mock_response.__exit__.return_value = None
mock_urlopen.return_value = mock_response
fetcher = HTMLFetcherService(max_size=1024 * 1024) # 1MB max
html = fetcher.fetch("https://example.com/")
assert html is None
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
def test_fetch_response_too_large(self, mock_urlopen):
"""Test that fetch returns None if response exceeds max_size."""
# Create response larger than max_size
large_content = b"x" * (1024 * 1024 + 1) # 1MB + 1 byte
mock_response = MagicMock()
mock_response.read.return_value = large_content
mock_response.headers.get_content_charset.return_value = "utf-8"
mock_response.headers.get.return_value = None
mock_response.__enter__.return_value = mock_response
mock_response.__exit__.return_value = None
mock_urlopen.return_value = mock_response
fetcher = HTMLFetcherService(max_size=1024 * 1024)
html = fetcher.fetch("https://example.com/")
assert html is None
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
def test_fetch_url_error(self, mock_urlopen):
"""Test that fetch returns None on URLError."""
mock_urlopen.side_effect = URLError("Connection failed")
fetcher = HTMLFetcherService()
html = fetcher.fetch("https://example.com/")
assert html is None
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
def test_fetch_http_error(self, mock_urlopen):
"""Test that fetch returns None on HTTPError."""
mock_urlopen.side_effect = HTTPError(
"https://example.com/",
404,
"Not Found",
{},
None
)
fetcher = HTMLFetcherService()
html = fetcher.fetch("https://example.com/")
assert html is None
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
def test_fetch_timeout_error(self, mock_urlopen):
"""Test that fetch returns None on timeout."""
mock_urlopen.side_effect = TimeoutError("Request timed out")
fetcher = HTMLFetcherService()
html = fetcher.fetch("https://example.com/")
assert html is None
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
def test_fetch_unicode_decode_error(self, mock_urlopen):
"""Test that fetch returns None on Unicode decode error."""
mock_response = MagicMock()
mock_response.read.return_value = b"\xff\xfe" # Invalid UTF-8
mock_response.headers.get_content_charset.return_value = "utf-8"
mock_response.headers.get.return_value = None
mock_response.__enter__.return_value = mock_response
mock_response.__exit__.return_value = None
mock_urlopen.return_value = mock_response
fetcher = HTMLFetcherService()
# Should use 'replace' error handling and return a string
html = fetcher.fetch("https://example.com/")
assert html is not None # Should not fail, uses error='replace'
@patch('gondulf.services.html_fetcher.urllib.request.urlopen')
def test_fetch_sets_user_agent(self, mock_urlopen):
"""Test that fetch sets User-Agent header."""
mock_response = MagicMock()
mock_response.read.return_value = b"<html></html>"
mock_response.headers.get_content_charset.return_value = "utf-8"
mock_response.headers.get.return_value = None
mock_response.__enter__.return_value = mock_response
mock_response.__exit__.return_value = None
mock_urlopen.return_value = mock_response
fetcher = HTMLFetcherService(user_agent="CustomAgent/2.0")
fetcher.fetch("https://example.com/")
# Check that User-Agent header was set
request = mock_urlopen.call_args[0][0]
assert request.get_header('User-agent') == "CustomAgent/2.0"

View File

@@ -0,0 +1,171 @@
"""Tests for rate limiter service."""
import pytest
import time
from unittest.mock import patch
from gondulf.services.rate_limiter import RateLimiter
class TestRateLimiter:
"""Tests for RateLimiter."""
def test_init_default_params(self):
"""Test initialization with default parameters."""
limiter = RateLimiter()
assert limiter.max_attempts == 3
assert limiter.window_seconds == 3600
def test_init_custom_params(self):
"""Test initialization with custom parameters."""
limiter = RateLimiter(max_attempts=5, window_hours=2)
assert limiter.max_attempts == 5
assert limiter.window_seconds == 7200
def test_check_rate_limit_no_attempts(self):
"""Test rate limit check with no previous attempts."""
limiter = RateLimiter()
assert limiter.check_rate_limit("example.com") is True
def test_check_rate_limit_within_limit(self):
"""Test rate limit check within limit."""
limiter = RateLimiter(max_attempts=3)
limiter.record_attempt("example.com")
limiter.record_attempt("example.com")
assert limiter.check_rate_limit("example.com") is True
def test_check_rate_limit_at_limit(self):
"""Test rate limit check at exact limit."""
limiter = RateLimiter(max_attempts=3)
limiter.record_attempt("example.com")
limiter.record_attempt("example.com")
limiter.record_attempt("example.com")
assert limiter.check_rate_limit("example.com") is False
def test_check_rate_limit_exceeded(self):
"""Test rate limit check when exceeded."""
limiter = RateLimiter(max_attempts=2)
limiter.record_attempt("example.com")
limiter.record_attempt("example.com")
assert limiter.check_rate_limit("example.com") is False
def test_record_attempt_creates_entry(self):
"""Test that record_attempt creates new entry."""
limiter = RateLimiter()
limiter.record_attempt("example.com")
assert "example.com" in limiter._attempts
assert len(limiter._attempts["example.com"]) == 1
def test_record_attempt_appends_to_existing(self):
"""Test that record_attempt appends to existing entry."""
limiter = RateLimiter()
limiter.record_attempt("example.com")
limiter.record_attempt("example.com")
assert len(limiter._attempts["example.com"]) == 2
def test_clean_old_attempts_removes_expired(self):
"""Test that old attempts are cleaned up."""
limiter = RateLimiter(max_attempts=3, window_hours=1)
# Mock time to control timestamps
with patch('time.time', return_value=1000):
limiter.record_attempt("example.com")
# Move time forward past window
with patch('time.time', return_value=1000 + 3700): # 1 hour + 100 seconds
limiter._clean_old_attempts("example.com")
assert "example.com" not in limiter._attempts
def test_clean_old_attempts_preserves_recent(self):
"""Test that recent attempts are preserved."""
limiter = RateLimiter(max_attempts=3, window_hours=1)
with patch('time.time', return_value=1000):
limiter.record_attempt("example.com")
# Move time forward but still within window
with patch('time.time', return_value=1000 + 1800): # 30 minutes
limiter._clean_old_attempts("example.com")
assert "example.com" in limiter._attempts
assert len(limiter._attempts["example.com"]) == 1
def test_check_rate_limit_cleans_old_attempts(self):
"""Test that check_rate_limit cleans old attempts."""
limiter = RateLimiter(max_attempts=2, window_hours=1)
# Record attempts at time 1000
with patch('time.time', return_value=1000):
limiter.record_attempt("example.com")
limiter.record_attempt("example.com")
# Check limit should be False
with patch('time.time', return_value=1000):
assert limiter.check_rate_limit("example.com") is False
# Move time forward past window
with patch('time.time', return_value=1000 + 3700):
# Old attempts should be cleaned, limit should pass
assert limiter.check_rate_limit("example.com") is True
def test_different_domains_independent(self):
"""Test that different domains have independent limits."""
limiter = RateLimiter(max_attempts=2)
limiter.record_attempt("example.com")
limiter.record_attempt("example.com")
limiter.record_attempt("other.com")
assert limiter.check_rate_limit("example.com") is False
assert limiter.check_rate_limit("other.com") is True
def test_get_remaining_attempts_initial(self):
"""Test getting remaining attempts initially."""
limiter = RateLimiter(max_attempts=3)
assert limiter.get_remaining_attempts("example.com") == 3
def test_get_remaining_attempts_after_one(self):
"""Test getting remaining attempts after one attempt."""
limiter = RateLimiter(max_attempts=3)
limiter.record_attempt("example.com")
assert limiter.get_remaining_attempts("example.com") == 2
def test_get_remaining_attempts_exhausted(self):
"""Test getting remaining attempts when exhausted."""
limiter = RateLimiter(max_attempts=3)
limiter.record_attempt("example.com")
limiter.record_attempt("example.com")
limiter.record_attempt("example.com")
assert limiter.get_remaining_attempts("example.com") == 0
def test_get_reset_time_no_attempts(self):
"""Test getting reset time with no attempts."""
limiter = RateLimiter()
assert limiter.get_reset_time("example.com") == 0
def test_get_reset_time_with_attempts(self):
"""Test getting reset time with attempts."""
limiter = RateLimiter(window_hours=1)
with patch('time.time', return_value=1000):
limiter.record_attempt("example.com")
reset_time = limiter.get_reset_time("example.com")
assert reset_time == 1000 + 3600
def test_get_reset_time_multiple_attempts(self):
"""Test getting reset time with multiple attempts (returns oldest)."""
limiter = RateLimiter(window_hours=1)
with patch('time.time', return_value=1000):
limiter.record_attempt("example.com")
with patch('time.time', return_value=2000):
limiter.record_attempt("example.com")
# Reset time should be based on oldest attempt
reset_time = limiter.get_reset_time("example.com")
assert reset_time == 1000 + 3600

View File

@@ -0,0 +1,181 @@
"""Tests for rel=me parser service."""
import pytest
from gondulf.services.relme_parser import RelMeParser
class TestRelMeParser:
"""Tests for RelMeParser."""
def test_parse_relme_links_basic(self):
"""Test parsing basic rel=me links."""
html = """
<html>
<body>
<a rel="me" href="https://github.com/user">GitHub</a>
<a rel="me" href="mailto:user@example.com">Email</a>
</body>
</html>
"""
parser = RelMeParser()
links = parser.parse_relme_links(html)
assert len(links) == 2
assert "https://github.com/user" in links
assert "mailto:user@example.com" in links
def test_parse_relme_links_link_tag(self):
"""Test parsing rel=me from <link> tags."""
html = """
<html>
<head>
<link rel="me" href="https://twitter.com/user">
</head>
</html>
"""
parser = RelMeParser()
links = parser.parse_relme_links(html)
assert len(links) == 1
assert "https://twitter.com/user" in links
def test_parse_relme_links_no_rel_me(self):
"""Test parsing HTML with no rel=me links."""
html = """
<html>
<body>
<a href="https://example.com">Link</a>
</body>
</html>
"""
parser = RelMeParser()
links = parser.parse_relme_links(html)
assert len(links) == 0
def test_parse_relme_links_no_href(self):
"""Test parsing rel=me link without href."""
html = """
<html>
<body>
<a rel="me">No href</a>
</body>
</html>
"""
parser = RelMeParser()
links = parser.parse_relme_links(html)
assert len(links) == 0
def test_parse_relme_links_malformed_html(self):
"""Test parsing malformed HTML returns empty list."""
html = "<html><body><<>>broken"
parser = RelMeParser()
links = parser.parse_relme_links(html)
# Should not crash, returns what it can parse
assert isinstance(links, list)
def test_extract_mailto_email_basic(self):
"""Test extracting email from mailto: link."""
links = ["mailto:user@example.com"]
parser = RelMeParser()
email = parser.extract_mailto_email(links)
assert email == "user@example.com"
def test_extract_mailto_email_with_query(self):
"""Test extracting email from mailto: link with query parameters."""
links = ["mailto:user@example.com?subject=Hello"]
parser = RelMeParser()
email = parser.extract_mailto_email(links)
assert email == "user@example.com"
def test_extract_mailto_email_multiple_links(self):
"""Test extracting email from multiple links (returns first mailto:)."""
links = [
"https://github.com/user",
"mailto:user@example.com",
"mailto:other@example.com"
]
parser = RelMeParser()
email = parser.extract_mailto_email(links)
assert email == "user@example.com"
def test_extract_mailto_email_no_mailto(self):
"""Test extracting email when no mailto: links present."""
links = ["https://github.com/user", "https://twitter.com/user"]
parser = RelMeParser()
email = parser.extract_mailto_email(links)
assert email is None
def test_extract_mailto_email_invalid_format(self):
"""Test extracting email from malformed mailto: link."""
links = ["mailto:notanemail"]
parser = RelMeParser()
email = parser.extract_mailto_email(links)
# Should return None for invalid email format
assert email is None
def test_extract_mailto_email_empty_list(self):
"""Test extracting email from empty list."""
parser = RelMeParser()
email = parser.extract_mailto_email([])
assert email is None
def test_find_email_success(self):
"""Test find_email combining parse and extract."""
html = """
<html>
<body>
<a rel="me" href="https://github.com/user">GitHub</a>
<a rel="me" href="mailto:user@example.com">Email</a>
</body>
</html>
"""
parser = RelMeParser()
email = parser.find_email(html)
assert email == "user@example.com"
def test_find_email_no_email(self):
"""Test find_email when no email present."""
html = """
<html>
<body>
<a rel="me" href="https://github.com/user">GitHub</a>
</body>
</html>
"""
parser = RelMeParser()
email = parser.find_email(html)
assert email is None
def test_find_email_malformed_html(self):
"""Test find_email with malformed HTML."""
html = "<html><<broken>>"
parser = RelMeParser()
email = parser.find_email(html)
assert email is None
def test_parse_relme_multiple_rel_values(self):
"""Test parsing link with multiple rel values including 'me'."""
html = """
<html>
<body>
<a rel="me nofollow" href="https://example.com">Link</a>
</body>
</html>
"""
parser = RelMeParser()
links = parser.parse_relme_links(html)
assert len(links) == 1
assert "https://example.com" in links

View File

@@ -0,0 +1,199 @@
"""Tests for validation utilities."""
import pytest
from gondulf.utils.validation import (
mask_email,
normalize_client_id,
validate_redirect_uri,
extract_domain_from_url,
validate_email
)
class TestMaskEmail:
"""Tests for mask_email function."""
def test_mask_email_basic(self):
"""Test basic email masking."""
assert mask_email("user@example.com") == "u***@example.com"
def test_mask_email_long_local(self):
"""Test masking email with long local part."""
assert mask_email("verylongusername@example.com") == "v***@example.com"
def test_mask_email_single_char_local(self):
"""Test masking email with single character local part."""
# Should return unchanged if local part is only 1 character
assert mask_email("a@example.com") == "a@example.com"
def test_mask_email_no_at_sign(self):
"""Test masking invalid email without @ sign."""
assert mask_email("notanemail") == "notanemail"
def test_mask_email_empty_string(self):
"""Test masking empty string."""
assert mask_email("") == ""
class TestNormalizeClientId:
"""Tests for normalize_client_id function."""
def test_normalize_basic_https(self):
"""Test normalizing basic HTTPS URL."""
assert normalize_client_id("https://example.com/") == "https://example.com/"
def test_normalize_remove_default_port(self):
"""Test normalizing URL with default HTTPS port."""
assert normalize_client_id("https://example.com:443/") == "https://example.com/"
def test_normalize_preserve_non_default_port(self):
"""Test normalizing URL with non-default port."""
assert normalize_client_id("https://example.com:8443/") == "https://example.com:8443/"
def test_normalize_preserve_path(self):
"""Test normalizing URL with path."""
assert normalize_client_id("https://example.com/app") == "https://example.com/app"
def test_normalize_preserve_query(self):
"""Test normalizing URL with query string."""
assert normalize_client_id("https://example.com/?foo=bar") == "https://example.com/?foo=bar"
def test_normalize_http_scheme_raises_error(self):
"""Test that HTTP scheme raises ValueError."""
with pytest.raises(ValueError, match="must use https scheme"):
normalize_client_id("http://example.com/")
def test_normalize_no_scheme_raises_error(self):
"""Test that missing scheme raises ValueError."""
with pytest.raises(ValueError, match="must use https scheme"):
normalize_client_id("example.com")
class TestValidateRedirectUri:
"""Tests for validate_redirect_uri function."""
def test_validate_same_origin(self):
"""Test redirect URI with same origin as client_id."""
assert validate_redirect_uri(
"https://example.com/callback",
"https://example.com/"
) is True
def test_validate_different_path_same_origin(self):
"""Test redirect URI with different path but same origin."""
assert validate_redirect_uri(
"https://example.com/auth/callback",
"https://example.com/"
) is True
def test_validate_subdomain(self):
"""Test redirect URI on subdomain of client_id."""
assert validate_redirect_uri(
"https://app.example.com/callback",
"https://example.com/"
) is True
def test_validate_different_domain_fails(self):
"""Test redirect URI on completely different domain fails."""
assert validate_redirect_uri(
"https://evil.com/callback",
"https://example.com/"
) is False
def test_validate_localhost_http_allowed(self):
"""Test that localhost can use HTTP."""
assert validate_redirect_uri(
"http://localhost/callback",
"https://example.com/"
) is True
def test_validate_127_0_0_1_http_allowed(self):
"""Test that 127.0.0.1 can use HTTP."""
assert validate_redirect_uri(
"http://127.0.0.1:8000/callback",
"https://example.com/"
) is True
def test_validate_http_non_localhost_fails(self):
"""Test that HTTP on non-localhost fails."""
assert validate_redirect_uri(
"http://example.com/callback",
"https://example.com/"
) is False
def test_validate_malformed_uri_fails(self):
"""Test that malformed URI fails gracefully."""
assert validate_redirect_uri(
"not a url",
"https://example.com/"
) is False
class TestExtractDomainFromUrl:
"""Tests for extract_domain_from_url function."""
def test_extract_domain_basic(self):
"""Test extracting domain from basic URL."""
assert extract_domain_from_url("https://example.com/") == "example.com"
def test_extract_domain_with_path(self):
"""Test extracting domain from URL with path."""
assert extract_domain_from_url("https://example.com/path/to/page") == "example.com"
def test_extract_domain_with_port(self):
"""Test extracting domain from URL with port."""
assert extract_domain_from_url("https://example.com:8443/") == "example.com"
def test_extract_domain_subdomain(self):
"""Test extracting subdomain."""
assert extract_domain_from_url("https://blog.example.com/") == "blog.example.com"
def test_extract_domain_no_hostname_raises_error(self):
"""Test that URL without hostname raises ValueError."""
with pytest.raises(ValueError, match="URL has no hostname"):
extract_domain_from_url("file:///path/to/file")
def test_extract_domain_invalid_url_raises_error(self):
"""Test that invalid URL raises ValueError."""
with pytest.raises(ValueError, match="Invalid URL"):
extract_domain_from_url("not a url")
class TestValidateEmail:
"""Tests for validate_email function."""
def test_validate_email_basic(self):
"""Test validating basic email."""
assert validate_email("user@example.com") is True
def test_validate_email_with_plus(self):
"""Test validating email with plus sign."""
assert validate_email("user+tag@example.com") is True
def test_validate_email_with_dots(self):
"""Test validating email with dots."""
assert validate_email("first.last@example.com") is True
def test_validate_email_subdomain(self):
"""Test validating email with subdomain."""
assert validate_email("user@mail.example.com") is True
def test_validate_email_no_at_sign(self):
"""Test that email without @ sign fails."""
assert validate_email("notanemail") is False
def test_validate_email_no_domain(self):
"""Test that email without domain fails."""
assert validate_email("user@") is False
def test_validate_email_no_local_part(self):
"""Test that email without local part fails."""
assert validate_email("@example.com") is False
def test_validate_email_no_tld(self):
"""Test that email without TLD fails."""
assert validate_email("user@example") is False
def test_validate_email_empty_string(self):
"""Test that empty string fails."""
assert validate_email("") is False