Security review fixes
This commit is contained in:
149
python/tests/unit/test_security_fixes.py
Normal file
149
python/tests/unit/test_security_fixes.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Tests for security fixes."""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from obsidian_rag.config import ObsidianRagConfig, SecurityConfig
|
||||
from obsidian_rag.embedder import OllamaEmbedder, SecurityError
|
||||
from obsidian_rag.indexer import Indexer, SensitiveContentError
|
||||
from obsidian_rag.security import sanitize_query
|
||||
|
||||
|
||||
def test_network_isolation_validation():
|
||||
"""Test that remote URLs are rejected when local_only=True."""
|
||||
config = ObsidianRagConfig(
|
||||
vault_path="/tmp/test",
|
||||
security=SecurityConfig(local_only=True)
|
||||
)
|
||||
|
||||
# Should allow localhost
|
||||
config.embedding.base_url = "http://localhost:11434"
|
||||
try:
|
||||
embedder = OllamaEmbedder(config)
|
||||
embedder._validate_network_isolation()
|
||||
except SecurityError:
|
||||
pytest.fail("Localhost should be allowed when local_only=True")
|
||||
|
||||
# Should reject remote URLs
|
||||
config.embedding.base_url = "http://example.com:11434"
|
||||
with pytest.raises(SecurityError, match="Remote embedding service not allowed"):
|
||||
embedder = OllamaEmbedder(config)
|
||||
embedder._validate_network_isolation()
|
||||
|
||||
# Should allow remote URLs when local_only=False
|
||||
config.security.local_only = False
|
||||
config.embedding.base_url = "http://example.com:11434"
|
||||
try:
|
||||
embedder = OllamaEmbedder(config)
|
||||
embedder._validate_network_isolation()
|
||||
except SecurityError:
|
||||
pytest.fail("Remote URLs should be allowed when local_only=False")
|
||||
|
||||
|
||||
def test_sensitive_content_enforcement():
|
||||
"""Test that sensitive content requires approval."""
|
||||
from obsidian_rag.config import MemoryConfig
|
||||
|
||||
config = ObsidianRagConfig(
|
||||
vault_path="/tmp/test",
|
||||
security=SecurityConfig(
|
||||
require_confirmation_for=["health"],
|
||||
auto_approve_sensitive=False,
|
||||
sensitive_sections=["#mentalhealth", "#physicalhealth", "#Relations"]
|
||||
),
|
||||
memory=MemoryConfig(
|
||||
patterns={
|
||||
"financial": ["owe", "owed", "debt", "paid", "$", "spent", "spend"],
|
||||
"health": ["#mentalhealth", "#physicalhealth", "medication", "therapy"],
|
||||
"commitments": ["shopping list", "costco", "amazon", "grocery"],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
indexer = Indexer(config)
|
||||
|
||||
# Create test chunks with health content
|
||||
chunks = [
|
||||
{
|
||||
'chunk_id': '1',
|
||||
'chunk_text': 'I have #mentalhealth issues and need therapy',
|
||||
'source_file': '/tmp/test/file.md',
|
||||
'source_directory': '/tmp/test',
|
||||
'section': 'content',
|
||||
'date': '2024-01-01',
|
||||
'tags': ['mentalhealth'],
|
||||
'chunk_index': 0,
|
||||
'total_chunks': 1,
|
||||
'modified_at': '2024-01-01T00:00:00Z',
|
||||
'indexed_at': '2024-01-01T00:00:00Z',
|
||||
}
|
||||
]
|
||||
|
||||
# Should raise SensitiveContentError
|
||||
with pytest.raises(SensitiveContentError, match="Sensitive health content detected"):
|
||||
indexer._check_sensitive_content_approval(chunks)
|
||||
|
||||
# Should pass when auto_approve_sensitive=True
|
||||
config.security.auto_approve_sensitive = True
|
||||
indexer = Indexer(config)
|
||||
try:
|
||||
indexer._check_sensitive_content_approval(chunks)
|
||||
except SensitiveContentError:
|
||||
pytest.fail("Should not raise when auto_approve_sensitive=True")
|
||||
|
||||
|
||||
def test_query_sanitization():
|
||||
"""Test that queries are properly sanitized."""
|
||||
# Test injection patterns
|
||||
dirty_query = "test'; DROP TABLE users; --"
|
||||
clean_query = sanitize_query(dirty_query)
|
||||
# The regex should remove the SQL injection pattern
|
||||
assert "'" not in clean_query or ";" not in clean_query
|
||||
|
||||
# Test that SQL keywords are removed
|
||||
sql_query = "SELECT * FROM users WHERE id = 1"
|
||||
clean_sql = sanitize_query(sql_query)
|
||||
assert "SELECT" not in clean_sql
|
||||
|
||||
# Test length limiting
|
||||
long_query = "a" * 2000
|
||||
short_query = sanitize_query(long_query)
|
||||
assert len(short_query) <= 1000
|
||||
|
||||
# Test whitespace normalization
|
||||
messy_query = " test \n query \t"
|
||||
clean_query = sanitize_query(messy_query)
|
||||
assert clean_query == "test query"
|
||||
|
||||
|
||||
def test_audit_logging():
|
||||
"""Test that audit logging works correctly."""
|
||||
from obsidian_rag.audit_logger import AuditLogger
|
||||
import tempfile
|
||||
import json
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
log_path = Path(tmpdir) / "audit.log"
|
||||
logger = AuditLogger(log_path)
|
||||
|
||||
# Log sensitive access
|
||||
logger.log_sensitive_access(
|
||||
"/tmp/test/health.md",
|
||||
"health",
|
||||
"index",
|
||||
{"chunk_id": "123"}
|
||||
)
|
||||
|
||||
# Verify log was created
|
||||
assert log_path.exists()
|
||||
|
||||
# Verify log content
|
||||
logs = json.loads(log_path.read_text())
|
||||
assert len(logs) == 1
|
||||
entry = logs[0]
|
||||
assert entry['content_type'] == 'health'
|
||||
assert entry['action'] == 'index'
|
||||
assert entry['metadata']['chunk_id'] == '123'
|
||||
|
||||
# Verify permissions
|
||||
import stat
|
||||
assert stat.S_IMODE(log_path.stat().st_mode) == 0o600
|
||||
Reference in New Issue
Block a user