216 lines
6.3 KiB
Python
216 lines
6.3 KiB
Python
"""Path traversal prevention, input sanitization, sensitive content detection, directory access control."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
import unicodedata
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from obsidian_rag.config import ObsidianRagConfig
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Path traversal
|
|
# ----------------------------------------------------------------------
|
|
|
|
|
|
def validate_path(requested: Path, vault_root: Path) -> Path:
|
|
"""Resolve requested relative to vault_root and reject anything escaping the vault.
|
|
|
|
Raises ValueError on traversal attempts.
|
|
"""
|
|
# Resolve both to absolute paths
|
|
vault = vault_root.resolve()
|
|
try:
|
|
resolved = (vault / requested).resolve()
|
|
except (OSError, ValueError) as e:
|
|
raise ValueError(f"Cannot resolve path: {requested}") from e
|
|
|
|
# Check the resolved path is under vault
|
|
try:
|
|
resolved.relative_to(vault)
|
|
except ValueError:
|
|
raise ValueError(f"Path traversal attempt blocked: {requested} resolves outside vault")
|
|
|
|
# Reject obvious traversal
|
|
if ".." in requested.parts:
|
|
raise ValueError(f"Path traversal attempt blocked: {requested}")
|
|
|
|
return resolved
|
|
|
|
|
|
def is_symlink_outside_vault(path: Path, vault_root: Path) -> bool:
|
|
"""Check if path is a symlink that resolves outside the vault."""
|
|
try:
|
|
resolved = path.resolve()
|
|
vault = vault_root.resolve()
|
|
# Check if any parent (including self) is outside vault
|
|
try:
|
|
resolved.relative_to(vault)
|
|
return False
|
|
except ValueError:
|
|
return True
|
|
except (OSError, ValueError):
|
|
return True
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Input sanitization
|
|
# ----------------------------------------------------------------------
|
|
|
|
|
|
HTML_TAG_RE = re.compile(r"<[^>]+>")
|
|
CODE_BLOCK_RE = re.compile(r"```[\s\S]*?```", re.MULTILINE)
|
|
MULTI_WHITESPACE_RE = re.compile(r"\s+")
|
|
MAX_CHUNK_LEN = 2000
|
|
INJECTION_PATTERNS = [
|
|
r"\x00", # Null bytes
|
|
r"\x1a", # EOF character
|
|
r"--\s", # SQL comment
|
|
r"/\*[\s\S]*?\*/", # SQL comment
|
|
r"';", # SQL injection
|
|
r"\b(DROP|DELETE|INSERT|UPDATE|SELECT)\b", # SQL keywords
|
|
r"<script[^>]*>.*?</script>", # XSS
|
|
r"javascript:", # JS injection
|
|
r"\b(eval|exec|spawn|fork|system)\b", # Code execution
|
|
]
|
|
|
|
MAX_QUERY_LENGTH = 1000
|
|
|
|
|
|
def sanitize_text(raw: str) -> str:
|
|
"""Sanitize raw vault content before embedding.
|
|
|
|
- Strip HTML tags (prevent XSS)
|
|
- Remove fenced code blocks
|
|
- Normalize whitespace
|
|
- Cap length at MAX_CHUNK_LEN chars
|
|
"""
|
|
# Remove fenced code blocks
|
|
text = CODE_BLOCK_RE.sub(" ", raw)
|
|
# Strip HTML tags
|
|
text = HTML_TAG_RE.sub("", text)
|
|
# Remove leading/trailing whitespace
|
|
text = text.strip()
|
|
# Normalize internal whitespace
|
|
text = MULTI_WHITESPACE_RE.sub(" ", text)
|
|
# Cap length
|
|
if len(text) > MAX_CHUNK_LEN:
|
|
text = text[:MAX_CHUNK_LEN]
|
|
return text
|
|
"""Sanitize raw vault content before embedding.
|
|
|
|
- Strip HTML tags (prevent XSS)
|
|
- Remove fenced code blocks
|
|
- Normalize whitespace
|
|
- Cap length at MAX_CHUNK_LEN chars
|
|
"""
|
|
# Remove fenced code blocks
|
|
text = CODE_BLOCK_RE.sub(" ", raw)
|
|
# Strip HTML tags
|
|
text = HTML_TAG_RE.sub("", text)
|
|
# Remove leading/trailing whitespace
|
|
text = text.strip()
|
|
# Normalize internal whitespace
|
|
text = MULTI_WHITESPACE_RE.sub(" ", text)
|
|
# Cap length
|
|
if len(text) > MAX_CHUNK_LEN:
|
|
text = text[:MAX_CHUNK_LEN]
|
|
return text
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Sensitive content detection
|
|
# ----------------------------------------------------------------------
|
|
|
|
|
|
def sanitize_query(query: str) -> str:
|
|
"""Sanitize search query to prevent prompt injection.
|
|
|
|
- Remove injection patterns
|
|
- Normalize whitespace
|
|
- Limit length
|
|
"""
|
|
# Remove injection patterns
|
|
for pattern in INJECTION_PATTERNS:
|
|
query = re.sub(pattern, " ", query, flags=re.IGNORECASE)
|
|
|
|
# Normalize whitespace
|
|
query = MULTI_WHITESPACE_RE.sub(" ", query.strip())
|
|
|
|
# Limit length
|
|
if len(query) > MAX_QUERY_LENGTH:
|
|
query = query[:MAX_QUERY_LENGTH]
|
|
|
|
return query
|
|
|
|
def detect_sensitive(
|
|
text: str,
|
|
sensitive_sections: list[str],
|
|
patterns: dict[str, list[str]],
|
|
) -> dict[str, bool]:
|
|
"""Detect sensitive content categories in text.
|
|
|
|
Returns dict with keys: health, financial, relations.
|
|
"""
|
|
text_lower = text.lower()
|
|
result: dict[str, bool] = {
|
|
"health": False,
|
|
"financial": False,
|
|
"relations": False,
|
|
}
|
|
|
|
# Check for sensitive section headings in the text
|
|
for section in sensitive_sections:
|
|
if section.lower() in text_lower:
|
|
result["health"] = result["health"] or section.lower() in ["#mentalhealth", "#physicalhealth"]
|
|
|
|
# Pattern matching
|
|
financial_patterns = patterns.get("financial", [])
|
|
health_patterns = patterns.get("health", [])
|
|
|
|
for pat in financial_patterns:
|
|
if pat.lower() in text_lower:
|
|
result["financial"] = True
|
|
break
|
|
|
|
for pat in health_patterns:
|
|
if pat.lower() in text_lower:
|
|
result["health"] = True
|
|
break
|
|
|
|
return result
|
|
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Directory access control
|
|
# ----------------------------------------------------------------------
|
|
|
|
|
|
def should_index_dir(
|
|
dir_name: str,
|
|
config: "ObsidianRagConfig",
|
|
) -> bool:
|
|
"""Apply deny/allow list rules to a directory.
|
|
|
|
If allow_dirs is non-empty, only those dirs are allowed.
|
|
If deny_dirs matches, the dir is rejected.
|
|
Hidden dirs (starting with '.') are always rejected.
|
|
"""
|
|
# Always reject hidden directories
|
|
if dir_name.startswith("."):
|
|
return False
|
|
|
|
# If allow list is set, only those dirs are allowed
|
|
if config.indexing.allow_dirs:
|
|
return dir_name in config.indexing.allow_dirs
|
|
|
|
# Otherwise reject any deny-listed directory
|
|
deny = config.indexing.deny_dirs
|
|
return dir_name not in deny
|
|
|
|
|
|
def filter_tags(text: str) -> list[str]:
|
|
"""Extract all #hashtags from text, lowercased and deduplicated."""
|
|
return list(dict.fromkeys(tag.lower() for tag in re.findall(r"#\w+", text))) |