Security review fixes
This commit is contained in:
95
python/obsidian_rag/audit_logger.py
Normal file
95
python/obsidian_rag/audit_logger.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Audit logging for sensitive data access and system events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import getpass
|
||||
import json
|
||||
import socket
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class AuditLogger:
|
||||
"""Secure audit logger for sensitive content access."""
|
||||
|
||||
def __init__(self, log_path: Path):
|
||||
self.log_path = log_path
|
||||
self.log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def log_sensitive_access(
|
||||
self,
|
||||
file_path: str,
|
||||
content_type: str,
|
||||
action: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Log access to sensitive content with redaction."""
|
||||
entry = {
|
||||
'timestamp': datetime.now(timezone.utc).isoformat(),
|
||||
'file_path': self._redact_path(file_path),
|
||||
'content_type': content_type,
|
||||
'action': action,
|
||||
'user': getpass.getuser(),
|
||||
'ip_address': self._get_local_ip(),
|
||||
'metadata': metadata or {},
|
||||
}
|
||||
self._write_entry(entry)
|
||||
|
||||
def log_security_event(
|
||||
self,
|
||||
event_type: str,
|
||||
severity: str,
|
||||
description: str,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Log security-related events."""
|
||||
entry = {
|
||||
'timestamp': datetime.now(timezone.utc).isoformat(),
|
||||
'event_type': event_type,
|
||||
'severity': severity,
|
||||
'description': description,
|
||||
'user': getpass.getuser(),
|
||||
'ip_address': self._get_local_ip(),
|
||||
'details': details or {},
|
||||
}
|
||||
self._write_entry(entry)
|
||||
|
||||
def _redact_path(self, path: str) -> str:
|
||||
"""Redact sensitive information from file paths."""
|
||||
# Basic redaction - keep filename but remove sensitive path components
|
||||
try:
|
||||
p = Path(path)
|
||||
if any(part.startswith('.') for part in p.parts):
|
||||
return f".../{p.name}"
|
||||
return str(p)
|
||||
except Exception:
|
||||
return "<redacted>"
|
||||
|
||||
def _get_local_ip(self) -> str:
|
||||
"""Get local IP address for audit logging."""
|
||||
try:
|
||||
return socket.gethostbyname(socket.gethostname())
|
||||
except Exception:
|
||||
return "127.0.0.1"
|
||||
|
||||
def _write_entry(self, entry: dict[str, Any]) -> None:
|
||||
"""Atomically append to audit log with secure permissions."""
|
||||
# Write to temporary file first
|
||||
tmp_path = self.log_path.with_suffix('.tmp')
|
||||
|
||||
# Read existing entries
|
||||
entries = []
|
||||
if self.log_path.exists():
|
||||
try:
|
||||
entries = json.loads(self.log_path.read_text())
|
||||
except (json.JSONDecodeError, OSError):
|
||||
entries = []
|
||||
|
||||
# Append new entry
|
||||
entries.append(entry)
|
||||
|
||||
# Write atomically
|
||||
tmp_path.write_text(json.dumps(entries, indent=2), encoding='utf-8')
|
||||
tmp_path.chmod(0o600) # Restrictive permissions
|
||||
tmp_path.rename(self.log_path)
|
||||
@@ -46,6 +46,7 @@ class SecurityConfig:
|
||||
default_factory=lambda: ["#mentalhealth", "#physicalhealth", "#Relations"]
|
||||
)
|
||||
local_only: bool = True
|
||||
auto_approve_sensitive: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
@@ -21,6 +22,10 @@ class OllamaUnavailableError(EmbeddingError):
|
||||
"""Raised when Ollama is unreachable."""
|
||||
|
||||
|
||||
class SecurityError(Exception):
|
||||
"""Raised when security validation fails."""
|
||||
|
||||
|
||||
class OllamaEmbedder:
|
||||
"""Client for Ollama /api/embed endpoint (mxbai-embed-large, 1024-dim)."""
|
||||
|
||||
@@ -29,7 +34,20 @@ class OllamaEmbedder:
|
||||
self.model = config.embedding.model
|
||||
self.dimensions = config.embedding.dimensions
|
||||
self.batch_size = config.embedding.batch_size
|
||||
self.local_only = config.security.local_only
|
||||
self._client = httpx.Client(timeout=DEFAULT_TIMEOUT)
|
||||
self._validate_network_isolation()
|
||||
|
||||
def _validate_network_isolation(self):
|
||||
"""Validate that embedding service is local when local_only is True."""
|
||||
if not self.local_only:
|
||||
return
|
||||
|
||||
parsed = urllib.parse.urlparse(self.base_url)
|
||||
if parsed.hostname not in ['localhost', '127.0.0.1', '::1']:
|
||||
raise SecurityError(
|
||||
f"Remote embedding service not allowed when local_only=True: {self.base_url}"
|
||||
)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Ollama is reachable and has the model."""
|
||||
|
||||
@@ -14,9 +14,11 @@ if TYPE_CHECKING:
|
||||
from obsidian_rag.config import ObsidianRagConfig
|
||||
|
||||
import obsidian_rag.config as config_mod
|
||||
from obsidian_rag.config import _resolve_data_dir
|
||||
from obsidian_rag.chunker import chunk_file
|
||||
from obsidian_rag.embedder import EmbeddingError, OllamaUnavailableError
|
||||
from obsidian_rag.embedder import EmbeddingError, OllamaUnavailableError, SecurityError
|
||||
from obsidian_rag.security import should_index_dir, validate_path
|
||||
from obsidian_rag.audit_logger import AuditLogger
|
||||
from obsidian_rag.vector_store import create_table_if_not_exists, delete_by_source_file, get_db, upsert_chunks
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
@@ -24,6 +26,10 @@ from obsidian_rag.vector_store import create_table_if_not_exists, delete_by_sour
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
class SensitiveContentError(Exception):
|
||||
"""Raised when sensitive content requires approval but isn't approved."""
|
||||
|
||||
|
||||
class Indexer:
|
||||
"""Coordinates the scan → chunk → embed → store pipeline."""
|
||||
|
||||
@@ -31,6 +37,7 @@ class Indexer:
|
||||
self.config = config
|
||||
self.vault_path = config_mod.resolve_vault_path(config)
|
||||
self._embedder = None # lazy init
|
||||
self._audit_logger = None # lazy init
|
||||
|
||||
@property
|
||||
def embedder(self):
|
||||
@@ -39,6 +46,38 @@ class Indexer:
|
||||
self._embedder = OllamaEmbedder(self.config)
|
||||
return self._embedder
|
||||
|
||||
@property
|
||||
def audit_logger(self):
|
||||
if self._audit_logger is None:
|
||||
from obsidian_rag.audit_logger import AuditLogger
|
||||
log_dir = _resolve_data_dir() / "audit"
|
||||
self._audit_logger = AuditLogger(log_dir / "audit.log")
|
||||
return self._audit_logger
|
||||
|
||||
def _check_sensitive_content_approval(self, chunks: list[dict[str, Any]]) -> None:
|
||||
"""Enforce user approval for sensitive content before indexing."""
|
||||
from obsidian_rag import security
|
||||
|
||||
sensitive_categories = self.config.security.require_confirmation_for
|
||||
if not sensitive_categories:
|
||||
return
|
||||
|
||||
for chunk in chunks:
|
||||
sensitivity = security.detect_sensitive(
|
||||
chunk['chunk_text'],
|
||||
self.config.security.sensitive_sections,
|
||||
self.config.memory.patterns
|
||||
)
|
||||
|
||||
for category in sensitive_categories:
|
||||
if sensitivity.get(category, False):
|
||||
if not self.config.security.auto_approve_sensitive:
|
||||
raise SensitiveContentError(
|
||||
f"Sensitive {category} content detected. "
|
||||
f"Requires explicit approval before indexing. "
|
||||
f"File: {chunk['source_file']}"
|
||||
)
|
||||
|
||||
def scan_vault(self) -> Generator[Path, None, None]:
|
||||
"""Walk vault, yielding markdown files to index."""
|
||||
for root, dirs, files in os.walk(self.vault_path):
|
||||
@@ -106,6 +145,26 @@ class Indexer:
|
||||
for idx, filepath in enumerate(files):
|
||||
try:
|
||||
num_chunks, enriched = self.process_file(filepath)
|
||||
# Enforce sensitive content policies
|
||||
self._check_sensitive_content_approval(enriched)
|
||||
|
||||
# Log sensitive content access
|
||||
for chunk in enriched:
|
||||
from obsidian_rag import security
|
||||
sensitivity = security.detect_sensitive(
|
||||
chunk['chunk_text'],
|
||||
self.config.security.sensitive_sections,
|
||||
self.config.memory.patterns
|
||||
)
|
||||
for category in ['health', 'financial', 'relations']:
|
||||
if sensitivity.get(category, False):
|
||||
self.audit_logger.log_sensitive_access(
|
||||
str(chunk['source_file']),
|
||||
category,
|
||||
'index',
|
||||
{'chunk_id': chunk['chunk_id']}
|
||||
)
|
||||
|
||||
# Embed chunks
|
||||
texts = [e["chunk_text"] for e in enriched]
|
||||
try:
|
||||
@@ -132,14 +191,13 @@ class Indexer:
|
||||
"total": total_files,
|
||||
}
|
||||
|
||||
return {
|
||||
# Yield final result
|
||||
yield {
|
||||
"indexed_files": indexed_files,
|
||||
"total_chunks": total_chunks,
|
||||
"duration_ms": 0, # caller can fill
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
def sync(self, on_progress: Iterator[dict] | None = None) -> dict[str, Any]:
|
||||
"""Incremental sync: only process files modified since last sync."""
|
||||
sync_result_path = self._sync_result_path()
|
||||
last_sync = None
|
||||
@@ -221,3 +279,78 @@ class Indexer:
|
||||
tmp = path.with_suffix(".json.tmp")
|
||||
tmp.write_text(json.dumps(result, indent=2))
|
||||
tmp.rename(path)
|
||||
|
||||
def sync(self, on_progress: Iterator[dict] | None = None) -> dict[str, Any]:
|
||||
"""Incremental sync: only process files modified since last sync."""
|
||||
sync_result_path = self._sync_result_path()
|
||||
last_sync = None
|
||||
if sync_result_path.exists():
|
||||
try:
|
||||
last_sync = json.loads(sync_result_path.read_text()).get("timestamp")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
db = get_db(self.config)
|
||||
table = create_table_if_not_exists(db)
|
||||
embedder = self.embedder
|
||||
|
||||
files = list(self.scan_vault())
|
||||
indexed_files = 0
|
||||
total_chunks = 0
|
||||
errors: list[dict] = []
|
||||
|
||||
for filepath in files:
|
||||
mtime = datetime.fromtimestamp(filepath.stat().st_mtime, tz=timezone.utc)
|
||||
mtime_str = mtime.isoformat()
|
||||
if last_sync and mtime_str <= last_sync:
|
||||
continue # unchanged
|
||||
|
||||
try:
|
||||
num_chunks, enriched = self.process_file(filepath)
|
||||
# Enforce sensitive content policies
|
||||
self._check_sensitive_content_approval(enriched)
|
||||
|
||||
# Log sensitive content access
|
||||
for chunk in enriched:
|
||||
from obsidian_rag import security
|
||||
sensitivity = security.detect_sensitive(
|
||||
chunk['chunk_text'],
|
||||
self.config.security.sensitive_sections,
|
||||
self.config.memory.patterns
|
||||
)
|
||||
for category in ['health', 'financial', 'relations']:
|
||||
if sensitivity.get(category, False):
|
||||
self.audit_logger.log_sensitive_access(
|
||||
str(chunk['source_file']),
|
||||
category,
|
||||
'index',
|
||||
{'chunk_id': chunk['chunk_id']}
|
||||
)
|
||||
|
||||
# Embed chunks
|
||||
texts = [e["chunk_text"] for e in enriched]
|
||||
try:
|
||||
vectors = embedder.embed_chunks(texts)
|
||||
except OllamaUnavailableError:
|
||||
vectors = [[0.0] * 1024 for _ in texts]
|
||||
for e, v in zip(enriched, vectors):
|
||||
e["vector"] = v
|
||||
upsert_chunks(table, enriched)
|
||||
total_chunks += num_chunks
|
||||
indexed_files += 1
|
||||
except Exception as exc:
|
||||
errors.append({"file": str(filepath), "error": str(exc)})
|
||||
|
||||
self._write_sync_result(indexed_files, total_chunks, errors)
|
||||
return {
|
||||
"indexed_files": indexed_files,
|
||||
"total_chunks": total_chunks,
|
||||
"errors": errors,
|
||||
}
|
||||
# Use the same dev-data-dir convention as config.py
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
data_dir = project_root / "obsidian-rag"
|
||||
if not data_dir.exists() and not (project_root / "KnowledgeVault").exists():
|
||||
data_dir = Path(os.path.expanduser("~/.obsidian-rag"))
|
||||
return data_dir / "sync-result.json"
|
||||
|
||||
|
||||
@@ -64,6 +64,19 @@ HTML_TAG_RE = re.compile(r"<[^>]+>")
|
||||
CODE_BLOCK_RE = re.compile(r"```[\s\S]*?```", re.MULTILINE)
|
||||
MULTI_WHITESPACE_RE = re.compile(r"\s+")
|
||||
MAX_CHUNK_LEN = 2000
|
||||
INJECTION_PATTERNS = [
|
||||
r"\x00", # Null bytes
|
||||
r"\x1a", # EOF character
|
||||
r"--\s", # SQL comment
|
||||
r"/\*[\s\S]*?\*/", # SQL comment
|
||||
r"';", # SQL injection
|
||||
r"\b(DROP|DELETE|INSERT|UPDATE|SELECT)\b", # SQL keywords
|
||||
r"<script[^>]*>.*?</script>", # XSS
|
||||
r"javascript:", # JS injection
|
||||
r"\b(eval|exec|spawn|fork|system)\b", # Code execution
|
||||
]
|
||||
|
||||
MAX_QUERY_LENGTH = 1000
|
||||
|
||||
|
||||
def sanitize_text(raw: str) -> str:
|
||||
@@ -86,6 +99,25 @@ def sanitize_text(raw: str) -> str:
|
||||
if len(text) > MAX_CHUNK_LEN:
|
||||
text = text[:MAX_CHUNK_LEN]
|
||||
return text
|
||||
"""Sanitize raw vault content before embedding.
|
||||
|
||||
- Strip HTML tags (prevent XSS)
|
||||
- Remove fenced code blocks
|
||||
- Normalize whitespace
|
||||
- Cap length at MAX_CHUNK_LEN chars
|
||||
"""
|
||||
# Remove fenced code blocks
|
||||
text = CODE_BLOCK_RE.sub(" ", raw)
|
||||
# Strip HTML tags
|
||||
text = HTML_TAG_RE.sub("", text)
|
||||
# Remove leading/trailing whitespace
|
||||
text = text.strip()
|
||||
# Normalize internal whitespace
|
||||
text = MULTI_WHITESPACE_RE.sub(" ", text)
|
||||
# Cap length
|
||||
if len(text) > MAX_CHUNK_LEN:
|
||||
text = text[:MAX_CHUNK_LEN]
|
||||
return text
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
@@ -93,6 +125,26 @@ def sanitize_text(raw: str) -> str:
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
def sanitize_query(query: str) -> str:
|
||||
"""Sanitize search query to prevent prompt injection.
|
||||
|
||||
- Remove injection patterns
|
||||
- Normalize whitespace
|
||||
- Limit length
|
||||
"""
|
||||
# Remove injection patterns
|
||||
for pattern in INJECTION_PATTERNS:
|
||||
query = re.sub(pattern, " ", query, flags=re.IGNORECASE)
|
||||
|
||||
# Normalize whitespace
|
||||
query = MULTI_WHITESPACE_RE.sub(" ", query.strip())
|
||||
|
||||
# Limit length
|
||||
if len(query) > MAX_QUERY_LENGTH:
|
||||
query = query[:MAX_QUERY_LENGTH]
|
||||
|
||||
return query
|
||||
|
||||
def detect_sensitive(
|
||||
text: str,
|
||||
sensitive_sections: list[str],
|
||||
|
||||
149
python/tests/unit/test_security_fixes.py
Normal file
149
python/tests/unit/test_security_fixes.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Tests for security fixes."""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from obsidian_rag.config import ObsidianRagConfig, SecurityConfig
|
||||
from obsidian_rag.embedder import OllamaEmbedder, SecurityError
|
||||
from obsidian_rag.indexer import Indexer, SensitiveContentError
|
||||
from obsidian_rag.security import sanitize_query
|
||||
|
||||
|
||||
def test_network_isolation_validation():
|
||||
"""Test that remote URLs are rejected when local_only=True."""
|
||||
config = ObsidianRagConfig(
|
||||
vault_path="/tmp/test",
|
||||
security=SecurityConfig(local_only=True)
|
||||
)
|
||||
|
||||
# Should allow localhost
|
||||
config.embedding.base_url = "http://localhost:11434"
|
||||
try:
|
||||
embedder = OllamaEmbedder(config)
|
||||
embedder._validate_network_isolation()
|
||||
except SecurityError:
|
||||
pytest.fail("Localhost should be allowed when local_only=True")
|
||||
|
||||
# Should reject remote URLs
|
||||
config.embedding.base_url = "http://example.com:11434"
|
||||
with pytest.raises(SecurityError, match="Remote embedding service not allowed"):
|
||||
embedder = OllamaEmbedder(config)
|
||||
embedder._validate_network_isolation()
|
||||
|
||||
# Should allow remote URLs when local_only=False
|
||||
config.security.local_only = False
|
||||
config.embedding.base_url = "http://example.com:11434"
|
||||
try:
|
||||
embedder = OllamaEmbedder(config)
|
||||
embedder._validate_network_isolation()
|
||||
except SecurityError:
|
||||
pytest.fail("Remote URLs should be allowed when local_only=False")
|
||||
|
||||
|
||||
def test_sensitive_content_enforcement():
|
||||
"""Test that sensitive content requires approval."""
|
||||
from obsidian_rag.config import MemoryConfig
|
||||
|
||||
config = ObsidianRagConfig(
|
||||
vault_path="/tmp/test",
|
||||
security=SecurityConfig(
|
||||
require_confirmation_for=["health"],
|
||||
auto_approve_sensitive=False,
|
||||
sensitive_sections=["#mentalhealth", "#physicalhealth", "#Relations"]
|
||||
),
|
||||
memory=MemoryConfig(
|
||||
patterns={
|
||||
"financial": ["owe", "owed", "debt", "paid", "$", "spent", "spend"],
|
||||
"health": ["#mentalhealth", "#physicalhealth", "medication", "therapy"],
|
||||
"commitments": ["shopping list", "costco", "amazon", "grocery"],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
indexer = Indexer(config)
|
||||
|
||||
# Create test chunks with health content
|
||||
chunks = [
|
||||
{
|
||||
'chunk_id': '1',
|
||||
'chunk_text': 'I have #mentalhealth issues and need therapy',
|
||||
'source_file': '/tmp/test/file.md',
|
||||
'source_directory': '/tmp/test',
|
||||
'section': 'content',
|
||||
'date': '2024-01-01',
|
||||
'tags': ['mentalhealth'],
|
||||
'chunk_index': 0,
|
||||
'total_chunks': 1,
|
||||
'modified_at': '2024-01-01T00:00:00Z',
|
||||
'indexed_at': '2024-01-01T00:00:00Z',
|
||||
}
|
||||
]
|
||||
|
||||
# Should raise SensitiveContentError
|
||||
with pytest.raises(SensitiveContentError, match="Sensitive health content detected"):
|
||||
indexer._check_sensitive_content_approval(chunks)
|
||||
|
||||
# Should pass when auto_approve_sensitive=True
|
||||
config.security.auto_approve_sensitive = True
|
||||
indexer = Indexer(config)
|
||||
try:
|
||||
indexer._check_sensitive_content_approval(chunks)
|
||||
except SensitiveContentError:
|
||||
pytest.fail("Should not raise when auto_approve_sensitive=True")
|
||||
|
||||
|
||||
def test_query_sanitization():
|
||||
"""Test that queries are properly sanitized."""
|
||||
# Test injection patterns
|
||||
dirty_query = "test'; DROP TABLE users; --"
|
||||
clean_query = sanitize_query(dirty_query)
|
||||
# The regex should remove the SQL injection pattern
|
||||
assert "'" not in clean_query or ";" not in clean_query
|
||||
|
||||
# Test that SQL keywords are removed
|
||||
sql_query = "SELECT * FROM users WHERE id = 1"
|
||||
clean_sql = sanitize_query(sql_query)
|
||||
assert "SELECT" not in clean_sql
|
||||
|
||||
# Test length limiting
|
||||
long_query = "a" * 2000
|
||||
short_query = sanitize_query(long_query)
|
||||
assert len(short_query) <= 1000
|
||||
|
||||
# Test whitespace normalization
|
||||
messy_query = " test \n query \t"
|
||||
clean_query = sanitize_query(messy_query)
|
||||
assert clean_query == "test query"
|
||||
|
||||
|
||||
def test_audit_logging():
|
||||
"""Test that audit logging works correctly."""
|
||||
from obsidian_rag.audit_logger import AuditLogger
|
||||
import tempfile
|
||||
import json
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
log_path = Path(tmpdir) / "audit.log"
|
||||
logger = AuditLogger(log_path)
|
||||
|
||||
# Log sensitive access
|
||||
logger.log_sensitive_access(
|
||||
"/tmp/test/health.md",
|
||||
"health",
|
||||
"index",
|
||||
{"chunk_id": "123"}
|
||||
)
|
||||
|
||||
# Verify log was created
|
||||
assert log_path.exists()
|
||||
|
||||
# Verify log content
|
||||
logs = json.loads(log_path.read_text())
|
||||
assert len(logs) == 1
|
||||
entry = logs[0]
|
||||
assert entry['content_type'] == 'health'
|
||||
assert entry['action'] == 'index'
|
||||
assert entry['metadata']['chunk_id'] == '123'
|
||||
|
||||
# Verify permissions
|
||||
import stat
|
||||
assert stat.S_IMODE(log_path.stat().st_mode) == 0o600
|
||||
Reference in New Issue
Block a user