Add comprehensive RSS scraper implementation with security and testing
- Modular architecture with separate modules for scraping, parsing, security, validation, and caching - Comprehensive security measures including HTML sanitization, rate limiting, and input validation - Robust error handling with custom exceptions and retry logic - HTTP caching with ETags and Last-Modified headers for efficiency - Pre-compiled regex patterns for improved performance - Comprehensive test suite with 66 tests covering all major functionality - Docker support for containerized deployment - Configuration management with environment variable support - Working parser that successfully extracts 32 articles from Warhammer Community 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
1
src/__init__.py
Normal file
1
src/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# RSS Scraper package
|
||||
5
src/rss_scraper/__init__.py
Normal file
5
src/rss_scraper/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""RSS Scraper for Warhammer Community website."""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "RSS Scraper"
|
||||
__description__ = "A production-ready RSS scraper for Warhammer Community website"
|
||||
216
src/rss_scraper/cache.py
Normal file
216
src/rss_scraper/cache.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""Caching utilities for avoiding redundant scraping."""
|
||||
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any, List
|
||||
import requests
|
||||
|
||||
from .config import Config
|
||||
from .exceptions import FileOperationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentCache:
|
||||
"""Cache for storing and retrieving scraped content."""
|
||||
|
||||
def __init__(self, cache_dir: str = "cache"):
|
||||
self.cache_dir = cache_dir
|
||||
self.cache_file = os.path.join(cache_dir, "content_cache.json")
|
||||
self.etag_file = os.path.join(cache_dir, "etags.json")
|
||||
self.max_cache_age_hours = 24 # Cache expires after 24 hours
|
||||
|
||||
# Ensure cache directory exists
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
def _get_cache_key(self, url: str) -> str:
|
||||
"""Generate cache key from URL."""
|
||||
return hashlib.sha256(url.encode()).hexdigest()
|
||||
|
||||
def _load_cache(self) -> Dict[str, Any]:
|
||||
"""Load cache from file."""
|
||||
try:
|
||||
if os.path.exists(self.cache_file):
|
||||
with open(self.cache_file, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load cache: {e}")
|
||||
return {}
|
||||
|
||||
def _save_cache(self, cache_data: Dict[str, Any]) -> None:
|
||||
"""Save cache to file."""
|
||||
try:
|
||||
with open(self.cache_file, 'w') as f:
|
||||
json.dump(cache_data, f, indent=2, default=str)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save cache: {e}")
|
||||
raise FileOperationError(f"Failed to save cache: {e}")
|
||||
|
||||
def _load_etags(self) -> Dict[str, str]:
|
||||
"""Load ETags from file."""
|
||||
try:
|
||||
if os.path.exists(self.etag_file):
|
||||
with open(self.etag_file, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load ETags: {e}")
|
||||
return {}
|
||||
|
||||
def _save_etags(self, etag_data: Dict[str, str]) -> None:
|
||||
"""Save ETags to file."""
|
||||
try:
|
||||
with open(self.etag_file, 'w') as f:
|
||||
json.dump(etag_data, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save ETags: {e}")
|
||||
|
||||
def _is_cache_valid(self, cached_entry: Dict[str, Any]) -> bool:
|
||||
"""Check if cached entry is still valid."""
|
||||
try:
|
||||
cached_time = datetime.fromisoformat(cached_entry['timestamp'])
|
||||
expiry_time = cached_time + timedelta(hours=self.max_cache_age_hours)
|
||||
return datetime.now() < expiry_time
|
||||
except (KeyError, ValueError):
|
||||
return False
|
||||
|
||||
def check_if_content_changed(self, url: str) -> Optional[Dict[str, str]]:
|
||||
"""Check if content has changed using conditional requests."""
|
||||
etags = self._load_etags()
|
||||
cache_key = self._get_cache_key(url)
|
||||
|
||||
headers = {}
|
||||
if cache_key in etags:
|
||||
headers['If-None-Match'] = etags[cache_key]
|
||||
|
||||
try:
|
||||
logger.debug(f"Checking if content changed for {url}")
|
||||
response = requests.head(url, headers=headers, timeout=10)
|
||||
|
||||
# 304 means not modified
|
||||
if response.status_code == 304:
|
||||
logger.info(f"Content not modified for {url}")
|
||||
return {'status': 'not_modified'}
|
||||
|
||||
# Update ETag if available
|
||||
if 'etag' in response.headers:
|
||||
etags[cache_key] = response.headers['etag']
|
||||
self._save_etags(etags)
|
||||
logger.debug(f"Updated ETag for {url}")
|
||||
|
||||
return {'status': 'modified', 'etag': response.headers.get('etag')}
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.warning(f"Failed to check content modification for {url}: {e}")
|
||||
# If we can't check, assume it's modified
|
||||
return {'status': 'modified'}
|
||||
|
||||
def get_cached_content(self, url: str) -> Optional[str]:
|
||||
"""Get cached HTML content if available and valid."""
|
||||
cache_data = self._load_cache()
|
||||
cache_key = self._get_cache_key(url)
|
||||
|
||||
if cache_key not in cache_data:
|
||||
logger.debug(f"No cached content for {url}")
|
||||
return None
|
||||
|
||||
cached_entry = cache_data[cache_key]
|
||||
|
||||
if not self._is_cache_valid(cached_entry):
|
||||
logger.debug(f"Cached content for {url} has expired")
|
||||
# Remove expired entry
|
||||
del cache_data[cache_key]
|
||||
self._save_cache(cache_data)
|
||||
return None
|
||||
|
||||
logger.info(f"Using cached content for {url}")
|
||||
return cached_entry['content']
|
||||
|
||||
def cache_content(self, url: str, content: str) -> None:
|
||||
"""Cache HTML content with timestamp."""
|
||||
cache_data = self._load_cache()
|
||||
cache_key = self._get_cache_key(url)
|
||||
|
||||
cache_data[cache_key] = {
|
||||
'url': url,
|
||||
'content': content,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'size': len(content)
|
||||
}
|
||||
|
||||
self._save_cache(cache_data)
|
||||
logger.info(f"Cached content for {url} ({len(content)} bytes)")
|
||||
|
||||
def get_cached_articles(self, url: str) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Get cached articles if available and valid."""
|
||||
cache_data = self._load_cache()
|
||||
cache_key = self._get_cache_key(url) + "_articles"
|
||||
|
||||
if cache_key not in cache_data:
|
||||
return None
|
||||
|
||||
cached_entry = cache_data[cache_key]
|
||||
|
||||
if not self._is_cache_valid(cached_entry):
|
||||
# Remove expired entry
|
||||
del cache_data[cache_key]
|
||||
self._save_cache(cache_data)
|
||||
return None
|
||||
|
||||
logger.info(f"Using cached articles for {url}")
|
||||
return cached_entry['articles']
|
||||
|
||||
def cache_articles(self, url: str, articles: List[Dict[str, Any]]) -> None:
|
||||
"""Cache extracted articles."""
|
||||
cache_data = self._load_cache()
|
||||
cache_key = self._get_cache_key(url) + "_articles"
|
||||
|
||||
# Convert datetime objects to strings for JSON serialization
|
||||
serializable_articles = []
|
||||
for article in articles:
|
||||
serializable_article = article.copy()
|
||||
if 'date' in serializable_article and hasattr(serializable_article['date'], 'isoformat'):
|
||||
serializable_article['date'] = serializable_article['date'].isoformat()
|
||||
serializable_articles.append(serializable_article)
|
||||
|
||||
cache_data[cache_key] = {
|
||||
'url': url,
|
||||
'articles': serializable_articles,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'count': len(articles)
|
||||
}
|
||||
|
||||
self._save_cache(cache_data)
|
||||
logger.info(f"Cached {len(articles)} articles for {url}")
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cached content."""
|
||||
try:
|
||||
if os.path.exists(self.cache_file):
|
||||
os.remove(self.cache_file)
|
||||
if os.path.exists(self.etag_file):
|
||||
os.remove(self.etag_file)
|
||||
logger.info("Cache cleared successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear cache: {e}")
|
||||
raise FileOperationError(f"Failed to clear cache: {e}")
|
||||
|
||||
def get_cache_info(self) -> Dict[str, Any]:
|
||||
"""Get information about cached content."""
|
||||
cache_data = self._load_cache()
|
||||
etags = self._load_etags()
|
||||
|
||||
info = {
|
||||
'cache_file': self.cache_file,
|
||||
'etag_file': self.etag_file,
|
||||
'cache_entries': len(cache_data),
|
||||
'etag_entries': len(etags),
|
||||
'cache_size_bytes': 0
|
||||
}
|
||||
|
||||
if os.path.exists(self.cache_file):
|
||||
info['cache_size_bytes'] = os.path.getsize(self.cache_file)
|
||||
|
||||
return info
|
||||
77
src/rss_scraper/config.py
Normal file
77
src/rss_scraper/config.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Configuration management for RSS Warhammer scraper."""
|
||||
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class Config:
|
||||
"""Configuration class for RSS scraper settings."""
|
||||
|
||||
# Security settings
|
||||
ALLOWED_DOMAINS: List[str] = [
|
||||
'warhammer-community.com',
|
||||
'www.warhammer-community.com'
|
||||
]
|
||||
|
||||
# Scraping limits
|
||||
MAX_SCROLL_ITERATIONS: int = int(os.getenv('MAX_SCROLL_ITERATIONS', '5'))
|
||||
MAX_CONTENT_SIZE: int = int(os.getenv('MAX_CONTENT_SIZE', str(10 * 1024 * 1024))) # 10MB
|
||||
MAX_TITLE_LENGTH: int = int(os.getenv('MAX_TITLE_LENGTH', '500'))
|
||||
|
||||
# Timing settings
|
||||
SCROLL_DELAY_SECONDS: float = float(os.getenv('SCROLL_DELAY_SECONDS', '2.0'))
|
||||
PAGE_TIMEOUT_MS: int = int(os.getenv('PAGE_TIMEOUT_MS', '120000'))
|
||||
|
||||
# Default URLs and paths
|
||||
DEFAULT_URL: str = os.getenv('DEFAULT_URL', 'https://www.warhammer-community.com/en-gb/')
|
||||
DEFAULT_OUTPUT_DIR: str = os.getenv('DEFAULT_OUTPUT_DIR', '.')
|
||||
|
||||
# File names
|
||||
RSS_FILENAME: str = os.getenv('RSS_FILENAME', 'warhammer_rss_feed.xml')
|
||||
DEBUG_HTML_FILENAME: str = os.getenv('DEBUG_HTML_FILENAME', 'page.html')
|
||||
|
||||
# Feed metadata
|
||||
FEED_TITLE: str = os.getenv('FEED_TITLE', 'Warhammer Community RSS Feed')
|
||||
FEED_DESCRIPTION: str = os.getenv('FEED_DESCRIPTION', 'Latest Warhammer Community Articles')
|
||||
|
||||
# Security patterns to remove from content
|
||||
DANGEROUS_PATTERNS: List[str] = [
|
||||
'<script', '</script', 'javascript:', 'data:', 'vbscript:'
|
||||
]
|
||||
|
||||
# CSS selectors for article parsing
|
||||
TITLE_SELECTORS: List[str] = [
|
||||
'h3.newsCard-title-sm',
|
||||
'h3.newsCard-title-lg'
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_output_dir(cls, override: Optional[str] = None) -> str:
|
||||
"""Get output directory with optional override."""
|
||||
return override or cls.DEFAULT_OUTPUT_DIR
|
||||
|
||||
@classmethod
|
||||
def get_allowed_domains(cls) -> List[str]:
|
||||
"""Get list of allowed domains for scraping."""
|
||||
env_domains = os.getenv('ALLOWED_DOMAINS')
|
||||
if env_domains:
|
||||
return [domain.strip() for domain in env_domains.split(',')]
|
||||
return cls.ALLOWED_DOMAINS
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls) -> None:
|
||||
"""Validate configuration values."""
|
||||
if cls.MAX_SCROLL_ITERATIONS < 0:
|
||||
raise ValueError("MAX_SCROLL_ITERATIONS must be non-negative")
|
||||
if cls.MAX_CONTENT_SIZE <= 0:
|
||||
raise ValueError("MAX_CONTENT_SIZE must be positive")
|
||||
if cls.MAX_TITLE_LENGTH <= 0:
|
||||
raise ValueError("MAX_TITLE_LENGTH must be positive")
|
||||
if cls.SCROLL_DELAY_SECONDS < 0:
|
||||
raise ValueError("SCROLL_DELAY_SECONDS must be non-negative")
|
||||
if cls.PAGE_TIMEOUT_MS <= 0:
|
||||
raise ValueError("PAGE_TIMEOUT_MS must be positive")
|
||||
if not cls.DEFAULT_URL.startswith(('http://', 'https://')):
|
||||
raise ValueError("DEFAULT_URL must be a valid HTTP/HTTPS URL")
|
||||
if not cls.get_allowed_domains():
|
||||
raise ValueError("ALLOWED_DOMAINS cannot be empty")
|
||||
41
src/rss_scraper/exceptions.py
Normal file
41
src/rss_scraper/exceptions.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Custom exceptions for the RSS scraper."""
|
||||
|
||||
|
||||
class ScrapingError(Exception):
|
||||
"""Base exception for scraping-related errors."""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(ScrapingError):
|
||||
"""Exception raised for validation errors."""
|
||||
pass
|
||||
|
||||
|
||||
class NetworkError(ScrapingError):
|
||||
"""Exception raised for network-related errors."""
|
||||
pass
|
||||
|
||||
|
||||
class PageLoadError(NetworkError):
|
||||
"""Exception raised when page fails to load properly."""
|
||||
pass
|
||||
|
||||
|
||||
class ContentSizeError(ScrapingError):
|
||||
"""Exception raised when content exceeds size limits."""
|
||||
pass
|
||||
|
||||
|
||||
class ParseError(ScrapingError):
|
||||
"""Exception raised when HTML parsing fails."""
|
||||
pass
|
||||
|
||||
|
||||
class ConfigurationError(ScrapingError):
|
||||
"""Exception raised for configuration-related errors."""
|
||||
pass
|
||||
|
||||
|
||||
class FileOperationError(ScrapingError):
|
||||
"""Exception raised for file operation errors."""
|
||||
pass
|
||||
111
src/rss_scraper/parser.py
Normal file
111
src/rss_scraper/parser.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""HTML parsing and article extraction functionality."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
import pytz
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from .config import Config
|
||||
from .validation import validate_link
|
||||
from .exceptions import ParseError
|
||||
from .security import sanitize_text_content, sanitize_html_content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def sanitize_text(text: Optional[str]) -> str:
|
||||
"""Sanitize text content to prevent injection attacks"""
|
||||
return sanitize_text_content(text)
|
||||
|
||||
|
||||
def extract_articles_from_html(html: str, base_url: str) -> List[Dict[str, Any]]:
|
||||
"""Extract articles from HTML content."""
|
||||
logger.info("Parsing HTML content with BeautifulSoup")
|
||||
|
||||
# Sanitize HTML content first for security
|
||||
sanitized_html = sanitize_html_content(html)
|
||||
|
||||
try:
|
||||
soup = BeautifulSoup(sanitized_html, 'html.parser')
|
||||
except Exception as e:
|
||||
raise ParseError(f"Failed to parse HTML content: {e}")
|
||||
|
||||
# Define a timezone (UTC in this case)
|
||||
timezone = pytz.UTC
|
||||
|
||||
# Find all articles in the page - look for article elements with shared- classes (all article types)
|
||||
all_articles = soup.find_all('article')
|
||||
article_elements = []
|
||||
for article in all_articles:
|
||||
classes = article.get('class', [])
|
||||
if classes and any('shared-' in cls for cls in classes):
|
||||
article_elements.append(article)
|
||||
logger.info(f"Found {len(article_elements)} article elements on page")
|
||||
|
||||
articles: List[Dict[str, Any]] = []
|
||||
seen_urls: set = set() # Set to track seen URLs and avoid duplicates
|
||||
|
||||
for article in article_elements:
|
||||
# Extract and sanitize the title
|
||||
title_tag = None
|
||||
for selector in Config.TITLE_SELECTORS:
|
||||
class_name = selector.split('.')[1] if '.' in selector else selector
|
||||
title_tag = article.find('h3', class_=class_name)
|
||||
if title_tag:
|
||||
break
|
||||
|
||||
raw_title = title_tag.text.strip() if title_tag else 'No title'
|
||||
title = sanitize_text(raw_title)
|
||||
|
||||
# Extract and validate the link - look for btn-cover class first, then any anchor
|
||||
link_tag = article.find('a', class_='btn-cover', href=True) or article.find('a', href=True)
|
||||
raw_link = link_tag['href'] if link_tag else None
|
||||
link = validate_link(raw_link, base_url)
|
||||
|
||||
# Skip this entry if the link is None or the URL has already been seen
|
||||
if not link or link in seen_urls:
|
||||
logger.debug(f"Skipping duplicate or invalid article: {title}")
|
||||
continue # Skip duplicates or invalid entries
|
||||
|
||||
seen_urls.add(link) # Add the URL to the set of seen URLs
|
||||
logger.debug(f"Processing article: {title[:50]}...")
|
||||
|
||||
# Extract the publication date and ignore reading time
|
||||
date = None
|
||||
for time_tag in article.find_all('time'):
|
||||
raw_date = time_tag.text.strip()
|
||||
|
||||
# Ignore "min" time blocks (reading time)
|
||||
if "min" not in raw_date.lower():
|
||||
try:
|
||||
# Parse the actual date (e.g., "05 Jun 25")
|
||||
date = datetime.strptime(raw_date, '%d %b %y')
|
||||
date = timezone.localize(date) # Localize with UTC
|
||||
break # Stop after finding the correct date
|
||||
except ValueError:
|
||||
# Try alternative date formats if the first one fails
|
||||
try:
|
||||
# Try format like "Jun 05, 2025"
|
||||
date = datetime.strptime(raw_date, '%b %d, %Y')
|
||||
date = timezone.localize(date)
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# If no valid date is found, use the current date as a fallback
|
||||
if not date:
|
||||
date = datetime.now(timezone)
|
||||
|
||||
# Add the article to the list with its publication date
|
||||
articles.append({
|
||||
'title': title,
|
||||
'link': link,
|
||||
'date': date
|
||||
})
|
||||
|
||||
# Sort the articles by publication date (newest first)
|
||||
articles.sort(key=lambda x: x['date'], reverse=True)
|
||||
logger.info(f"Successfully extracted {len(articles)} unique articles")
|
||||
|
||||
return articles
|
||||
124
src/rss_scraper/retry_utils.py
Normal file
124
src/rss_scraper/retry_utils.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Retry utilities with exponential backoff for network operations."""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, Type, Union, Tuple
|
||||
from functools import wraps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetryConfig:
|
||||
"""Configuration for retry behavior."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_attempts: int = 3,
|
||||
base_delay: float = 1.0,
|
||||
max_delay: float = 60.0,
|
||||
backoff_factor: float = 2.0,
|
||||
jitter: bool = True
|
||||
):
|
||||
self.max_attempts = max_attempts
|
||||
self.base_delay = base_delay
|
||||
self.max_delay = max_delay
|
||||
self.backoff_factor = backoff_factor
|
||||
self.jitter = jitter
|
||||
|
||||
|
||||
def calculate_delay(attempt: int, config: RetryConfig) -> float:
|
||||
"""Calculate delay for retry attempt with exponential backoff."""
|
||||
delay = config.base_delay * (config.backoff_factor ** (attempt - 1))
|
||||
delay = min(delay, config.max_delay)
|
||||
|
||||
if config.jitter:
|
||||
# Add random jitter to avoid thundering herd
|
||||
import random
|
||||
jitter_amount = delay * 0.1
|
||||
delay += random.uniform(-jitter_amount, jitter_amount)
|
||||
|
||||
return max(0, delay)
|
||||
|
||||
|
||||
def retry_on_exception(
|
||||
exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]],
|
||||
config: Optional[RetryConfig] = None
|
||||
) -> Callable:
|
||||
"""Decorator to retry function calls on specific exceptions.
|
||||
|
||||
Args:
|
||||
exceptions: Exception type(s) to retry on
|
||||
config: Retry configuration, uses default if None
|
||||
|
||||
Returns:
|
||||
Decorated function with retry logic
|
||||
"""
|
||||
if config is None:
|
||||
config = RetryConfig()
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any:
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(1, config.max_attempts + 1):
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
if attempt > 1:
|
||||
logger.info(f"{func.__name__} succeeded on attempt {attempt}")
|
||||
return result
|
||||
|
||||
except exceptions as e:
|
||||
last_exception = e
|
||||
|
||||
if attempt == config.max_attempts:
|
||||
logger.error(
|
||||
f"{func.__name__} failed after {config.max_attempts} attempts. "
|
||||
f"Final error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
delay = calculate_delay(attempt, config)
|
||||
logger.warning(
|
||||
f"{func.__name__} attempt {attempt} failed: {e}. "
|
||||
f"Retrying in {delay:.2f} seconds..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
except Exception as e:
|
||||
# Don't retry on unexpected exceptions
|
||||
logger.error(f"{func.__name__} failed with unexpected error: {e}")
|
||||
raise
|
||||
|
||||
# This should never be reached, but just in case
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
# Common retry configurations for different scenarios
|
||||
NETWORK_RETRY_CONFIG = RetryConfig(
|
||||
max_attempts=3,
|
||||
base_delay=1.0,
|
||||
max_delay=30.0,
|
||||
backoff_factor=2.0,
|
||||
jitter=True
|
||||
)
|
||||
|
||||
PLAYWRIGHT_RETRY_CONFIG = RetryConfig(
|
||||
max_attempts=2,
|
||||
base_delay=2.0,
|
||||
max_delay=10.0,
|
||||
backoff_factor=2.0,
|
||||
jitter=False
|
||||
)
|
||||
|
||||
FILE_RETRY_CONFIG = RetryConfig(
|
||||
max_attempts=3,
|
||||
base_delay=0.5,
|
||||
max_delay=5.0,
|
||||
backoff_factor=1.5,
|
||||
jitter=False
|
||||
)
|
||||
59
src/rss_scraper/rss_generator.py
Normal file
59
src/rss_scraper/rss_generator.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""RSS feed generation functionality."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import List, Dict, Any
|
||||
from feedgen.feed import FeedGenerator
|
||||
|
||||
from .config import Config
|
||||
from .validation import validate_output_path
|
||||
from .exceptions import FileOperationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_rss_feed(articles: List[Dict[str, Any]], feed_url: str) -> bytes:
|
||||
"""Generate RSS feed from articles list."""
|
||||
logger.info(f"Generating RSS feed for {len(articles)} articles")
|
||||
|
||||
# Initialize the RSS feed generator
|
||||
fg = FeedGenerator()
|
||||
fg.title(Config.FEED_TITLE)
|
||||
fg.link(href=feed_url)
|
||||
fg.description(Config.FEED_DESCRIPTION)
|
||||
|
||||
# Add the sorted articles to the RSS feed
|
||||
for article in articles:
|
||||
fe = fg.add_entry()
|
||||
fe.title(article['title'])
|
||||
fe.link(href=article['link'])
|
||||
fe.pubDate(article['date'])
|
||||
|
||||
# Generate the RSS feed
|
||||
return fg.rss_str(pretty=True)
|
||||
|
||||
|
||||
def save_rss_feed(rss_content: bytes, output_dir: str) -> str:
|
||||
"""Save RSS feed to file."""
|
||||
try:
|
||||
rss_path = validate_output_path(os.path.join(output_dir, Config.RSS_FILENAME), output_dir)
|
||||
with open(rss_path, 'wb') as f:
|
||||
f.write(rss_content)
|
||||
logger.info(f'RSS feed saved to: {rss_path}')
|
||||
return rss_path
|
||||
except Exception as e:
|
||||
raise FileOperationError(f"Failed to save RSS feed: {e}")
|
||||
|
||||
|
||||
def save_debug_html(html_content: str, output_dir: str) -> None:
|
||||
"""Save HTML content for debugging purposes."""
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
html_path = validate_output_path(os.path.join(output_dir, Config.DEBUG_HTML_FILENAME), output_dir)
|
||||
with open(html_path, 'w', encoding='utf-8') as f:
|
||||
f.write(soup.prettify())
|
||||
logger.info(f'Debug HTML saved to: {html_path}')
|
||||
except Exception as e:
|
||||
# HTML saving is not critical, just log the error
|
||||
logger.warning(f"Failed to save debug HTML: {e}")
|
||||
112
src/rss_scraper/scraper.py
Normal file
112
src/rss_scraper/scraper.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Web scraping functionality using Playwright."""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from playwright.sync_api import sync_playwright
|
||||
from typing import Optional
|
||||
|
||||
from .config import Config
|
||||
from .exceptions import NetworkError, PageLoadError, ContentSizeError
|
||||
from .retry_utils import retry_on_exception, PLAYWRIGHT_RETRY_CONFIG
|
||||
from .cache import ContentCache
|
||||
from .security import wait_for_rate_limit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global cache instance
|
||||
_cache = ContentCache()
|
||||
|
||||
|
||||
def load_page_with_retry(url: str, use_cache: bool = True) -> str:
|
||||
"""Load page content with caching and retry logic for network errors."""
|
||||
logger.info(f"Loading page: {url}")
|
||||
|
||||
# Check cache first if enabled
|
||||
if use_cache:
|
||||
# Check if content has changed using conditional requests
|
||||
change_check = _cache.check_if_content_changed(url)
|
||||
if change_check and change_check['status'] == 'not_modified':
|
||||
cached_content = _cache.get_cached_content(url)
|
||||
if cached_content:
|
||||
logger.info("Using cached content (not modified)")
|
||||
return cached_content
|
||||
|
||||
# Check for valid cached content
|
||||
cached_content = _cache.get_cached_content(url)
|
||||
if cached_content:
|
||||
logger.info("Using cached content")
|
||||
return cached_content
|
||||
|
||||
# Load fresh content
|
||||
html = _load_page_fresh(url)
|
||||
|
||||
# Cache the content if caching is enabled
|
||||
if use_cache:
|
||||
_cache.cache_content(url, html)
|
||||
|
||||
return html
|
||||
|
||||
|
||||
@retry_on_exception((NetworkError, PageLoadError), PLAYWRIGHT_RETRY_CONFIG)
|
||||
def _load_page_fresh(url: str) -> str:
|
||||
"""Load fresh page content using Playwright."""
|
||||
logger.info(f"Loading fresh content from: {url}")
|
||||
|
||||
# Apply rate limiting before making request
|
||||
wait_for_rate_limit()
|
||||
|
||||
try:
|
||||
with sync_playwright() as p:
|
||||
browser = p.chromium.launch(headless=True)
|
||||
page = browser.new_page()
|
||||
|
||||
# Set a longer timeout for loading the page
|
||||
page.set_default_navigation_timeout(Config.PAGE_TIMEOUT_MS)
|
||||
|
||||
try:
|
||||
# Load the page
|
||||
page.goto(url, wait_until="networkidle")
|
||||
|
||||
# Simulate scrolling to load more content
|
||||
logger.info(f"Scrolling page {Config.MAX_SCROLL_ITERATIONS} times to load content")
|
||||
for i in range(Config.MAX_SCROLL_ITERATIONS):
|
||||
logger.debug(f"Scroll iteration {i + 1}/{Config.MAX_SCROLL_ITERATIONS}")
|
||||
page.evaluate("window.scrollBy(0, document.body.scrollHeight)")
|
||||
time.sleep(Config.SCROLL_DELAY_SECONDS)
|
||||
|
||||
# Get the fully rendered HTML content
|
||||
html = page.content()
|
||||
|
||||
# Check content size for security
|
||||
if len(html) > Config.MAX_CONTENT_SIZE:
|
||||
error_msg = f"Content size {len(html)} exceeds maximum {Config.MAX_CONTENT_SIZE}"
|
||||
logger.error(error_msg)
|
||||
raise ContentSizeError(error_msg)
|
||||
|
||||
logger.info(f"Page loaded successfully, content size: {len(html)} bytes")
|
||||
return html
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load page content: {e}")
|
||||
if "timeout" in str(e).lower() or "network" in str(e).lower():
|
||||
raise NetworkError(f"Network error loading page: {e}")
|
||||
else:
|
||||
raise PageLoadError(f"Page load error: {e}")
|
||||
finally:
|
||||
browser.close()
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, (NetworkError, PageLoadError, ContentSizeError)):
|
||||
raise
|
||||
logger.error(f"Unexpected error in Playwright: {e}")
|
||||
raise PageLoadError(f"Playwright error: {e}")
|
||||
|
||||
|
||||
def clear_cache() -> None:
|
||||
"""Clear the content cache."""
|
||||
_cache.clear_cache()
|
||||
|
||||
|
||||
def get_cache_info() -> dict:
|
||||
"""Get information about the cache."""
|
||||
return _cache.get_cache_info()
|
||||
236
src/rss_scraper/security.py
Normal file
236
src/rss_scraper/security.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""Security utilities for content sanitization and rate limiting."""
|
||||
|
||||
import time
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
import bleach
|
||||
|
||||
from .config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiter to prevent excessive requests."""
|
||||
|
||||
def __init__(self, requests_per_minute: int = 30):
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.request_times: list = []
|
||||
self.min_delay_seconds = 60.0 / requests_per_minute
|
||||
self.last_request_time: Optional[float] = None
|
||||
|
||||
def wait_if_needed(self) -> None:
|
||||
"""Wait if necessary to respect rate limits."""
|
||||
current_time = time.time()
|
||||
|
||||
# Clean old request times (older than 1 minute)
|
||||
cutoff_time = current_time - 60
|
||||
self.request_times = [t for t in self.request_times if t > cutoff_time]
|
||||
|
||||
# Check if we've hit the rate limit
|
||||
if len(self.request_times) >= self.requests_per_minute:
|
||||
sleep_time = 60 - (current_time - self.request_times[0])
|
||||
if sleep_time > 0:
|
||||
logger.info(f"Rate limit reached, sleeping for {sleep_time:.2f} seconds")
|
||||
time.sleep(sleep_time)
|
||||
|
||||
# Ensure minimum delay between requests
|
||||
if self.last_request_time:
|
||||
time_since_last = current_time - self.last_request_time
|
||||
if time_since_last < self.min_delay_seconds:
|
||||
sleep_time = self.min_delay_seconds - time_since_last
|
||||
logger.debug(f"Enforcing minimum delay, sleeping for {sleep_time:.2f} seconds")
|
||||
time.sleep(sleep_time)
|
||||
|
||||
# Record this request
|
||||
self.request_times.append(time.time())
|
||||
self.last_request_time = time.time()
|
||||
|
||||
|
||||
class ContentSanitizer:
|
||||
"""Enhanced content sanitization for security."""
|
||||
|
||||
def __init__(self):
|
||||
# Allowed HTML tags for RSS content (including structural elements for parsing)
|
||||
self.allowed_tags = [
|
||||
'p', 'br', 'strong', 'em', 'b', 'i', 'u', 'span',
|
||||
'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
|
||||
'ul', 'ol', 'li', 'blockquote',
|
||||
'div', 'article', 'section', 'header', 'footer', 'main', 'nav',
|
||||
'a', 'img', 'figure', 'figcaption', 'time'
|
||||
]
|
||||
|
||||
# Allowed attributes
|
||||
self.allowed_attributes = {
|
||||
'*': ['class', 'id'],
|
||||
'a': ['href', 'title', 'class'],
|
||||
'img': ['src', 'alt', 'title', 'width', 'height', 'class'],
|
||||
'time': ['datetime', 'class'],
|
||||
'div': ['class', 'id'],
|
||||
'article': ['class', 'id'],
|
||||
'section': ['class', 'id']
|
||||
}
|
||||
|
||||
# Protocols allowed in URLs
|
||||
self.allowed_protocols = ['http', 'https']
|
||||
|
||||
# Dangerous patterns to remove (pre-compiled for performance)
|
||||
self.dangerous_patterns = [
|
||||
re.compile(r'<script[^>]*>.*?</script>', re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r'<iframe[^>]*>.*?</iframe>', re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r'<object[^>]*>.*?</object>', re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r'<embed[^>]*>.*?</embed>', re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r'<applet[^>]*>.*?</applet>', re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r'<form[^>]*>.*?</form>', re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r'javascript:', re.IGNORECASE),
|
||||
re.compile(r'vbscript:', re.IGNORECASE),
|
||||
re.compile(r'data:', re.IGNORECASE),
|
||||
re.compile(r'on\w+\s*=', re.IGNORECASE), # event handlers like onclick, onload, etc.
|
||||
]
|
||||
|
||||
def sanitize_html(self, html_content: str) -> str:
|
||||
"""Sanitize HTML content using bleach library."""
|
||||
if not html_content:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# First pass: remove obviously dangerous patterns
|
||||
cleaned = html_content
|
||||
for pattern in self.dangerous_patterns:
|
||||
cleaned = pattern.sub('', cleaned)
|
||||
|
||||
# Second pass: use bleach for comprehensive sanitization
|
||||
sanitized = bleach.clean(
|
||||
cleaned,
|
||||
tags=self.allowed_tags,
|
||||
attributes=self.allowed_attributes,
|
||||
protocols=self.allowed_protocols,
|
||||
strip=True,
|
||||
strip_comments=True
|
||||
)
|
||||
|
||||
return sanitized
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sanitizing HTML: {e}")
|
||||
# If sanitization fails, return empty string for safety
|
||||
return ""
|
||||
|
||||
def sanitize_text(self, text: Optional[str]) -> str:
|
||||
"""Enhanced text sanitization with better security."""
|
||||
if not text:
|
||||
return "No title"
|
||||
|
||||
# Basic cleaning
|
||||
sanitized = text.strip()
|
||||
|
||||
# Remove null bytes and other control characters
|
||||
sanitized = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', sanitized)
|
||||
|
||||
# Remove dangerous patterns (case insensitive)
|
||||
for pattern in Config.DANGEROUS_PATTERNS:
|
||||
sanitized = re.sub(re.escape(pattern), '', sanitized, flags=re.IGNORECASE)
|
||||
|
||||
# Limit length
|
||||
sanitized = sanitized[:Config.MAX_TITLE_LENGTH]
|
||||
|
||||
# Remove excessive whitespace
|
||||
sanitized = re.sub(r'\s+', ' ', sanitized).strip()
|
||||
|
||||
return sanitized if sanitized else "No title"
|
||||
|
||||
def validate_url_security(self, url: str) -> bool:
|
||||
"""Enhanced URL validation for security."""
|
||||
if not url:
|
||||
return False
|
||||
|
||||
# Check for dangerous protocols
|
||||
dangerous_protocols = ['javascript:', 'vbscript:', 'data:', 'file:', 'ftp:']
|
||||
url_lower = url.lower()
|
||||
|
||||
for protocol in dangerous_protocols:
|
||||
if url_lower.startswith(protocol):
|
||||
logger.warning(f"Blocked dangerous protocol in URL: {url}")
|
||||
return False
|
||||
|
||||
# Check for suspicious patterns
|
||||
suspicious_patterns = [
|
||||
r'\.\./', # Path traversal
|
||||
r'%2e%2e%2f', # Encoded path traversal
|
||||
r'<script', # Script injection
|
||||
r'javascript:', # JavaScript protocol
|
||||
r'vbscript:', # VBScript protocol
|
||||
]
|
||||
|
||||
for pattern in suspicious_patterns:
|
||||
if re.search(pattern, url, re.IGNORECASE):
|
||||
logger.warning(f"Blocked suspicious pattern in URL: {url}")
|
||||
return False
|
||||
|
||||
# Check URL length (prevent buffer overflow attacks)
|
||||
if len(url) > 2048:
|
||||
logger.warning(f"Blocked excessively long URL (length: {len(url)})")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def sanitize_filename(self, filename: str) -> str:
|
||||
"""Sanitize filenames to prevent directory traversal and injection."""
|
||||
if not filename:
|
||||
return "default"
|
||||
|
||||
# Remove path separators and dangerous characters
|
||||
sanitized = re.sub(r'[<>:"|?*\\/]', '_', filename)
|
||||
|
||||
# Remove null bytes and control characters
|
||||
sanitized = re.sub(r'[\x00-\x1F\x7F]', '', sanitized)
|
||||
|
||||
# Remove leading/trailing dots and spaces
|
||||
sanitized = sanitized.strip('. ')
|
||||
|
||||
# Prevent reserved Windows filenames
|
||||
reserved_names = [
|
||||
'CON', 'PRN', 'AUX', 'NUL',
|
||||
'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', 'COM8', 'COM9',
|
||||
'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9'
|
||||
]
|
||||
|
||||
if sanitized.upper() in reserved_names:
|
||||
sanitized = f"file_{sanitized}"
|
||||
|
||||
# Limit length
|
||||
sanitized = sanitized[:255]
|
||||
|
||||
return sanitized if sanitized else "default"
|
||||
|
||||
|
||||
# Global instances
|
||||
_rate_limiter = RateLimiter(requests_per_minute=30)
|
||||
_sanitizer = ContentSanitizer()
|
||||
|
||||
|
||||
def wait_for_rate_limit() -> None:
|
||||
"""Apply rate limiting."""
|
||||
_rate_limiter.wait_if_needed()
|
||||
|
||||
|
||||
def sanitize_html_content(html: str) -> str:
|
||||
"""Sanitize HTML content."""
|
||||
return _sanitizer.sanitize_html(html)
|
||||
|
||||
|
||||
def sanitize_text_content(text: Optional[str]) -> str:
|
||||
"""Sanitize text content."""
|
||||
return _sanitizer.sanitize_text(text)
|
||||
|
||||
|
||||
def validate_url_security(url: str) -> bool:
|
||||
"""Validate URL for security."""
|
||||
return _sanitizer.validate_url_security(url)
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""Sanitize filename."""
|
||||
return _sanitizer.sanitize_filename(filename)
|
||||
113
src/rss_scraper/validation.py
Normal file
113
src/rss_scraper/validation.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""URL and path validation utilities."""
|
||||
|
||||
import os
|
||||
import urllib.parse
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from .config import Config
|
||||
from .exceptions import ValidationError, FileOperationError
|
||||
from .security import validate_url_security, sanitize_filename
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_url(url: str) -> bool:
|
||||
"""Validate URL against whitelist of allowed domains"""
|
||||
try:
|
||||
logger.debug(f"Validating URL: {url}")
|
||||
|
||||
# Enhanced security validation first
|
||||
if not validate_url_security(url):
|
||||
raise ValidationError(f"URL failed security validation: {url}")
|
||||
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
raise ValidationError("Invalid URL format")
|
||||
|
||||
# Check if domain is in allowed list
|
||||
domain = parsed.netloc.lower()
|
||||
allowed_domains = Config.get_allowed_domains()
|
||||
if domain not in allowed_domains:
|
||||
raise ValidationError(f"Domain {domain} not in allowed list: {allowed_domains}")
|
||||
|
||||
logger.debug(f"URL validation successful for domain: {domain}")
|
||||
return True
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"URL validation failed for {url}: {e}")
|
||||
raise ValidationError(f"URL validation failed: {e}")
|
||||
|
||||
|
||||
def validate_output_path(path: str, base_dir: str) -> str:
|
||||
"""Validate and sanitize output file path"""
|
||||
logger.debug(f"Validating output path: {path} in base directory: {base_dir}")
|
||||
|
||||
try:
|
||||
# Sanitize the filename component
|
||||
dir_part, filename = os.path.split(path)
|
||||
if filename:
|
||||
sanitized_filename = sanitize_filename(filename)
|
||||
path = os.path.join(dir_part, sanitized_filename)
|
||||
logger.debug(f"Sanitized filename: {filename} -> {sanitized_filename}")
|
||||
|
||||
# Resolve to absolute path and check if it's safe
|
||||
abs_path = os.path.abspath(path)
|
||||
abs_base = os.path.abspath(base_dir)
|
||||
|
||||
# Ensure path is within allowed directory
|
||||
if not abs_path.startswith(abs_base):
|
||||
error_msg = f"Output path {abs_path} is outside allowed directory {abs_base}"
|
||||
logger.error(error_msg)
|
||||
raise ValidationError(error_msg)
|
||||
|
||||
# Additional security check for suspicious patterns - only check for directory traversal
|
||||
# Note: We allow absolute paths since they're resolved safely above
|
||||
if '..' in path:
|
||||
error_msg = f"Directory traversal detected in path: {path}"
|
||||
logger.error(error_msg)
|
||||
raise ValidationError(error_msg)
|
||||
|
||||
# Ensure output directory exists
|
||||
os.makedirs(abs_base, exist_ok=True)
|
||||
logger.debug(f"Output path validated: {abs_path}")
|
||||
|
||||
return abs_path
|
||||
except OSError as e:
|
||||
raise FileOperationError(f"Failed to create or access directory {base_dir}: {e}")
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise FileOperationError(f"Unexpected error validating path: {e}")
|
||||
|
||||
|
||||
def validate_link(link: Optional[str], base_url: str) -> Optional[str]:
|
||||
"""Validate and sanitize article links"""
|
||||
if not link:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Handle relative URLs
|
||||
if link.startswith('/'):
|
||||
parsed_base = urllib.parse.urlparse(base_url)
|
||||
link = f"{parsed_base.scheme}://{parsed_base.netloc}{link}"
|
||||
|
||||
# Enhanced security validation
|
||||
if not validate_url_security(link):
|
||||
logger.warning(f"Link failed security validation: {link}")
|
||||
return None
|
||||
|
||||
# Validate the resulting URL
|
||||
parsed = urllib.parse.urlparse(link)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
return None
|
||||
|
||||
# Ensure it's from allowed domain
|
||||
domain = parsed.netloc.lower()
|
||||
if domain not in Config.get_allowed_domains():
|
||||
return None
|
||||
|
||||
return link
|
||||
except Exception:
|
||||
return None
|
||||
Reference in New Issue
Block a user