150 lines
5.0 KiB
Python
150 lines
5.0 KiB
Python
"""Tests for security fixes."""
|
|
|
|
import pytest
|
|
from pathlib import Path
|
|
from obsidian_rag.config import ObsidianRagConfig, SecurityConfig
|
|
from obsidian_rag.embedder import OllamaEmbedder, SecurityError
|
|
from obsidian_rag.indexer import Indexer, SensitiveContentError
|
|
from obsidian_rag.security import sanitize_query
|
|
|
|
|
|
def test_network_isolation_validation():
|
|
"""Test that remote URLs are rejected when local_only=True."""
|
|
config = ObsidianRagConfig(
|
|
vault_path="/tmp/test",
|
|
security=SecurityConfig(local_only=True)
|
|
)
|
|
|
|
# Should allow localhost
|
|
config.embedding.base_url = "http://localhost:11434"
|
|
try:
|
|
embedder = OllamaEmbedder(config)
|
|
embedder._validate_network_isolation()
|
|
except SecurityError:
|
|
pytest.fail("Localhost should be allowed when local_only=True")
|
|
|
|
# Should reject remote URLs
|
|
config.embedding.base_url = "http://example.com:11434"
|
|
with pytest.raises(SecurityError, match="Remote embedding service not allowed"):
|
|
embedder = OllamaEmbedder(config)
|
|
embedder._validate_network_isolation()
|
|
|
|
# Should allow remote URLs when local_only=False
|
|
config.security.local_only = False
|
|
config.embedding.base_url = "http://example.com:11434"
|
|
try:
|
|
embedder = OllamaEmbedder(config)
|
|
embedder._validate_network_isolation()
|
|
except SecurityError:
|
|
pytest.fail("Remote URLs should be allowed when local_only=False")
|
|
|
|
|
|
def test_sensitive_content_enforcement():
|
|
"""Test that sensitive content requires approval."""
|
|
from obsidian_rag.config import MemoryConfig
|
|
|
|
config = ObsidianRagConfig(
|
|
vault_path="/tmp/test",
|
|
security=SecurityConfig(
|
|
require_confirmation_for=["health"],
|
|
auto_approve_sensitive=False,
|
|
sensitive_sections=["#mentalhealth", "#physicalhealth", "#Relations"]
|
|
),
|
|
memory=MemoryConfig(
|
|
patterns={
|
|
"financial": ["owe", "owed", "debt", "paid", "$", "spent", "spend"],
|
|
"health": ["#mentalhealth", "#physicalhealth", "medication", "therapy"],
|
|
"commitments": ["shopping list", "costco", "amazon", "grocery"],
|
|
}
|
|
)
|
|
)
|
|
|
|
indexer = Indexer(config)
|
|
|
|
# Create test chunks with health content
|
|
chunks = [
|
|
{
|
|
'chunk_id': '1',
|
|
'chunk_text': 'I have #mentalhealth issues and need therapy',
|
|
'source_file': '/tmp/test/file.md',
|
|
'source_directory': '/tmp/test',
|
|
'section': 'content',
|
|
'date': '2024-01-01',
|
|
'tags': ['mentalhealth'],
|
|
'chunk_index': 0,
|
|
'total_chunks': 1,
|
|
'modified_at': '2024-01-01T00:00:00Z',
|
|
'indexed_at': '2024-01-01T00:00:00Z',
|
|
}
|
|
]
|
|
|
|
# Should raise SensitiveContentError
|
|
with pytest.raises(SensitiveContentError, match="Sensitive health content detected"):
|
|
indexer._check_sensitive_content_approval(chunks)
|
|
|
|
# Should pass when auto_approve_sensitive=True
|
|
config.security.auto_approve_sensitive = True
|
|
indexer = Indexer(config)
|
|
try:
|
|
indexer._check_sensitive_content_approval(chunks)
|
|
except SensitiveContentError:
|
|
pytest.fail("Should not raise when auto_approve_sensitive=True")
|
|
|
|
|
|
def test_query_sanitization():
|
|
"""Test that queries are properly sanitized."""
|
|
# Test injection patterns
|
|
dirty_query = "test'; DROP TABLE users; --"
|
|
clean_query = sanitize_query(dirty_query)
|
|
# The regex should remove the SQL injection pattern
|
|
assert "'" not in clean_query or ";" not in clean_query
|
|
|
|
# Test that SQL keywords are removed
|
|
sql_query = "SELECT * FROM users WHERE id = 1"
|
|
clean_sql = sanitize_query(sql_query)
|
|
assert "SELECT" not in clean_sql
|
|
|
|
# Test length limiting
|
|
long_query = "a" * 2000
|
|
short_query = sanitize_query(long_query)
|
|
assert len(short_query) <= 1000
|
|
|
|
# Test whitespace normalization
|
|
messy_query = " test \n query \t"
|
|
clean_query = sanitize_query(messy_query)
|
|
assert clean_query == "test query"
|
|
|
|
|
|
def test_audit_logging():
|
|
"""Test that audit logging works correctly."""
|
|
from obsidian_rag.audit_logger import AuditLogger
|
|
import tempfile
|
|
import json
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
log_path = Path(tmpdir) / "audit.log"
|
|
logger = AuditLogger(log_path)
|
|
|
|
# Log sensitive access
|
|
logger.log_sensitive_access(
|
|
"/tmp/test/health.md",
|
|
"health",
|
|
"index",
|
|
{"chunk_id": "123"}
|
|
)
|
|
|
|
# Verify log was created
|
|
assert log_path.exists()
|
|
|
|
# Verify log content
|
|
logs = json.loads(log_path.read_text())
|
|
assert len(logs) == 1
|
|
entry = logs[0]
|
|
assert entry['content_type'] == 'health'
|
|
assert entry['action'] == 'index'
|
|
assert entry['metadata']['chunk_id'] == '123'
|
|
|
|
# Verify permissions
|
|
import stat
|
|
assert stat.S_IMODE(log_path.stat().st_mode) == 0o600
|