Add comprehensive RSS scraper implementation with security and testing

- Modular architecture with separate modules for scraping, parsing, security, validation, and caching
- Comprehensive security measures including HTML sanitization, rate limiting, and input validation
- Robust error handling with custom exceptions and retry logic
- HTTP caching with ETags and Last-Modified headers for efficiency
- Pre-compiled regex patterns for improved performance
- Comprehensive test suite with 66 tests covering all major functionality
- Docker support for containerized deployment
- Configuration management with environment variable support
- Working parser that successfully extracts 32 articles from Warhammer Community

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-06-06 09:15:06 -06:00
parent e0647325ff
commit 25086fc01b
26 changed files with 15226 additions and 280 deletions

1
src/__init__.py Normal file
View File

@@ -0,0 +1 @@
# RSS Scraper package

View File

@@ -0,0 +1,5 @@
"""RSS Scraper for Warhammer Community website."""
__version__ = "1.0.0"
__author__ = "RSS Scraper"
__description__ = "A production-ready RSS scraper for Warhammer Community website"

216
src/rss_scraper/cache.py Normal file
View File

@@ -0,0 +1,216 @@
"""Caching utilities for avoiding redundant scraping."""
import os
import json
import hashlib
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
import requests
from .config import Config
from .exceptions import FileOperationError
logger = logging.getLogger(__name__)
class ContentCache:
"""Cache for storing and retrieving scraped content."""
def __init__(self, cache_dir: str = "cache"):
self.cache_dir = cache_dir
self.cache_file = os.path.join(cache_dir, "content_cache.json")
self.etag_file = os.path.join(cache_dir, "etags.json")
self.max_cache_age_hours = 24 # Cache expires after 24 hours
# Ensure cache directory exists
os.makedirs(cache_dir, exist_ok=True)
def _get_cache_key(self, url: str) -> str:
"""Generate cache key from URL."""
return hashlib.sha256(url.encode()).hexdigest()
def _load_cache(self) -> Dict[str, Any]:
"""Load cache from file."""
try:
if os.path.exists(self.cache_file):
with open(self.cache_file, 'r') as f:
return json.load(f)
except Exception as e:
logger.warning(f"Failed to load cache: {e}")
return {}
def _save_cache(self, cache_data: Dict[str, Any]) -> None:
"""Save cache to file."""
try:
with open(self.cache_file, 'w') as f:
json.dump(cache_data, f, indent=2, default=str)
except Exception as e:
logger.error(f"Failed to save cache: {e}")
raise FileOperationError(f"Failed to save cache: {e}")
def _load_etags(self) -> Dict[str, str]:
"""Load ETags from file."""
try:
if os.path.exists(self.etag_file):
with open(self.etag_file, 'r') as f:
return json.load(f)
except Exception as e:
logger.warning(f"Failed to load ETags: {e}")
return {}
def _save_etags(self, etag_data: Dict[str, str]) -> None:
"""Save ETags to file."""
try:
with open(self.etag_file, 'w') as f:
json.dump(etag_data, f, indent=2)
except Exception as e:
logger.warning(f"Failed to save ETags: {e}")
def _is_cache_valid(self, cached_entry: Dict[str, Any]) -> bool:
"""Check if cached entry is still valid."""
try:
cached_time = datetime.fromisoformat(cached_entry['timestamp'])
expiry_time = cached_time + timedelta(hours=self.max_cache_age_hours)
return datetime.now() < expiry_time
except (KeyError, ValueError):
return False
def check_if_content_changed(self, url: str) -> Optional[Dict[str, str]]:
"""Check if content has changed using conditional requests."""
etags = self._load_etags()
cache_key = self._get_cache_key(url)
headers = {}
if cache_key in etags:
headers['If-None-Match'] = etags[cache_key]
try:
logger.debug(f"Checking if content changed for {url}")
response = requests.head(url, headers=headers, timeout=10)
# 304 means not modified
if response.status_code == 304:
logger.info(f"Content not modified for {url}")
return {'status': 'not_modified'}
# Update ETag if available
if 'etag' in response.headers:
etags[cache_key] = response.headers['etag']
self._save_etags(etags)
logger.debug(f"Updated ETag for {url}")
return {'status': 'modified', 'etag': response.headers.get('etag')}
except requests.RequestException as e:
logger.warning(f"Failed to check content modification for {url}: {e}")
# If we can't check, assume it's modified
return {'status': 'modified'}
def get_cached_content(self, url: str) -> Optional[str]:
"""Get cached HTML content if available and valid."""
cache_data = self._load_cache()
cache_key = self._get_cache_key(url)
if cache_key not in cache_data:
logger.debug(f"No cached content for {url}")
return None
cached_entry = cache_data[cache_key]
if not self._is_cache_valid(cached_entry):
logger.debug(f"Cached content for {url} has expired")
# Remove expired entry
del cache_data[cache_key]
self._save_cache(cache_data)
return None
logger.info(f"Using cached content for {url}")
return cached_entry['content']
def cache_content(self, url: str, content: str) -> None:
"""Cache HTML content with timestamp."""
cache_data = self._load_cache()
cache_key = self._get_cache_key(url)
cache_data[cache_key] = {
'url': url,
'content': content,
'timestamp': datetime.now().isoformat(),
'size': len(content)
}
self._save_cache(cache_data)
logger.info(f"Cached content for {url} ({len(content)} bytes)")
def get_cached_articles(self, url: str) -> Optional[List[Dict[str, Any]]]:
"""Get cached articles if available and valid."""
cache_data = self._load_cache()
cache_key = self._get_cache_key(url) + "_articles"
if cache_key not in cache_data:
return None
cached_entry = cache_data[cache_key]
if not self._is_cache_valid(cached_entry):
# Remove expired entry
del cache_data[cache_key]
self._save_cache(cache_data)
return None
logger.info(f"Using cached articles for {url}")
return cached_entry['articles']
def cache_articles(self, url: str, articles: List[Dict[str, Any]]) -> None:
"""Cache extracted articles."""
cache_data = self._load_cache()
cache_key = self._get_cache_key(url) + "_articles"
# Convert datetime objects to strings for JSON serialization
serializable_articles = []
for article in articles:
serializable_article = article.copy()
if 'date' in serializable_article and hasattr(serializable_article['date'], 'isoformat'):
serializable_article['date'] = serializable_article['date'].isoformat()
serializable_articles.append(serializable_article)
cache_data[cache_key] = {
'url': url,
'articles': serializable_articles,
'timestamp': datetime.now().isoformat(),
'count': len(articles)
}
self._save_cache(cache_data)
logger.info(f"Cached {len(articles)} articles for {url}")
def clear_cache(self) -> None:
"""Clear all cached content."""
try:
if os.path.exists(self.cache_file):
os.remove(self.cache_file)
if os.path.exists(self.etag_file):
os.remove(self.etag_file)
logger.info("Cache cleared successfully")
except Exception as e:
logger.error(f"Failed to clear cache: {e}")
raise FileOperationError(f"Failed to clear cache: {e}")
def get_cache_info(self) -> Dict[str, Any]:
"""Get information about cached content."""
cache_data = self._load_cache()
etags = self._load_etags()
info = {
'cache_file': self.cache_file,
'etag_file': self.etag_file,
'cache_entries': len(cache_data),
'etag_entries': len(etags),
'cache_size_bytes': 0
}
if os.path.exists(self.cache_file):
info['cache_size_bytes'] = os.path.getsize(self.cache_file)
return info

