Files
obsidian-rag/python/tests/unit/test_security_fixes.py

150 lines
5.0 KiB
Python

"""Tests for security fixes."""
import pytest
from pathlib import Path
from obsidian_rag.config import ObsidianRagConfig, SecurityConfig
from obsidian_rag.embedder import OllamaEmbedder, SecurityError
from obsidian_rag.indexer import Indexer, SensitiveContentError
from obsidian_rag.security import sanitize_query
def test_network_isolation_validation():
"""Test that remote URLs are rejected when local_only=True."""
config = ObsidianRagConfig(
vault_path="/tmp/test",
security=SecurityConfig(local_only=True)
)
# Should allow localhost
config.embedding.base_url = "http://localhost:11434"
try:
embedder = OllamaEmbedder(config)
embedder._validate_network_isolation()
except SecurityError:
pytest.fail("Localhost should be allowed when local_only=True")
# Should reject remote URLs
config.embedding.base_url = "http://example.com:11434"
with pytest.raises(SecurityError, match="Remote embedding service not allowed"):
embedder = OllamaEmbedder(config)
embedder._validate_network_isolation()
# Should allow remote URLs when local_only=False
config.security.local_only = False
config.embedding.base_url = "http://example.com:11434"
try:
embedder = OllamaEmbedder(config)
embedder._validate_network_isolation()
except SecurityError:
pytest.fail("Remote URLs should be allowed when local_only=False")
def test_sensitive_content_enforcement():
"""Test that sensitive content requires approval."""
from obsidian_rag.config import MemoryConfig
config = ObsidianRagConfig(
vault_path="/tmp/test",
security=SecurityConfig(
require_confirmation_for=["health"],
auto_approve_sensitive=False,
sensitive_sections=["#mentalhealth", "#physicalhealth", "#Relations"]
),
memory=MemoryConfig(
patterns={
"financial": ["owe", "owed", "debt", "paid", "$", "spent", "spend"],
"health": ["#mentalhealth", "#physicalhealth", "medication", "therapy"],
"commitments": ["shopping list", "costco", "amazon", "grocery"],
}
)
)
indexer = Indexer(config)
# Create test chunks with health content
chunks = [
{
'chunk_id': '1',
'chunk_text': 'I have #mentalhealth issues and need therapy',
'source_file': '/tmp/test/file.md',
'source_directory': '/tmp/test',
'section': 'content',
'date': '2024-01-01',
'tags': ['mentalhealth'],
'chunk_index': 0,
'total_chunks': 1,
'modified_at': '2024-01-01T00:00:00Z',
'indexed_at': '2024-01-01T00:00:00Z',
}
]
# Should raise SensitiveContentError
with pytest.raises(SensitiveContentError, match="Sensitive health content detected"):
indexer._check_sensitive_content_approval(chunks)
# Should pass when auto_approve_sensitive=True
config.security.auto_approve_sensitive = True
indexer = Indexer(config)
try:
indexer._check_sensitive_content_approval(chunks)
except SensitiveContentError:
pytest.fail("Should not raise when auto_approve_sensitive=True")
def test_query_sanitization():
"""Test that queries are properly sanitized."""
# Test injection patterns
dirty_query = "test'; DROP TABLE users; --"
clean_query = sanitize_query(dirty_query)
# The regex should remove the SQL injection pattern
assert "'" not in clean_query or ";" not in clean_query
# Test that SQL keywords are removed
sql_query = "SELECT * FROM users WHERE id = 1"
clean_sql = sanitize_query(sql_query)
assert "SELECT" not in clean_sql
# Test length limiting
long_query = "a" * 2000
short_query = sanitize_query(long_query)
assert len(short_query) <= 1000
# Test whitespace normalization
messy_query = " test \n query \t"
clean_query = sanitize_query(messy_query)
assert clean_query == "test query"
def test_audit_logging():
"""Test that audit logging works correctly."""
from obsidian_rag.audit_logger import AuditLogger
import tempfile
import json
with tempfile.TemporaryDirectory() as tmpdir:
log_path = Path(tmpdir) / "audit.log"
logger = AuditLogger(log_path)
# Log sensitive access
logger.log_sensitive_access(
"/tmp/test/health.md",
"health",
"index",
{"chunk_id": "123"}
)
# Verify log was created
assert log_path.exists()
# Verify log content
logs = json.loads(log_path.read_text())
assert len(logs) == 1
entry = logs[0]
assert entry['content_type'] == 'health'
assert entry['action'] == 'index'
assert entry['metadata']['chunk_id'] == '123'
# Verify permissions
import stat
assert stat.S_IMODE(log_path.stat().st_mode) == 0o600