"""Tests for security fixes.""" import pytest from pathlib import Path from obsidian_rag.config import ObsidianRagConfig, SecurityConfig from obsidian_rag.embedder import OllamaEmbedder, SecurityError from obsidian_rag.indexer import Indexer, SensitiveContentError from obsidian_rag.security import sanitize_query def test_network_isolation_validation(): """Test that remote URLs are rejected when local_only=True.""" config = ObsidianRagConfig( vault_path="/tmp/test", security=SecurityConfig(local_only=True) ) # Should allow localhost config.embedding.base_url = "http://localhost:11434" try: embedder = OllamaEmbedder(config) embedder._validate_network_isolation() except SecurityError: pytest.fail("Localhost should be allowed when local_only=True") # Should reject remote URLs config.embedding.base_url = "http://example.com:11434" with pytest.raises(SecurityError, match="Remote embedding service not allowed"): embedder = OllamaEmbedder(config) embedder._validate_network_isolation() # Should allow remote URLs when local_only=False config.security.local_only = False config.embedding.base_url = "http://example.com:11434" try: embedder = OllamaEmbedder(config) embedder._validate_network_isolation() except SecurityError: pytest.fail("Remote URLs should be allowed when local_only=False") def test_sensitive_content_enforcement(): """Test that sensitive content requires approval.""" from obsidian_rag.config import MemoryConfig config = ObsidianRagConfig( vault_path="/tmp/test", security=SecurityConfig( require_confirmation_for=["health"], auto_approve_sensitive=False, sensitive_sections=["#mentalhealth", "#physicalhealth", "#Relations"] ), memory=MemoryConfig( patterns={ "financial": ["owe", "owed", "debt", "paid", "$", "spent", "spend"], "health": ["#mentalhealth", "#physicalhealth", "medication", "therapy"], "commitments": ["shopping list", "costco", "amazon", "grocery"], } ) ) indexer = Indexer(config) # Create test chunks with health content chunks = [ { 'chunk_id': '1', 'chunk_text': 'I have #mentalhealth issues and need therapy', 'source_file': '/tmp/test/file.md', 'source_directory': '/tmp/test', 'section': 'content', 'date': '2024-01-01', 'tags': ['mentalhealth'], 'chunk_index': 0, 'total_chunks': 1, 'modified_at': '2024-01-01T00:00:00Z', 'indexed_at': '2024-01-01T00:00:00Z', } ] # Should raise SensitiveContentError with pytest.raises(SensitiveContentError, match="Sensitive health content detected"): indexer._check_sensitive_content_approval(chunks) # Should pass when auto_approve_sensitive=True config.security.auto_approve_sensitive = True indexer = Indexer(config) try: indexer._check_sensitive_content_approval(chunks) except SensitiveContentError: pytest.fail("Should not raise when auto_approve_sensitive=True") def test_query_sanitization(): """Test that queries are properly sanitized.""" # Test injection patterns dirty_query = "test'; DROP TABLE users; --" clean_query = sanitize_query(dirty_query) # The regex should remove the SQL injection pattern assert "'" not in clean_query or ";" not in clean_query # Test that SQL keywords are removed sql_query = "SELECT * FROM users WHERE id = 1" clean_sql = sanitize_query(sql_query) assert "SELECT" not in clean_sql # Test length limiting long_query = "a" * 2000 short_query = sanitize_query(long_query) assert len(short_query) <= 1000 # Test whitespace normalization messy_query = " test \n query \t" clean_query = sanitize_query(messy_query) assert clean_query == "test query" def test_audit_logging(): """Test that audit logging works correctly.""" from obsidian_rag.audit_logger import AuditLogger import tempfile import json with tempfile.TemporaryDirectory() as tmpdir: log_path = Path(tmpdir) / "audit.log" logger = AuditLogger(log_path) # Log sensitive access logger.log_sensitive_access( "/tmp/test/health.md", "health", "index", {"chunk_id": "123"} ) # Verify log was created assert log_path.exists() # Verify log content logs = json.loads(log_path.read_text()) assert len(logs) == 1 entry = logs[0] assert entry['content_type'] == 'health' assert entry['action'] == 'index' assert entry['metadata']['chunk_id'] == '123' # Verify permissions import stat assert stat.S_IMODE(log_path.stat().st_mode) == 0o600