77
src/rss_scraper/config.py Normal file
View File

@@ -0,0 +1,77 @@
"""Configuration management for RSS Warhammer scraper."""
import os
from typing import List, Optional
class Config:
"""Configuration class for RSS scraper settings."""
# Security settings
ALLOWED_DOMAINS: List[str] = [
'warhammer-community.com',
'www.warhammer-community.com'
]
# Scraping limits
MAX_SCROLL_ITERATIONS: int = int(os.getenv('MAX_SCROLL_ITERATIONS', '5'))
MAX_CONTENT_SIZE: int = int(os.getenv('MAX_CONTENT_SIZE', str(10 * 1024 * 1024))) # 10MB
MAX_TITLE_LENGTH: int = int(os.getenv('MAX_TITLE_LENGTH', '500'))
# Timing settings
SCROLL_DELAY_SECONDS: float = float(os.getenv('SCROLL_DELAY_SECONDS', '2.0'))
PAGE_TIMEOUT_MS: int = int(os.getenv('PAGE_TIMEOUT_MS', '120000'))
# Default URLs and paths
DEFAULT_URL: str = os.getenv('DEFAULT_URL', 'https://www.warhammer-community.com/en-gb/')
DEFAULT_OUTPUT_DIR: str = os.getenv('DEFAULT_OUTPUT_DIR', '.')
# File names
RSS_FILENAME: str = os.getenv('RSS_FILENAME', 'warhammer_rss_feed.xml')
DEBUG_HTML_FILENAME: str = os.getenv('DEBUG_HTML_FILENAME', 'page.html')
# Feed metadata
FEED_TITLE: str = os.getenv('FEED_TITLE', 'Warhammer Community RSS Feed')
FEED_DESCRIPTION: str = os.getenv('FEED_DESCRIPTION', 'Latest Warhammer Community Articles')
# Security patterns to remove from content
DANGEROUS_PATTERNS: List[str] = [
'<script', '</script', 'javascript:', 'data:', 'vbscript:'
]
# CSS selectors for article parsing
TITLE_SELECTORS: List[str] = [
'h3.newsCard-title-sm',
'h3.newsCard-title-lg'
]
@classmethod
def get_output_dir(cls, override: Optional[str] = None) -> str:
"""Get output directory with optional override."""
return override or cls.DEFAULT_OUTPUT_DIR
@classmethod
def get_allowed_domains(cls) -> List[str]:
"""Get list of allowed domains for scraping."""
env_domains = os.getenv('ALLOWED_DOMAINS')
if env_domains:
return [domain.strip() for domain in env_domains.split(',')]
return cls.ALLOWED_DOMAINS
@classmethod
def validate_config(cls) -> None:
"""Validate configuration values."""
if cls.MAX_SCROLL_ITERATIONS < 0:
raise ValueError("MAX_SCROLL_ITERATIONS must be non-negative")
if cls.MAX_CONTENT_SIZE <= 0:
raise ValueError("MAX_CONTENT_SIZE must be positive")
if cls.MAX_TITLE_LENGTH <= 0:
raise ValueError("MAX_TITLE_LENGTH must be positive")
if cls.SCROLL_DELAY_SECONDS < 0:
raise ValueError("SCROLL_DELAY_SECONDS must be non-negative")
if cls.PAGE_TIMEOUT_MS <= 0:
raise ValueError("PAGE_TIMEOUT_MS must be positive")
if not cls.DEFAULT_URL.startswith(('http://', 'https://')):
raise ValueError("DEFAULT_URL must be a valid HTTP/HTTPS URL")
if not cls.get_allowed_domains():
raise ValueError("ALLOWED_DOMAINS cannot be empty")

View File

@@ -0,0 +1,41 @@
"""Custom exceptions for the RSS scraper."""
class ScrapingError(Exception):
"""Base exception for scraping-related errors."""
pass
class ValidationError(ScrapingError):
"""Exception raised for validation errors."""
pass
class NetworkError(ScrapingError):
"""Exception raised for network-related errors."""
pass
class PageLoadError(NetworkError):
"""Exception raised when page fails to load properly."""
pass
class ContentSizeError(ScrapingError):
"""Exception raised when content exceeds size limits."""
pass
class ParseError(ScrapingError):
"""Exception raised when HTML parsing fails."""
pass
class ConfigurationError(ScrapingError):
"""Exception raised for configuration-related errors."""
pass
class FileOperationError(ScrapingError):
"""Exception raised for file operation errors."""
pass

111
src/rss_scraper/parser.py Normal file
View File

@@ -0,0 +1,111 @@
"""HTML parsing and article extraction functionality."""
import logging
from datetime import datetime
from typing import List, Dict, Any, Optional
import pytz
from bs4 import BeautifulSoup
from .config import Config
from .validation import validate_link
from .exceptions import ParseError
from .security import sanitize_text_content, sanitize_html_content
logger = logging.getLogger(__name__)
def sanitize_text(text: Optional[str]) -> str:
"""Sanitize text content to prevent injection attacks"""
return sanitize_text_content(text)
def extract_articles_from_html(html: str, base_url: str) -> List[Dict[str, Any]]:
"""Extract articles from HTML content."""
logger.info("Parsing HTML content with BeautifulSoup")
# Sanitize HTML content first for security
sanitized_html = sanitize_html_content(html)
try:
soup = BeautifulSoup(sanitized_html, 'html.parser')
except Exception as e:
raise ParseError(f"Failed to parse HTML content: {e}")
# Define a timezone (UTC in this case)
timezone = pytz.UTC
# Find all articles in the page - look for article elements with shared- classes (all article types)
all_articles = soup.find_all('article')
article_elements = []
for article in all_articles:
classes = article.get('class', [])
if classes and any('shared-' in cls for cls in classes):
article_elements.append(article)
logger.info(f"Found {len(article_elements)} article elements on page")
articles: List[Dict[str, Any]] = []
seen_urls: set = set() # Set to track seen URLs and avoid duplicates
for article in article_elements:
# Extract and sanitize the title
title_tag = None
for selector in Config.TITLE_SELECTORS:
class_name = selector.split('.')[1] if '.' in selector else selector
title_tag = article.find('h3', class_=class_name)
if title_tag:
break
raw_title = title_tag.text.strip() if title_tag else 'No title'
title = sanitize_text(raw_title)
# Extract and validate the link - look for btn-cover class first, then any anchor
link_tag = article.find('a', class_='btn-cover', href=True) or article.find('a', href=True)
raw_link = link_tag['href'] if link_tag else None
link = validate_link(raw_link, base_url)
# Skip this entry if the link is None or the URL has already been seen
if not link or link in seen_urls:
logger.debug(f"Skipping duplicate or invalid article: {title}")
continue # Skip duplicates or invalid entries
seen_urls.add(link) # Add the URL to the set of seen URLs
logger.debug(f"Processing article: {title[:50]}...")
# Extract the publication date and ignore reading time
date = None
for time_tag in article.find_all('time'):
raw_date = time_tag.text.strip()
# Ignore "min" time blocks (reading time)
if "min" not in raw_date.lower():
try:
# Parse the actual date (e.g., "05 Jun 25")
date = datetime.strptime(raw_date, '%d %b %y')
date = timezone.localize(date) # Localize with UTC
break # Stop after finding the correct date
except ValueError:
# Try alternative date formats if the first one fails
try:
# Try format like "Jun 05, 2025"
date = datetime.strptime(raw_date, '%b %d, %Y')
date = timezone.localize(date)
break
except ValueError:
continue
# If no valid date is found, use the current date as a fallback
if not date:
date = datetime.now(timezone)
# Add the article to the list with its publication date
articles.append({
'title': title,
'link': link,
'date': date
})
# Sort the articles by publication date (newest first)
articles.sort(key=lambda x: x['date'], reverse=True)
logger.info(f"Successfully extracted {len(articles)} unique articles")
return articles

View File

@@ -0,0 +1,124 @@
"""Retry utilities with exponential backoff for network operations."""
import time
import logging
from typing import Any, Callable, Optional, Type, Union, Tuple
from functools import wraps
logger = logging.getLogger(__name__)
class RetryConfig:
"""Configuration for retry behavior."""
def __init__(
self,
max_attempts: int = 3,
base_delay: float = 1.0,
max_delay: float = 60.0,
backoff_factor: float = 2.0,
jitter: bool = True
):
self.max_attempts = max_attempts
self.base_delay = base_delay
self.max_delay = max_delay
self.backoff_factor = backoff_factor
self.jitter = jitter
def calculate_delay(attempt: int, config: RetryConfig) -> float:
"""Calculate delay for retry attempt with exponential backoff."""
delay = config.base_delay * (config.backoff_factor ** (attempt - 1))
delay = min(delay, config.max_delay)
if config.jitter:
# Add random jitter to avoid thundering herd
import random
jitter_amount = delay * 0.1
delay += random.uniform(-jitter_amount, jitter_amount)
return max(0, delay)
def retry_on_exception(
exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]],
config: Optional[RetryConfig] = None
) -> Callable:
"""Decorator to retry function calls on specific exceptions.
Args:
exceptions: Exception type(s) to retry on
config: Retry configuration, uses default if None
Returns:
Decorated function with retry logic
"""
if config is None:
config = RetryConfig()
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
last_exception = None
for attempt in range(1, config.max_attempts + 1):
try:
result = func(*args, **kwargs)
if attempt > 1:
logger.info(f"{func.__name__} succeeded on attempt {attempt}")
return result
except exceptions as e:
last_exception = e
if attempt == config.max_attempts:
logger.error(
f"{func.__name__} failed after {config.max_attempts} attempts. "
f"Final error: {e}"
)
raise
delay = calculate_delay(attempt, config)
logger.warning(
f"{func.__name__} attempt {attempt} failed: {e}. "
f"Retrying in {delay:.2f} seconds..."
)
time.sleep(delay)
except Exception as e:
# Don't retry on unexpected exceptions
logger.error(f"{func.__name__} failed with unexpected error: {e}")
raise
# This should never be reached, but just in case
if last_exception:
raise last_exception
return wrapper
return decorator
# Common retry configurations for different scenarios
NETWORK_RETRY_CONFIG = RetryConfig(
max_attempts=3,
base_delay=1.0,
max_delay=30.0,
backoff_factor=2.0,
jitter=True
)
PLAYWRIGHT_RETRY_CONFIG = RetryConfig(
max_attempts=2,
base_delay=2.0,
max_delay=10.0,
backoff_factor=2.0,
jitter=False
)
FILE_RETRY_CONFIG = RetryConfig(
max_attempts=3,
base_delay=0.5,
max_delay=5.0,
backoff_factor=1.5,
jitter=False
)

View File

@@ -0,0 +1,59 @@
"""RSS feed generation functionality."""
import os
import logging
from typing import List, Dict, Any
from feedgen.feed import FeedGenerator
from .config import Config
from .validation import validate_output_path
from .exceptions import FileOperationError
logger = logging.getLogger(__name__)
def generate_rss_feed(articles: List[Dict[str, Any]], feed_url: str) -> bytes:
"""Generate RSS feed from articles list."""
logger.info(f"Generating RSS feed for {len(articles)} articles")
# Initialize the RSS feed generator
fg = FeedGenerator()
fg.title(Config.FEED_TITLE)
fg.link(href=feed_url)
fg.description(Config.FEED_DESCRIPTION)
# Add the sorted articles to the RSS feed
for article in articles:
fe = fg.add_entry()
fe.title(article['title'])
fe.link(href=article['link'])
fe.pubDate(article['date'])
# Generate the RSS feed
return fg.rss_str(pretty=True)
def save_rss_feed(rss_content: bytes, output_dir: str) -> str:
"""Save RSS feed to file."""
try:
rss_path = validate_output_path(os.path.join(output_dir, Config.RSS_FILENAME), output_dir)
with open(rss_path, 'wb') as f:
f.write(rss_content)
logger.info(f'RSS feed saved to: {rss_path}')
return rss_path
except Exception as e:
raise FileOperationError(f"Failed to save RSS feed: {e}")
def save_debug_html(html_content: str, output_dir: str) -> None:
"""Save HTML content for debugging purposes."""
try:
from bs4 import BeautifulSoup
soup = BeautifulSoup(html_content, 'html.parser')
html_path = validate_output_path(os.path.join(output_dir, Config.DEBUG_HTML_FILENAME), output_dir)
with open(html_path, 'w', encoding='utf-8') as f:
f.write(soup.prettify())
logger.info(f'Debug HTML saved to: {html_path}')
except Exception as e:
# HTML saving is not critical, just log the error
logger.warning(f"Failed to save debug HTML: {e}")

112
src/rss_scraper/scraper.py Normal file
View File

@@ -0,0 +1,112 @@
"""Web scraping functionality using Playwright."""
import time
import logging
from playwright.sync_api import sync_playwright
from typing import Optional
from .config import Config
from .exceptions import NetworkError, PageLoadError, ContentSizeError
from .retry_utils import retry_on_exception, PLAYWRIGHT_RETRY_CONFIG
from .cache import ContentCache
from .security import wait_for_rate_limit
logger = logging.getLogger(__name__)
# Global cache instance
_cache = ContentCache()
def load_page_with_retry(url: str, use_cache: bool = True) -> str:
"""Load page content with caching and retry logic for network errors."""
logger.info(f"Loading page: {url}")
# Check cache first if enabled
if use_cache:
# Check if content has changed using conditional requests
change_check = _cache.check_if_content_changed(url)
if change_check and change_check['status'] == 'not_modified':
cached_content = _cache.get_cached_content(url)
if cached_content:
logger.info("Using cached content (not modified)")
return cached_content
# Check for valid cached content
cached_content = _cache.get_cached_content(url)
if cached_content:
logger.info("Using cached content")
return cached_content
# Load fresh content
html = _load_page_fresh(url)
# Cache the content if caching is enabled
if use_cache:
_cache.cache_content(url, html)
return html
@retry_on_exception((NetworkError, PageLoadError), PLAYWRIGHT_RETRY_CONFIG)
def _load_page_fresh(url: str) -> str:
"""Load fresh page content using Playwright."""
logger.info(f"Loading fresh content from: {url}")
# Apply rate limiting before making request
wait_for_rate_limit()
try:
with sync_playwright() as p:
browser = p.chromium.launch(headless=True)
page = browser.new_page()
# Set a longer timeout for loading the page
page.set_default_navigation_timeout(Config.PAGE_TIMEOUT_MS)
try:
# Load the page
page.goto(url, wait_until="networkidle")
# Simulate scrolling to load more content
logger.info(f"Scrolling page {Config.MAX_SCROLL_ITERATIONS} times to load content")
for i in range(Config.MAX_SCROLL_ITERATIONS):
logger.debug(f"Scroll iteration {i + 1}/{Config.MAX_SCROLL_ITERATIONS}")
page.evaluate("window.scrollBy(0, document.body.scrollHeight)")
time.sleep(Config.SCROLL_DELAY_SECONDS)
# Get the fully rendered HTML content
html = page.content()
# Check content size for security
if len(html) > Config.MAX_CONTENT_SIZE:
error_msg = f"Content size {len(html)} exceeds maximum {Config.MAX_CONTENT_SIZE}"
logger.error(error_msg)
raise ContentSizeError(error_msg)
logger.info(f"Page loaded successfully, content size: {len(html)} bytes")
return html
except Exception as e:
logger.error(f"Failed to load page content: {e}")
if "timeout" in str(e).lower() or "network" in str(e).lower():
raise NetworkError(f"Network error loading page: {e}")
else:
raise PageLoadError(f"Page load error: {e}")
finally:
browser.close()
except Exception as e:
if isinstance(e, (NetworkError, PageLoadError, ContentSizeError)):
raise
logger.error(f"Unexpected error in Playwright: {e}")
raise PageLoadError(f"Playwright error: {e}")
def clear_cache() -> None:
"""Clear the content cache."""
_cache.clear_cache()
def get_cache_info() -> dict:
"""Get information about the cache."""
return _cache.get_cache_info()

236
src/rss_scraper/security.py Normal file
View File

@@ -0,0 +1,236 @@
"""Security utilities for content sanitization and rate limiting."""
import time
import logging
import re
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
import bleach
from .config import Config
logger = logging.getLogger(__name__)
class RateLimiter:
"""Rate limiter to prevent excessive requests."""
def __init__(self, requests_per_minute: int = 30):
self.requests_per_minute = requests_per_minute
self.request_times: list = []
self.min_delay_seconds = 60.0 / requests_per_minute
self.last_request_time: Optional[float] = None
def wait_if_needed(self) -> None:
"""Wait if necessary to respect rate limits."""
current_time = time.time()
# Clean old request times (older than 1 minute)
cutoff_time = current_time - 60
self.request_times = [t for t in self.request_times if t > cutoff_time]
# Check if we've hit the rate limit
if len(self.request_times) >= self.requests_per_minute:
sleep_time = 60 - (current_time - self.request_times[0])
if sleep_time > 0:
logger.info(f"Rate limit reached, sleeping for {sleep_time:.2f} seconds")
time.sleep(sleep_time)
# Ensure minimum delay between requests
if self.last_request_time:
time_since_last = current_time - self.last_request_time
if time_since_last < self.min_delay_seconds:
sleep_time = self.min_delay_seconds - time_since_last
logger.debug(f"Enforcing minimum delay, sleeping for {sleep_time:.2f} seconds")
time.sleep(sleep_time)
# Record this request
self.request_times.append(time.time())
self.last_request_time = time.time()
class ContentSanitizer:
"""Enhanced content sanitization for security."""
def __init__(self):
# Allowed HTML tags for RSS content (including structural elements for parsing)
self.allowed_tags = [
'p', 'br', 'strong', 'em', 'b', 'i', 'u', 'span',
'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
'ul', 'ol', 'li', 'blockquote',
'div', 'article', 'section', 'header', 'footer', 'main', 'nav',
'a', 'img', 'figure', 'figcaption', 'time'
]
# Allowed attributes
self.allowed_attributes = {
'*': ['class', 'id'],
'a': ['href', 'title', 'class'],
'img': ['src', 'alt', 'title', 'width', 'height', 'class'],
'time': ['datetime', 'class'],
'div': ['class', 'id'],
'article': ['class', 'id'],
'section': ['class', 'id']
}
# Protocols allowed in URLs
self.allowed_protocols = ['http', 'https']
# Dangerous patterns to remove (pre-compiled for performance)
self.dangerous_patterns = [
re.compile(r'<script[^>]*>.*?</script>', re.IGNORECASE | re.DOTALL),
re.compile(r'<iframe[^>]*>.*?</iframe>', re.IGNORECASE | re.DOTALL),
re.compile(r'<object[^>]*>.*?</object>', re.IGNORECASE | re.DOTALL),
re.compile(r'<embed[^>]*>.*?</embed>', re.IGNORECASE | re.DOTALL),
re.compile(r'<applet[^>]*>.*?</applet>', re.IGNORECASE | re.DOTALL),
re.compile(r'<form[^>]*>.*?</form>', re.IGNORECASE | re.DOTALL),
re.compile(r'javascript:', re.IGNORECASE),
re.compile(r'vbscript:', re.IGNORECASE),
re.compile(r'data:', re.IGNORECASE),
re.compile(r'on\w+\s*=', re.IGNORECASE), # event handlers like onclick, onload, etc.
]
def sanitize_html(self, html_content: str) -> str:
"""Sanitize HTML content using bleach library."""
if not html_content:
return ""
try:
# First pass: remove obviously dangerous patterns
cleaned = html_content
for pattern in self.dangerous_patterns:
cleaned = pattern.sub('', cleaned)
# Second pass: use bleach for comprehensive sanitization
sanitized = bleach.clean(
cleaned,
tags=self.allowed_tags,
attributes=self.allowed_attributes,
protocols=self.allowed_protocols,
strip=True,
strip_comments=True
)
return sanitized
except Exception as e:
logger.error(f"Error sanitizing HTML: {e}")
# If sanitization fails, return empty string for safety
return ""
def sanitize_text(self, text: Optional[str]) -> str:
"""Enhanced text sanitization with better security."""
if not text:
return "No title"
# Basic cleaning
sanitized = text.strip()
# Remove null bytes and other control characters
sanitized = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', sanitized)
# Remove dangerous patterns (case insensitive)
for pattern in Config.DANGEROUS_PATTERNS:
sanitized = re.sub(re.escape(pattern), '', sanitized, flags=re.IGNORECASE)
# Limit length
sanitized = sanitized[:Config.MAX_TITLE_LENGTH]
# Remove excessive whitespace
sanitized = re.sub(r'\s+', ' ', sanitized).strip()
return sanitized if sanitized else "No title"
def validate_url_security(self, url: str) -> bool:
"""Enhanced URL validation for security."""
if not url:
return False
# Check for dangerous protocols
dangerous_protocols = ['javascript:', 'vbscript:', 'data:', 'file:', 'ftp:']
url_lower = url.lower()
for protocol in dangerous_protocols:
if url_lower.startswith(protocol):
logger.warning(f"Blocked dangerous protocol in URL: {url}")
return False
# Check for suspicious patterns
suspicious_patterns = [
r'\.\./', # Path traversal
r'%2e%2e%2f', # Encoded path traversal
r'<script', # Script injection
r'javascript:', # JavaScript protocol
r'vbscript:', # VBScript protocol
]
for pattern in suspicious_patterns:
if re.search(pattern, url, re.IGNORECASE):
logger.warning(f"Blocked suspicious pattern in URL: {url}")
return False
# Check URL length (prevent buffer overflow attacks)
if len(url) > 2048:
logger.warning(f"Blocked excessively long URL (length: {len(url)})")
return False
return True
def sanitize_filename(self, filename: str) -> str:
"""Sanitize filenames to prevent directory traversal and injection."""
if not filename:
return "default"
# Remove path separators and dangerous characters
sanitized = re.sub(r'[<>:"|?*\\/]', '_', filename)
# Remove null bytes and control characters
sanitized = re.sub(r'[\x00-\x1F\x7F]', '', sanitized)
# Remove leading/trailing dots and spaces
sanitized = sanitized.strip('. ')
# Prevent reserved Windows filenames
reserved_names = [
'CON', 'PRN', 'AUX', 'NUL',
'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', 'COM8', 'COM9',
'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9'
]
if sanitized.upper() in reserved_names:
sanitized = f"file_{sanitized}"
# Limit length
sanitized = sanitized[:255]
return sanitized if sanitized else "default"
# Global instances
_rate_limiter = RateLimiter(requests_per_minute=30)
_sanitizer = ContentSanitizer()
def wait_for_rate_limit() -> None:
"""Apply rate limiting."""
_rate_limiter.wait_if_needed()
def sanitize_html_content(html: str) -> str:
"""Sanitize HTML content."""
return _sanitizer.sanitize_html(html)
def sanitize_text_content(text: Optional[str]) -> str:
"""Sanitize text content."""
return _sanitizer.sanitize_text(text)
def validate_url_security(url: str) -> bool:
"""Validate URL for security."""
return _sanitizer.validate_url_security(url)
def sanitize_filename(filename: str) -> str:
"""Sanitize filename."""
return _sanitizer.sanitize_filename(filename)

View File

@@ -0,0 +1,113 @@
"""URL and path validation utilities."""
import os
import urllib.parse
import logging
from typing import Optional
from .config import Config
from .exceptions import ValidationError, FileOperationError
from .security import validate_url_security, sanitize_filename
logger = logging.getLogger(__name__)
def validate_url(url: str) -> bool:
"""Validate URL against whitelist of allowed domains"""
try:
logger.debug(f"Validating URL: {url}")
# Enhanced security validation first
if not validate_url_security(url):
raise ValidationError(f"URL failed security validation: {url}")
parsed = urllib.parse.urlparse(url)
if not parsed.scheme or not parsed.netloc:
raise ValidationError("Invalid URL format")
# Check if domain is in allowed list
domain = parsed.netloc.lower()
allowed_domains = Config.get_allowed_domains()
if domain not in allowed_domains:
raise ValidationError(f"Domain {domain} not in allowed list: {allowed_domains}")
logger.debug(f"URL validation successful for domain: {domain}")
return True
except ValidationError:
raise
except Exception as e:
logger.error(f"URL validation failed for {url}: {e}")
raise ValidationError(f"URL validation failed: {e}")
def validate_output_path(path: str, base_dir: str) -> str:
"""Validate and sanitize output file path"""
logger.debug(f"Validating output path: {path} in base directory: {base_dir}")
try:
# Sanitize the filename component
dir_part, filename = os.path.split(path)
if filename:
sanitized_filename = sanitize_filename(filename)
path = os.path.join(dir_part, sanitized_filename)
logger.debug(f"Sanitized filename: {filename} -> {sanitized_filename}")
# Resolve to absolute path and check if it's safe
abs_path = os.path.abspath(path)
abs_base = os.path.abspath(base_dir)
# Ensure path is within allowed directory
if not abs_path.startswith(abs_base):
error_msg = f"Output path {abs_path} is outside allowed directory {abs_base}"
logger.error(error_msg)
raise ValidationError(error_msg)
# Additional security check for suspicious patterns - only check for directory traversal
# Note: We allow absolute paths since they're resolved safely above
if '..' in path:
error_msg = f"Directory traversal detected in path: {path}"
logger.error(error_msg)
raise ValidationError(error_msg)
# Ensure output directory exists
os.makedirs(abs_base, exist_ok=True)
logger.debug(f"Output path validated: {abs_path}")
return abs_path
except OSError as e:
raise FileOperationError(f"Failed to create or access directory {base_dir}: {e}")
except ValidationError:
raise
except Exception as e:
raise FileOperationError(f"Unexpected error validating path: {e}")
def validate_link(link: Optional[str], base_url: str) -> Optional[str]:
"""Validate and sanitize article links"""
if not link:
return None
try:
# Handle relative URLs
if link.startswith('/'):
parsed_base = urllib.parse.urlparse(base_url)
link = f"{parsed_base.scheme}://{parsed_base.netloc}{link}"
# Enhanced security validation
if not validate_url_security(link):
logger.warning(f"Link failed security validation: {link}")
return None
# Validate the resulting URL
parsed = urllib.parse.urlparse(link)
if not parsed.scheme or not parsed.netloc:
return None
# Ensure it's from allowed domain
domain = parsed.netloc.lower()
if domain not in Config.get_allowed_domains():
return None
return link
except Exception:
return None