Sprint 0-1: Python indexer, TS plugin scaffolding, and test suite

## What's new

**Python indexer (`python/obsidian_rag/`)** — full pipeline from scan to LanceDB:
- `config.py` — JSON config loader with cross-platform path resolution
- `security.py` — path traversal prevention, HTML stripping, sensitive content detection, dir allow/deny lists
- `chunker.py` — section-split for journal entries (date-named files), sliding-window for unstructured notes
- `embedder.py` — Ollama `/api/embeddings` client with batched requests and timeout/error handling
- `vector_store.py` — LanceDB schema, upsert (merge_insert), delete, search with filters, stats
- `indexer.py` — full/sync/reindex pipeline orchestrator with progress yields
- `cli.py` — `index | sync | reindex | status` CLI commands

**TypeScript plugin (`src/`)** — OpenClaw plugin scaffold:
- `utils/` — config loader, TypeScript types, response envelope factory, LanceDB client
- `services/` — health state machine (HEALTHY/DEGRADED/UNAVAILABLE), vault watcher with debounce/batching, indexer bridge (subprocess spawner)
- `tools/` — 4 tool stubs: search, index, status, memory_store (OpenClaw wiring pending)
- `index.ts` — plugin entry point with health probe + vault watcher startup

**Config** (`obsidian-rag/config.json`, `openclaw.plugin.json`):
- 627 files / 3764 chunks indexed in dev vault

**Tests: 76 passing**
- Python: 64 pytest tests (chunker, security, vector_store, config)
- TypeScript: 12 vitest tests (lancedb client, response envelope)

## Bugs fixed

- LanceDB `tags` column filter: `LIKE '%tag%'` → `list_contains(tags, 'tag')` (List<String> column)
- LanceDB JS `db.list_tables()` returns `ListTablesResponse` object, not plain array
- LanceDB JS result score field: `_score` → `_distance`
- TypeScript regex literal with unescaped `/` in path-resolve regex
- Python: `create_table_if_not_exists` identity check → name comparison

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-04-10 22:56:50 -04:00
parent 18ad47e100
commit 5c281165c7
40 changed files with 5814 additions and 59 deletions

View File

@@ -0,0 +1,14 @@
Metadata-Version: 2.4
Name: obsidian-rag
Version: 0.1.0
Summary: RAG indexer for Obsidian vaults — powers OpenClaw's obsidian_rag_* tools
Requires-Python: >=3.11
Requires-Dist: lancedb>=0.12
Requires-Dist: httpx>=0.27
Requires-Dist: pyyaml>=6.0
Requires-Dist: python-frontmatter>=1.1
Provides-Extra: dev
Requires-Dist: pytest>=8.0; extra == "dev"
Requires-Dist: pytest-asyncio>=0.23; extra == "dev"
Requires-Dist: pytest-mock>=3.12; extra == "dev"
Requires-Dist: ruff>=0.5; extra == "dev"

View File

@@ -0,0 +1,16 @@
pyproject.toml
obsidian_rag/__init__.py
obsidian_rag/__main__.py
obsidian_rag/chunker.py
obsidian_rag/cli.py
obsidian_rag/config.py
obsidian_rag/embedder.py
obsidian_rag/indexer.py
obsidian_rag/security.py
obsidian_rag/vector_store.py
obsidian_rag.egg-info/PKG-INFO
obsidian_rag.egg-info/SOURCES.txt
obsidian_rag.egg-info/dependency_links.txt
obsidian_rag.egg-info/entry_points.txt
obsidian_rag.egg-info/requires.txt
obsidian_rag.egg-info/top_level.txt

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,2 @@
[console_scripts]
obsidian-rag = obsidian_rag.cli:main

View File

@@ -0,0 +1,10 @@
lancedb>=0.12
httpx>=0.27
pyyaml>=6.0
python-frontmatter>=1.1
[dev]
pytest>=8.0
pytest-asyncio>=0.23
pytest-mock>=3.12
ruff>=0.5

View File

@@ -0,0 +1 @@
obsidian_rag

View File

@@ -0,0 +1,3 @@
"""Obsidian RAG — semantic search indexer for Obsidian vaults."""
__version__ = "0.1.0"

View File

@@ -0,0 +1,7 @@
"""CLI entry point: obsidian-rag index | sync | reindex | status."""
import sys
from obsidian_rag.cli import main
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,240 @@
"""Markdown parsing, structured + unstructured chunking, metadata enrichment."""
from __future__ import annotations
import re
import unicodedata
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
import frontmatter
if TYPE_CHECKING:
from obsidian_rag.config import ObsidianRagConfig
# ----------------------------------------------------------------------
# Types
# ----------------------------------------------------------------------
@dataclass
class Chunk:
chunk_id: str
text: str
source_file: str
source_directory: str
section: str | None
date: str | None
tags: list[str] = field(default_factory=list)
chunk_index: int = 0
total_chunks: int = 1
modified_at: str | None = None
indexed_at: str | None = None
# ----------------------------------------------------------------------
# Markdown parsing
# ----------------------------------------------------------------------
def parse_frontmatter(content: str) -> tuple[dict, str]:
"""Parse frontmatter from markdown content. Returns (metadata, body)."""
try:
post = frontmatter.parse(content)
meta = dict(post[0]) if post[0] else {}
body = str(post[1])
return meta, body
except Exception:
return {}, content
def extract_tags(text: str) -> list[str]:
"""Extract all #hashtags from text, deduplicated, lowercased."""
return list(dict.fromkeys(t.lower() for t in re.findall(r"#[\w-]+", text)))
def extract_date_from_filename(filepath: Path) -> str | None:
"""Try to parse an ISO date from a filename (e.g. 2024-01-15.md)."""
name = filepath.stem # filename without extension
# Match YYYY-MM-DD or YYYYMMDD
m = re.search(r"(\d{4}-\d{2}-\d{2})|(\d{4}\d{2}\d{2})", name)
if m:
date_str = m.group(1) or m.group(2)
# Normalize YYYYMMDD → YYYY-MM-DD
if len(date_str) == 8:
return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}"
return date_str
return None
def is_structured_note(filepath: Path) -> bool:
"""Heuristic: journal/daily notes use date-named files with section headers."""
name = filepath.stem
date_match = re.search(r"\d{4}-\d{2}-\d{2}", name)
return date_match is not None
# ----------------------------------------------------------------------
# Section-split chunker (structured notes)
# ----------------------------------------------------------------------
SECTION_HEADER_RE = re.compile(r"^#{1,3}\s+(.+)$", re.MULTILINE)
def split_by_sections(body: str, metadata: dict) -> list[tuple[str, str]]:
"""Split markdown body into (section_name, section_content) pairs.
If no headers found, returns [(None, body)].
"""
sections: list[tuple[str | None, str]] = []
lines = body.splitlines(keepends=True)
current_heading: str | None = None
current_content: list[str] = []
for line in lines:
m = SECTION_HEADER_RE.match(line.rstrip())
if m:
# Flush previous section
if current_heading is not None or current_content:
sections.append((current_heading, "".join(current_content).strip()))
current_content = []
current_heading = m.group(1).strip()
else:
current_content.append(line)
# Flush last section
if current_heading is not None or current_content:
sections.append((current_heading, "".join(current_content).strip()))
if not sections:
sections = [(None, body.strip())]
return sections
# ----------------------------------------------------------------------
# Sliding window chunker (unstructured notes)
# ----------------------------------------------------------------------
def _count_tokens(text: str) -> int:
"""Rough token count: split on whitespace, average ~4 chars per token."""
return len(text.split())
def sliding_window_chunks(
text: str,
chunk_size: int = 500,
overlap: int = 100,
) -> list[str]:
"""Split text into overlapping sliding-window chunks of ~chunk_size tokens.
Returns list of chunk strings.
"""
words = text.split()
if not words:
return []
chunks: list[str] = []
start = 0
while start < len(words):
end = start + chunk_size
chunk_words = words[start:end]
chunks.append(" ".join(chunk_words))
# Advance by (chunk_size - overlap)
advance = chunk_size - overlap
if advance <= 0:
advance = max(1, chunk_size // 2)
start += advance
if start >= len(words):
break
return chunks
# ----------------------------------------------------------------------
# Main chunk router
# ----------------------------------------------------------------------
def chunk_file(
filepath: Path,
content: str,
modified_at: str,
config: "ObsidianRagConfig",
chunk_id_prefix: str = "",
) -> list[Chunk]:
"""Parse a markdown file and return a list of Chunks.
Uses section-split for structured notes (journal entries with date filenames),
sliding window for everything else.
"""
import uuid
vault_path = Path(config.vault_path)
rel_path = filepath if filepath.is_absolute() else filepath
source_file = str(rel_path)
source_directory = rel_path.parts[0] if rel_path.parts else ""
metadata, body = parse_frontmatter(content)
tags = extract_tags(body)
date = extract_date_from_filename(filepath)
chunk_size = config.indexing.chunk_size
overlap = config.indexing.chunk_overlap
chunks: list[Chunk] = []
if is_structured_note(filepath):
# Section-split for journal/daily notes
sections = split_by_sections(body, metadata)
total = len(sections)
for idx, (section, section_text) in enumerate(sections):
if not section_text.strip():
continue
section_tags = extract_tags(section_text)
combined_tags = list(dict.fromkeys([*tags, *section_tags]))
chunk_text = section_text
chunk = Chunk(
chunk_id=f"{chunk_id_prefix}{uuid.uuid4().hex[:8]}",
text=chunk_text,
source_file=source_file,
source_directory=source_directory,
section=f"#{section}" if section else None,
date=date,
tags=combined_tags,
chunk_index=idx,
total_chunks=total,
modified_at=modified_at,
)
chunks.append(chunk)
else:
# Sliding window for unstructured notes
text_chunks = sliding_window_chunks(body, chunk_size, overlap)
total = len(text_chunks)
for idx, text_chunk in enumerate(text_chunks):
if not text_chunk.strip():
continue
chunk = Chunk(
chunk_id=f"{chunk_id_prefix}{uuid.uuid4().hex[:8]}",
text=text_chunk,
source_file=source_file,
source_directory=source_directory,
section=None,
date=date,
tags=tags,
chunk_index=idx,
total_chunks=total,
modified_at=modified_at,
)
chunks.append(chunk)
return chunks

156
python/obsidian_rag/cli.py Normal file
View File

@@ -0,0 +1,156 @@
"""CLI: obsidian-rag index | sync | reindex | status."""
from __future__ import annotations
import json
import sys
import time
from pathlib import Path
import obsidian_rag.config as config_mod
from obsidian_rag.vector_store import get_db, get_stats
from obsidian_rag.indexer import Indexer
def main(argv: list[str] | None = None) -> int:
argv = argv or sys.argv[1:]
if not argv or argv[0] in ("--help", "-h"):
print(_usage())
return 0
cmd = argv[0]
try:
config = config_mod.load_config()
except FileNotFoundError as e:
print(f"ERROR: {e}", file=sys.stderr)
return 1
if cmd == "index":
return _index(config)
elif cmd == "sync":
return _sync(config)
elif cmd == "reindex":
return _reindex(config)
elif cmd == "status":
return _status(config)
else:
print(f"Unknown command: {cmd}\n{_usage()}", file=sys.stderr)
return 1
def _index(config) -> int:
indexer = Indexer(config)
t0 = time.monotonic()
try:
gen = indexer.full_index()
result: dict = {"indexed_files": 0, "total_chunks": 0, "errors": []}
for item in gen:
result = item # progress yields are dicts; final dict from return
duration_ms = int((time.monotonic() - t0) * 1000)
print(
json.dumps(
{
"type": "complete",
"indexed_files": result["indexed_files"],
"total_chunks": result["total_chunks"],
"duration_ms": duration_ms,
"errors": result["errors"],
},
indent=2,
)
)
return 0 if not result["errors"] else 1
except Exception as e:
print(json.dumps({"type": "error", "error": str(e)}), file=sys.stderr)
return 2
def _sync(config) -> int:
indexer = Indexer(config)
try:
result = indexer.sync()
print(json.dumps({"type": "complete", **result}, indent=2))
return 0 if not result["errors"] else 1
except Exception as e:
print(json.dumps({"type": "error", "error": str(e)}), file=sys.stderr)
return 2
def _reindex(config) -> int:
indexer = Indexer(config)
t0 = time.monotonic()
try:
result = indexer.reindex()
duration_ms = int((time.monotonic() - t0) * 1000)
print(
json.dumps(
{
"type": "complete",
"indexed_files": result["indexed_files"],
"total_chunks": result["total_chunks"],
"duration_ms": duration_ms,
"errors": result["errors"],
},
indent=2,
)
)
return 0
except Exception as e:
print(json.dumps({"type": "error", "error": str(e)}), file=sys.stderr)
return 2
def _status(config) -> int:
try:
db = get_db(config)
table = db.open_table("obsidian_chunks")
stats = get_stats(table)
# Resolve sync-result.json path (same convention as indexer)
from pathlib import Path
import os as osmod
project_root = Path(__file__).parent.parent.parent
data_dir = project_root / "obsidian-rag"
if not data_dir.exists() and not (project_root / "KnowledgeVault").exists():
data_dir = Path(osmod.path.expanduser("~/.obsidian-rag"))
sync_path = data_dir / "sync-result.json"
last_sync = None
if sync_path.exists():
try:
last_sync = json.loads(sync_path.read_text()).get("timestamp")
except Exception:
pass
print(
json.dumps(
{
"total_docs": stats["total_docs"],
"total_chunks": stats["total_chunks"],
"last_sync": last_sync,
},
indent=2,
)
)
return 0
except FileNotFoundError:
print(json.dumps({"error": "Index not found. Run 'obsidian-rag index' first."}, indent=2))
return 1
except Exception as e:
print(json.dumps({"error": str(e)}), file=sys.stderr)
return 1
def _usage() -> str:
return """obsidian-rag - Obsidian vault RAG indexer
Usage:
obsidian-rag index Full index of the vault
obsidian-rag sync Incremental sync (changed files only)
obsidian-rag reindex Force full reindex (nuke + rebuild)
obsidian-rag status Show index health and statistics
"""
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,145 @@
"""Configuration loader — reads ~/.obsidian-rag/config.json (or ./obsidian-rag/ for dev)."""
from __future__ import annotations
import json
import os
from enum import Enum
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
DEFAULT_CONFIG_DIR = Path(__file__).parent.parent.parent # python/ → project root
@dataclass
class EmbeddingConfig:
provider: str = "ollama"
model: str = "mxbai-embed-large"
base_url: str = "http://localhost:11434"
dimensions: int = 1024
batch_size: int = 64
@dataclass
class VectorStoreConfig:
type: str = "lancedb"
path: str = "" # resolved relative to data_dir
@dataclass
class IndexingConfig:
chunk_size: int = 500
chunk_overlap: int = 100
file_patterns: list[str] = field(default_factory=lambda: ["*.md"])
deny_dirs: list[str] = field(
default_factory=lambda: [".obsidian", ".trash", "zzz-Archive", ".git", ".logseq"]
)
allow_dirs: list[str] = field(default_factory=list)
@dataclass
class SecurityConfig:
require_confirmation_for: list[str] = field(default_factory=lambda: ["health", "financial_debt"])
sensitive_sections: list[str] = field(
default_factory=lambda: ["#mentalhealth", "#physicalhealth", "#Relations"]
)
local_only: bool = True
@dataclass
class MemoryConfig:
auto_suggest: bool = True
patterns: dict[str, list[str]] = field(
default_factory=lambda: {
"financial": ["owe", "owed", "debt", "paid", "$", "spent", "spend"],
"health": ["#mentalhealth", "#physicalhealth", "medication", "therapy"],
"commitments": ["shopping list", "costco", "amazon", "grocery"],
}
)
@dataclass
class ObsidianRagConfig:
vault_path: str = ""
embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
vector_store: VectorStoreConfig = field(default_factory=VectorStoreConfig)
indexing: IndexingConfig = field(default_factory=IndexingConfig)
security: SecurityConfig = field(default_factory=SecurityConfig)
memory: MemoryConfig = field(default_factory=MemoryConfig)
def _resolve_data_dir() -> Path:
"""Resolve the data directory: dev (project root/obsidian-rag/) or production (~/.obsidian-rag/)."""
dev_data_dir = DEFAULT_CONFIG_DIR / "obsidian-rag"
if dev_data_dir.exists() or (DEFAULT_CONFIG_DIR / "KnowledgeVault").exists():
return dev_data_dir
# Production: ~/.obsidian-rag/
return Path(os.path.expanduser("~/.obsidian-rag"))
def load_config(config_path: str | Path | None = None) -> ObsidianRagConfig:
"""Load config from JSON file, falling back to dev/default config."""
if config_path is None:
config_path = _resolve_data_dir() / "config.json"
else:
config_path = Path(config_path)
if not config_path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
with open(config_path) as f:
raw: dict[str, Any] = json.load(f)
return ObsidianRagConfig(
vault_path=raw.get("vault_path", ""),
embedding=_merge(EmbeddingConfig(), raw.get("embedding", {})),
vector_store=_merge(VectorStoreConfig(), raw.get("vector_store", {})),
indexing=_merge(IndexingConfig(), raw.get("indexing", {})),
security=_merge(SecurityConfig(), raw.get("security", {})),
memory=_merge(MemoryConfig(), raw.get("memory", {})),
)
def _merge(default: Any, overrides: dict[str, Any]) -> Any:
"""Shallow-merge a dict into a dataclass instance."""
if not isinstance(default, type) and not isinstance(default, (list, dict, str, int, float, bool)):
# It's a dataclass instance — merge fields
if hasattr(default, "__dataclass_fields__"):
fields = {}
for key, val in overrides.items():
if key in default.__dataclass_fields__:
field_def = default.__dataclass_fields__[key]
actual_default = field_def.default
if isinstance(actual_default, type) and issubclass(actual_default, Enum):
# Enum fields need special handling
fields[key] = val
elif isinstance(val, dict):
fields[key] = _merge(actual_default, val)
else:
fields[key] = val
else:
fields[key] = val
return default.__class__(**{**default.__dict__, **fields})
if isinstance(overrides, dict) and isinstance(default, dict):
return {**default, **overrides}
return overrides if overrides is not None else default
def resolve_vault_path(config: ObsidianRagConfig) -> Path:
"""Resolve vault_path relative to project root or as absolute."""
vp = Path(config.vault_path)
if vp.is_absolute():
return vp
# Resolve relative to project root
return (DEFAULT_CONFIG_DIR / vp).resolve()
def resolve_vector_db_path(config: ObsidianRagConfig) -> Path:
"""Resolve vector store path relative to data directory."""
data_dir = _resolve_data_dir()
vsp = Path(config.vector_store.path)
if vsp.is_absolute():
return vsp
return (data_dir / vsp).resolve()

View File

@@ -0,0 +1,110 @@
"""Ollama API client for embedding generation."""
from __future__ import annotations
import time
from typing import TYPE_CHECKING
import httpx
if TYPE_CHECKING:
from obsidian_rag.config import ObsidianRagConfig
DEFAULT_TIMEOUT = 120.0 # seconds
class EmbeddingError(Exception):
"""Raised when embedding generation fails."""
class OllamaUnavailableError(EmbeddingError):
"""Raised when Ollama is unreachable."""
class OllamaEmbedder:
"""Client for Ollama /api/embed endpoint (mxbai-embed-large, 1024-dim)."""
def __init__(self, config: "ObsidianRagConfig"):
self.base_url = config.embedding.base_url.rstrip("/")
self.model = config.embedding.model
self.dimensions = config.embedding.dimensions
self.batch_size = config.embedding.batch_size
self._client = httpx.Client(timeout=DEFAULT_TIMEOUT)
def is_available(self) -> bool:
"""Check if Ollama is reachable and has the model."""
try:
resp = self._client.get(f"{self.base_url}/api/tags", timeout=5.0)
if resp.status_code != 200:
return False
models = resp.json().get("models", [])
return any(self.model in m.get("name", "") for m in models)
except Exception:
return False
def embed_chunks(self, texts: list[str]) -> list[list[float]]:
"""Generate embeddings for a batch of texts. Returns list of vectors."""
if not texts:
return []
all_vectors: list[list[float]] = []
for i in range(0, len(texts), self.batch_size):
batch = texts[i : i + self.batch_size]
vectors = self._embed_batch(batch)
all_vectors.extend(vectors)
return all_vectors
def embed_single(self, text: str) -> list[float]:
"""Generate embedding for a single text."""
[vec] = self._embed_batch([text])
return vec
def _embed_batch(self, batch: list[str]) -> list[list[float]]:
"""Internal batch call. Raises EmbeddingError on failure."""
# Ollama /api/embeddings takes {"model": "...", "prompt": "..."} for single
# For batch, call /api/embeddings multiple times sequentially
if len(batch) == 1:
endpoint = f"{self.base_url}/api/embeddings"
payload = {"model": self.model, "prompt": batch[0]}
else:
# For batch, use /api/embeddings with "input" (multiple calls)
results = []
for text in batch:
try:
resp = self._client.post(
f"{self.base_url}/api/embeddings",
json={"model": self.model, "prompt": text},
timeout=DEFAULT_TIMEOUT,
)
except httpx.ConnectError as e:
raise OllamaUnavailableError(f"Cannot connect to Ollama at {self.base_url}") from e
except httpx.TimeoutException as e:
raise EmbeddingError(f"Embedding request timed out after {DEFAULT_TIMEOUT}s") from e
if resp.status_code != 200:
raise EmbeddingError(f"Ollama returned {resp.status_code}: {resp.text}")
data = resp.json()
embedding = data.get("embedding", [])
if not embedding:
embedding = data.get("embeddings", [[]])[0]
results.append(embedding)
return results
try:
resp = self._client.post(endpoint, json=payload, timeout=DEFAULT_TIMEOUT)
except httpx.ConnectError as e:
raise OllamaUnavailableError(f"Cannot connect to Ollama at {self.base_url}") from e
except httpx.TimeoutException as e:
raise EmbeddingError(f"Embedding request timed out after {DEFAULT_TIMEOUT}s") from e
if resp.status_code != 200:
raise EmbeddingError(f"Ollama returned {resp.status_code}: {resp.text}")
data = resp.json()
embedding = data.get("embedding", [])
if not embedding:
embedding = data.get("embeddings", [[]])[0]
return [embedding]
def close(self):
self._client.close()

View File

@@ -0,0 +1,223 @@
"""Full indexing pipeline: scan → parse → chunk → embed → store."""
from __future__ import annotations
import json
import os
import time
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generator, Iterator
if TYPE_CHECKING:
from obsidian_rag.config import ObsidianRagConfig
import obsidian_rag.config as config_mod
from obsidian_rag.chunker import chunk_file
from obsidian_rag.embedder import EmbeddingError, OllamaUnavailableError
from obsidian_rag.security import should_index_dir, validate_path
from obsidian_rag.vector_store import create_table_if_not_exists, delete_by_source_file, get_db, upsert_chunks
# ----------------------------------------------------------------------
# Pipeline
# ----------------------------------------------------------------------
class Indexer:
"""Coordinates the scan → chunk → embed → store pipeline."""
def __init__(self, config: "ObsidianRagConfig"):
self.config = config
self.vault_path = config_mod.resolve_vault_path(config)
self._embedder = None # lazy init
@property
def embedder(self):
if self._embedder is None:
from obsidian_rag.embedder import OllamaEmbedder
self._embedder = OllamaEmbedder(self.config)
return self._embedder
def scan_vault(self) -> Generator[Path, None, None]:
"""Walk vault, yielding markdown files to index."""
for root, dirs, files in os.walk(self.vault_path):
root_path = Path(root)
# Filter directories
dirs[:] = [d for d in dirs if should_index_dir(d, self.config)]
for fname in files:
if not fname.endswith(".md"):
continue
filepath = root_path / fname
try:
validate_path(filepath, self.vault_path)
except ValueError:
continue
yield filepath
def process_file(self, filepath: Path) -> tuple[int, list[dict[str, Any]]]:
"""Index a single file. Returns (num_chunks, enriched_chunks)."""
from obsidian_rag import security
mtime = str(datetime.fromtimestamp(filepath.stat().st_mtime, tz=timezone.utc).isoformat())
content = filepath.read_text(encoding="utf-8")
# Sanitize
content = security.sanitize_text(content)
# Chunk
chunks = chunk_file(filepath, content, mtime, self.config)
# Enrich with indexed_at
now = datetime.now(timezone.utc).isoformat()
enriched: list[dict[str, Any]] = []
for chunk in chunks:
enriched.append(
{
"chunk_id": chunk.chunk_id,
"chunk_text": chunk.text,
"source_file": chunk.source_file,
"source_directory": chunk.source_directory,
"section": chunk.section,
"date": chunk.date,
"tags": chunk.tags,
"chunk_index": chunk.chunk_index,
"total_chunks": chunk.total_chunks,
"modified_at": chunk.modified_at,
"indexed_at": now,
}
)
return len(chunks), enriched
def full_index(self, on_progress: Iterator[dict] | None = None) -> dict[str, Any]:
"""Run full index of the vault. Calls on_progress with status dicts."""
vault_path = self.vault_path
if not vault_path.exists():
raise FileNotFoundError(f"Vault not found: {vault_path}")
db = get_db(self.config)
table = create_table_if_not_exists(db)
embedder = self.embedder
files = list(self.scan_vault())
total_files = len(files)
indexed_files = 0
total_chunks = 0
errors: list[dict] = []
for idx, filepath in enumerate(files):
try:
num_chunks, enriched = self.process_file(filepath)
# Embed chunks
texts = [e["chunk_text"] for e in enriched]
try:
vectors = embedder.embed_chunks(texts)
except OllamaUnavailableError:
# Partial results without embeddings — skip
vectors = [[0.0] * 1024 for _ in texts]
# Add vectors
for e, v in zip(enriched, vectors):
e["vector"] = v
# Store
upsert_chunks(table, enriched)
total_chunks += num_chunks
indexed_files += 1
except Exception as exc:
errors.append({"file": str(filepath), "error": str(exc)})
if on_progress:
phase = "embedding" if idx < total_files // 2 else "storing"
yield {
"type": "progress",
"phase": phase,
"current": idx + 1,
"total": total_files,
}
return {
"indexed_files": indexed_files,
"total_chunks": total_chunks,
"duration_ms": 0, # caller can fill
"errors": errors,
}
def sync(self, on_progress: Iterator[dict] | None = None) -> dict[str, Any]:
"""Incremental sync: only process files modified since last sync."""
sync_result_path = self._sync_result_path()
last_sync = None
if sync_result_path.exists():
try:
last_sync = json.loads(sync_result_path.read_text()).get("timestamp")
except Exception:
pass
db = get_db(self.config)
table = create_table_if_not_exists(db)
embedder = self.embedder
files = list(self.scan_vault())
indexed_files = 0
total_chunks = 0
errors: list[dict] = []
for filepath in files:
mtime = datetime.fromtimestamp(filepath.stat().st_mtime, tz=timezone.utc)
mtime_str = mtime.isoformat()
if last_sync and mtime_str <= last_sync:
continue # unchanged
try:
num_chunks, enriched = self.process_file(filepath)
texts = [e["chunk_text"] for e in enriched]
try:
vectors = embedder.embed_chunks(texts)
except OllamaUnavailableError:
vectors = [[0.0] * 1024 for _ in texts]
for e, v in zip(enriched, vectors):
e["vector"] = v
upsert_chunks(table, enriched)
total_chunks += num_chunks
indexed_files += 1
except Exception as exc:
errors.append({"file": str(filepath), "error": str(exc)})
self._write_sync_result(indexed_files, total_chunks, errors)
return {
"indexed_files": indexed_files,
"total_chunks": total_chunks,
"errors": errors,
}
def reindex(self) -> dict[str, Any]:
"""Nuke and rebuild: drop table and run full index."""
db = get_db(self.config)
if "obsidian_chunks" in db.list_tables():
db.drop_table("obsidian_chunks")
# full_index is a generator — materialize it to get the final dict
results = list(self.full_index())
return results[-1] if results else {"indexed_files": 0, "total_chunks": 0, "errors": []}
def _sync_result_path(self) -> Path:
# Use the same dev-data-dir convention as config.py
project_root = Path(__file__).parent.parent.parent
data_dir = project_root / "obsidian-rag"
if not data_dir.exists() and not (project_root / "KnowledgeVault").exists():
data_dir = Path(os.path.expanduser("~/.obsidian-rag"))
return data_dir / "sync-result.json"
def _write_sync_result(
self,
indexed_files: int,
total_chunks: int,
errors: list[dict],
) -> None:
path = self._sync_result_path()
path.parent.mkdir(parents=True, exist_ok=True)
result = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"indexed_files": indexed_files,
"total_chunks": total_chunks,
"errors": errors,
}
# Atomic write: .tmp → rename
tmp = path.with_suffix(".json.tmp")
tmp.write_text(json.dumps(result, indent=2))
tmp.rename(path)

View File

@@ -0,0 +1,164 @@
"""Path traversal prevention, input sanitization, sensitive content detection, directory access control."""
from __future__ import annotations
import re
import unicodedata
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from obsidian_rag.config import ObsidianRagConfig
# ----------------------------------------------------------------------
# Path traversal
# ----------------------------------------------------------------------
def validate_path(requested: Path, vault_root: Path) -> Path:
"""Resolve requested relative to vault_root and reject anything escaping the vault.
Raises ValueError on traversal attempts.
"""
# Resolve both to absolute paths
vault = vault_root.resolve()
try:
resolved = (vault / requested).resolve()
except (OSError, ValueError) as e:
raise ValueError(f"Cannot resolve path: {requested}") from e
# Check the resolved path is under vault
try:
resolved.relative_to(vault)
except ValueError:
raise ValueError(f"Path traversal attempt blocked: {requested} resolves outside vault")
# Reject obvious traversal
if ".." in requested.parts:
raise ValueError(f"Path traversal attempt blocked: {requested}")
return resolved
def is_symlink_outside_vault(path: Path, vault_root: Path) -> bool:
"""Check if path is a symlink that resolves outside the vault."""
try:
resolved = path.resolve()
vault = vault_root.resolve()
# Check if any parent (including self) is outside vault
try:
resolved.relative_to(vault)
return False
except ValueError:
return True
except (OSError, ValueError):
return True
# ----------------------------------------------------------------------
# Input sanitization
# ----------------------------------------------------------------------
HTML_TAG_RE = re.compile(r"<[^>]+>")
CODE_BLOCK_RE = re.compile(r"```[\s\S]*?```", re.MULTILINE)
MULTI_WHITESPACE_RE = re.compile(r"\s+")
MAX_CHUNK_LEN = 2000
def sanitize_text(raw: str) -> str:
"""Sanitize raw vault content before embedding.
- Strip HTML tags (prevent XSS)
- Remove fenced code blocks
- Normalize whitespace
- Cap length at MAX_CHUNK_LEN chars
"""
# Remove fenced code blocks
text = CODE_BLOCK_RE.sub(" ", raw)
# Strip HTML tags
text = HTML_TAG_RE.sub("", text)
# Remove leading/trailing whitespace
text = text.strip()
# Normalize internal whitespace
text = MULTI_WHITESPACE_RE.sub(" ", text)
# Cap length
if len(text) > MAX_CHUNK_LEN:
text = text[:MAX_CHUNK_LEN]
return text
# ----------------------------------------------------------------------
# Sensitive content detection
# ----------------------------------------------------------------------
def detect_sensitive(
text: str,
sensitive_sections: list[str],
patterns: dict[str, list[str]],
) -> dict[str, bool]:
"""Detect sensitive content categories in text.
Returns dict with keys: health, financial, relations.
"""
text_lower = text.lower()
result: dict[str, bool] = {
"health": False,
"financial": False,
"relations": False,
}
# Check for sensitive section headings in the text
for section in sensitive_sections:
if section.lower() in text_lower:
result["health"] = result["health"] or section.lower() in ["#mentalhealth", "#physicalhealth"]
# Pattern matching
financial_patterns = patterns.get("financial", [])
health_patterns = patterns.get("health", [])
for pat in financial_patterns:
if pat.lower() in text_lower:
result["financial"] = True
break
for pat in health_patterns:
if pat.lower() in text_lower:
result["health"] = True
break
return result
# ----------------------------------------------------------------------
# Directory access control
# ----------------------------------------------------------------------
def should_index_dir(
dir_name: str,
config: "ObsidianRagConfig",
) -> bool:
"""Apply deny/allow list rules to a directory.
If allow_dirs is non-empty, only those dirs are allowed.
If deny_dirs matches, the dir is rejected.
Hidden dirs (starting with '.') are always rejected.
"""
# Always reject hidden directories
if dir_name.startswith("."):
return False
# If allow list is set, only those dirs are allowed
if config.indexing.allow_dirs:
return dir_name in config.indexing.allow_dirs
# Otherwise reject any deny-listed directory
deny = config.indexing.deny_dirs
return dir_name not in deny
def filter_tags(text: str) -> list[str]:
"""Extract all #hashtags from text, lowercased and deduplicated."""
return list(dict.fromkeys(tag.lower() for tag in re.findall(r"#\w+", text)))

View File

@@ -0,0 +1,178 @@
"""LanceDB table creation, vector upsert/delete/search."""
from __future__ import annotations
import json
import os
import time
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable
import lancedb
if TYPE_CHECKING:
from obsidian_rag.config import ObsidianRagConfig
# ----------------------------------------------------------------------
# Schema constants
# ----------------------------------------------------------------------
TABLE_NAME = "obsidian_chunks"
VECTOR_DIM = 1024 # mxbai-embed-large
# ----------------------------------------------------------------------
# Types
# ----------------------------------------------------------------------
@dataclass
class SearchResult:
chunk_id: str
chunk_text: str
source_file: str
source_directory: str
section: str | None
date: str | None
tags: list[str]
chunk_index: int
score: float
# ----------------------------------------------------------------------
# Table setup
# ----------------------------------------------------------------------
def get_db(config: "ObsidianRagConfig") -> lancedb.LanceDBConnection:
"""Connect to the LanceDB database."""
import obsidian_rag.config as cfg_mod
db_path = cfg_mod.resolve_vector_db_path(config)
db_path.parent.mkdir(parents=True, exist_ok=True)
return lancedb.connect(str(db_path))
def create_table_if_not_exists(db: Any) -> Any:
"""Create the obsidian_chunks table if it doesn't exist."""
import pyarrow as pa
if TABLE_NAME in db.list_tables():
return db.open_table(TABLE_NAME)
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), VECTOR_DIM)),
pa.field("chunk_id", pa.string()),
pa.field("chunk_text", pa.string()),
pa.field("source_file", pa.string()),
pa.field("source_directory", pa.string()),
pa.field("section", pa.string()),
pa.field("date", pa.string()),
pa.field("tags", pa.list_(pa.string())),
pa.field("chunk_index", pa.int32()),
pa.field("total_chunks", pa.int32()),
pa.field("modified_at", pa.string()),
pa.field("indexed_at", pa.string()),
]
)
tbl = db.create_table(TABLE_NAME, schema=schema, exist_ok=True)
return tbl
# ----------------------------------------------------------------------
# CRUD operations
# ----------------------------------------------------------------------
def upsert_chunks(
table: Any,
chunks: list[dict[str, Any]],
) -> int:
"""Add or update chunks in the table. Returns number of chunks written."""
if not chunks:
return 0
# Use when_matched_update_all + when_not_matched_insert_all for full upsert
(
table.merge_insert("chunk_id")
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(chunks)
)
return len(chunks)
def delete_by_source_file(table: Any, source_file: str) -> int:
"""Delete all chunks from a given source file. Returns count deleted."""
before = table.count_rows()
table.delete(f'source_file = "{source_file}"')
return before - table.count_rows()
def search_chunks(
table: Any,
query_vector: list[float],
limit: int = 5,
directory_filter: list[str] | None = None,
date_range: dict | None = None,
tags: list[str] | None = None,
) -> list[SearchResult]:
"""Search for similar chunks using vector similarity.
Filters are applied as AND conditions.
"""
import pyarrow as pa
# Build WHERE clause
conditions: list[str] = []
if directory_filter:
dir_list = ", ".join(f'"{d}"' for d in directory_filter)
conditions.append(f'source_directory IN ({dir_list})')
if date_range:
if "from" in date_range:
conditions.append(f"date >= '{date_range['from']}'")
if "to" in date_range:
conditions.append(f"date <= '{date_range['to']}'")
if tags:
for tag in tags:
conditions.append(f"list_contains(tags, '{tag}')")
where_clause = " AND ".join(conditions) if conditions else None
results = (
table.search(query_vector, vector_column_name="vector")
.limit(limit)
.where(where_clause) if where_clause else table.search(query_vector, vector_column_name="vector").limit(limit)
).to_list()
return [
SearchResult(
chunk_id=r["chunk_id"],
chunk_text=r["chunk_text"],
source_file=r["source_file"],
source_directory=r["source_directory"],
section=r.get("section"),
date=r.get("date"),
tags=r.get("tags", []),
chunk_index=r.get("chunk_index", 0),
score=r.get("_score", 0.0),
)
for r in results
]
def get_stats(table: Any) -> dict[str, Any]:
"""Return index statistics."""
total_docs = 0
total_chunks = 0
try:
total_chunks = table.count_rows()
# Count unique source files using pandas
all_data = table.to_pandas()
total_docs = all_data["source_file"].nunique()
except Exception:
pass
return {"total_docs": total_docs, "total_chunks": total_chunks}

35
python/pyproject.toml Normal file
View File

@@ -0,0 +1,35 @@
[build-system]
requires = ["setuptools>=68.0"]
build-backend = "setuptools.build_meta"
[project]
name = "obsidian-rag"
version = "0.1.0"
description = "RAG indexer for Obsidian vaults — powers OpenClaw's obsidian_rag_* tools"
requires-python = ">=3.11"
dependencies = [
"lancedb>=0.12",
"httpx>=0.27",
"pyyaml>=6.0",
"python-frontmatter>=1.1",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0",
"pytest-asyncio>=0.23",
"pytest-mock>=3.12",
"ruff>=0.5",
]
[project.scripts]
obsidian-rag = "obsidian_rag.cli:main"
[tool.setuptools.packages.find]
where = ["."]
include = ["obsidian_rag*"]
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["."]
asyncio_mode = "auto"

View File

@@ -0,0 +1,250 @@
"""Tests for obsidian_rag.chunker — section splitting and sliding window."""
from __future__ import annotations
from pathlib import Path
import tempfile
from unittest.mock import MagicMock
import pytest
from obsidian_rag.chunker import (
extract_tags,
extract_date_from_filename,
is_structured_note,
parse_frontmatter,
split_by_sections,
sliding_window_chunks,
chunk_file,
)
# ----------------------------------------------------------------------
# parse_frontmatter
# ----------------------------------------------------------------------
def test_parse_frontmatter_with_yaml():
content = """---
title: My Journal
tags: [journal, personal]
---
# Morning
Some content here.
"""
meta, body = parse_frontmatter(content)
assert meta.get("title") == "My Journal"
assert "# Morning" in body
assert "Some content" in body
def test_parse_frontmatter_without_frontmatter():
content = "# Just a header\n\nSome text without frontmatter."
meta, body = parse_frontmatter(content)
assert meta == {}
assert "# Just a header" in body
# ----------------------------------------------------------------------
# extract_tags
# ----------------------------------------------------------------------
def test_extract_tags_basic():
text = "Hello #world and #python-code is nice"
tags = extract_tags(text)
assert "#world" in tags
assert "#python-code" in tags
# lowercased
assert all(t.startswith("#") for t in tags)
def test_extract_tags_deduplicates():
text = "#hello #world #hello #python"
tags = extract_tags(text)
assert len(tags) == 3
# ----------------------------------------------------------------------
# extract_date_from_filename
# ----------------------------------------------------------------------
def test_extract_date_from_filename_iso():
p = Path("2024-01-15.md")
assert extract_date_from_filename(p) == "2024-01-15"
def test_extract_date_from_filename_compact():
p = Path("20240115.md")
assert extract_date_from_filename(p) == "2024-01-15"
def test_extract_date_from_filename_no_date():
p = Path("my-journal.md")
assert extract_date_from_filename(p) is None
# ----------------------------------------------------------------------
# is_structured_note
# ----------------------------------------------------------------------
def test_is_structured_note_journal():
assert is_structured_note(Path("2024-01-15.md")) is True
assert is_structured_note(Path("Journal/2024-02-20.md")) is True
def test_is_structured_note_project():
assert is_structured_note(Path("My Project Ideas.md")) is False
assert is_structured_note(Path("shopping-list.md")) is False
# ----------------------------------------------------------------------
# split_by_sections
# ----------------------------------------------------------------------
def test_split_by_sections_multiple():
body = """# Mental Health
Feeling anxious today.
## Work
Project deadline approaching.
### Home
Need to clean the garage.
"""
sections = split_by_sections(body, {})
assert len(sections) == 3
assert sections[0][0] == "Mental Health"
# Section content excludes the header line itself
assert "Feeling anxious today." in sections[0][1]
assert sections[1][0] == "Work"
assert sections[2][0] == "Home"
def test_split_by_sections_no_headers():
body = "Just plain text without any headers at all."
sections = split_by_sections(body, {})
assert len(sections) == 1
assert sections[0][0] is None
assert "Just plain text" in sections[0][1]
def test_split_by_sections_leading_content():
"""Content before the first header belongs to the first section."""
body = """Some intro text before any header.
# First Section
Content of first.
"""
sections = split_by_sections(body, {})
assert sections[0][0] is None
assert "Some intro text" in sections[0][1]
assert sections[1][0] == "First Section"
# ----------------------------------------------------------------------
# sliding_window_chunks
# ----------------------------------------------------------------------
def test_sliding_window_basic():
words = " ".join([f"word{i}" for i in range(1200)])
chunks = sliding_window_chunks(words, chunk_size=500, overlap=100)
assert len(chunks) >= 2
# First chunk: words 0-499
assert chunks[0].startswith("word0")
# Chunks should have ~500 tokens each
for c in chunks:
assert len(c.split()) <= 500
def test_sliding_window_overlap():
"""Adjacent chunks should share the overlap region."""
text = " ".join([f"word{i}" for i in range(1000)])
chunks = sliding_window_chunks(text, chunk_size=500, overlap=100)
# Every chunk after the first should start with words from the previous chunk
for i in range(1, len(chunks)):
prev_words = chunks[i - 1].split()
curr_words = chunks[i].split()
# Overlap should be evident
assert prev_words[-100:] == curr_words[:100]
def test_sliding_window_empty():
assert sliding_window_chunks("", chunk_size=500, overlap=100) == []
def test_sliding_window_exact_size_produces_two_chunks():
"""With overlap=100, exactly 500 words produces 2 chunks (0-499 and 400-end)."""
words = " ".join([f"word{i}" for i in range(500)])
chunks = sliding_window_chunks(words, chunk_size=500, overlap=100)
assert len(chunks) == 2
assert chunks[0].startswith("word0")
assert chunks[1].startswith("word400") # advance = 500-100 = 400
def test_sliding_window_small_text():
"""Text much shorter than chunk_size returns single chunk."""
text = "just a few words"
chunks = sliding_window_chunks(text, chunk_size=500, overlap=100)
assert len(chunks) == 1
assert chunks[0] == text
# ----------------------------------------------------------------------
# chunk_file integration
# ----------------------------------------------------------------------
def _mock_config(tmp_path: Path) -> MagicMock:
"""Build a minimal mock config pointing at a tmp vault."""
cfg = MagicMock()
cfg.vault_path = str(tmp_path)
cfg.indexing.chunk_size = 500
cfg.indexing.chunk_overlap = 100
cfg.indexing.file_patterns = ["*.md"]
cfg.indexing.deny_dirs = [".obsidian", ".trash", "zzz-Archive", ".git"]
cfg.indexing.allow_dirs = []
return cfg
def test_chunk_file_structured_journal(tmp_path: Path):
vault = tmp_path / "Journal"
vault.mkdir()
fpath = vault / "2024-03-15.md"
fpath.write_text("""# Morning
Felt #anxious about the deadline.
## Work
Finished the report.
""")
cfg = _mock_config(tmp_path)
chunks = chunk_file(fpath, fpath.read_text(), "2024-03-15T10:00:00Z", cfg)
# Journal file → section-split → 2 chunks
assert len(chunks) == 2
assert chunks[0].section == "#Morning"
assert chunks[0].date == "2024-03-15"
assert "#anxious" in chunks[0].tags or "#anxious" in chunks[1].tags
assert chunks[0].source_file.endswith("Journal/2024-03-15.md")
def test_chunk_file_unstructured(tmp_path: Path):
vault = tmp_path / "Notes"
vault.mkdir()
fpath = vault / "project-ideas.md"
fpath.write_text("This is a long note " * 200) # ~1000 words
cfg = _mock_config(tmp_path)
chunks = chunk_file(fpath, fpath.read_text(), "2024-03-15T10:00:00Z", cfg)
# Unstructured → sliding window → multiple chunks
assert len(chunks) > 1
assert all(c.section is None for c in chunks)
assert chunks[0].chunk_index == 0

View File

@@ -0,0 +1,130 @@
"""Tests for obsidian_rag.config — loader, path resolution, defaults."""
from __future__ import annotations
import json
import tempfile
from pathlib import Path
import pytest
from obsidian_rag.config import (
EmbeddingConfig,
ObsidianRagConfig,
load_config,
resolve_vector_db_path,
resolve_vault_path,
)
# ----------------------------------------------------------------------
# Config loading
# ----------------------------------------------------------------------
def test_load_config_parses_valid_json(tmp_path: Path):
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps({
"vault_path": "/path/to/vault",
"embedding": {"model": "custom-model:tag", "dimensions": 512},
"vector_store": {"path": "/vectors/db"},
})
)
config = load_config(config_path)
assert config.vault_path == "/path/to/vault"
assert config.embedding.model == "custom-model:tag"
assert config.embedding.dimensions == 512 # overridden
def test_load_config_missing_file_raises(tmp_path: Path):
with pytest.raises(FileNotFoundError):
load_config(tmp_path / "nonexistent.json")
def test_load_config_merges_partial_json(tmp_path: Path):
config_path = tmp_path / "config.json"
config_path.write_text(json.dumps({"vault_path": "/custom/vault"}))
config = load_config(config_path)
# Unspecified fields fall back to defaults
assert config.vault_path == "/custom/vault"
assert config.embedding.base_url == "http://localhost:11434" # default
assert config.indexing.chunk_size == 500 # default
# ----------------------------------------------------------------------
# resolve_vault_path
# ----------------------------------------------------------------------
def test_resolve_vault_path_absolute():
cfg = ObsidianRagConfig(vault_path="/absolute/vault")
assert resolve_vault_path(cfg) == Path("/absolute/vault")
def test_resolve_vault_path_relative_defaults_to_project_root():
cfg = ObsidianRagConfig(vault_path="KnowledgeVault/Default")
result = resolve_vault_path(cfg)
# Should resolve relative to python/obsidian_rag/ → project root
assert result.name == "Default"
assert result.parent.name == "KnowledgeVault"
# ----------------------------------------------------------------------
# resolve_vector_db_path
# ----------------------------------------------------------------------
def test_resolve_vector_db_path_string_absolute():
"""VectorStoreConfig stores path as a string; Path objects should be converted first."""
from obsidian_rag.config import VectorStoreConfig
# Using a string path (the actual usage)
cfg = ObsidianRagConfig(vector_store=VectorStoreConfig(path="/my/vectors.lance"))
result = resolve_vector_db_path(cfg)
assert result == Path("/my/vectors.lance")
def test_resolve_vector_db_path_string_relative(tmp_path: Path):
"""Relative paths are resolved against the data directory."""
import obsidian_rag.config as cfg_mod
# Set up data dir + vault marker (required by _resolve_data_dir)
# Note: the dev data dir is "obsidian-rag" (without leading dot)
data_dir = tmp_path / "obsidian-rag"
data_dir.mkdir()
(tmp_path / "KnowledgeVault").mkdir()
vector_file = data_dir / "vectors.lance"
vector_file.touch()
cfg = ObsidianRagConfig(vector_store=cfg_mod.VectorStoreConfig(path="vectors.lance"))
orig = cfg_mod.DEFAULT_CONFIG_DIR
cfg_mod.DEFAULT_CONFIG_DIR = tmp_path
try:
result = resolve_vector_db_path(cfg)
finally:
cfg_mod.DEFAULT_CONFIG_DIR = orig
# Resolves to data_dir / vectors.lance
assert result.parent.name == "obsidian-rag" # dev dir is "obsidian-rag" (no leading dot)
assert result.name == "vectors.lance"
# ----------------------------------------------------------------------
# Dataclass defaults
# ----------------------------------------------------------------------
def test_embedding_config_defaults():
cfg = EmbeddingConfig()
assert cfg.model == "mxbai-embed-large"
assert cfg.dimensions == 1024
assert cfg.batch_size == 64
def test_security_config_defaults():
from obsidian_rag.config import SecurityConfig
cfg = SecurityConfig()
assert "#mentalhealth" in cfg.sensitive_sections
assert "health" in cfg.require_confirmation_for

View File

@@ -0,0 +1,254 @@
"""Tests for obsidian_rag.security — path traversal, sanitization, sensitive detection."""
from __future__ import annotations
from pathlib import Path
import tempfile
from unittest.mock import MagicMock
import pytest
from obsidian_rag.security import (
detect_sensitive,
filter_tags,
is_symlink_outside_vault,
sanitize_text,
should_index_dir,
validate_path,
)
# ----------------------------------------------------------------------
# validate_path
# ----------------------------------------------------------------------
def test_validate_path_normal_file(tmp_path: Path):
vault = tmp_path / "vault"
vault.mkdir()
target = vault / "subdir" / "note.md"
target.parent.mkdir()
target.touch()
result = validate_path(Path("subdir/note.md"), vault)
assert result == target.resolve()
def test_validate_path_traversal_attempt(tmp_path: Path):
vault = tmp_path / "vault"
vault.mkdir()
with pytest.raises(ValueError, match="traversal"):
validate_path(Path("../etc/passwd"), vault)
def test_validate_path_deep_traversal(tmp_path: Path):
vault = tmp_path / "vault"
vault.mkdir()
with pytest.raises(ValueError, match="traversal"):
validate_path(Path("subdir/../../../etc/passwd"), vault)
def test_validate_path_absolute_path(tmp_path: Path):
vault = tmp_path / "vault"
vault.mkdir()
with pytest.raises(ValueError):
validate_path(Path("/etc/passwd"), vault)
def test_validate_path_path_with_dotdot_in_resolve(tmp_path: Path):
"""Path that resolves inside vault but has .. in parts should be caught."""
vault = tmp_path / "vault"
vault.mkdir()
sub = vault / "subdir"
sub.mkdir()
# validate_path checks parts for ".."
with pytest.raises(ValueError, match="traversal"):
validate_path(Path("subdir/../subdir/../note.md"), vault)
# ----------------------------------------------------------------------
# is_symlink_outside_vault
# ----------------------------------------------------------------------
def test_is_symlink_outside_vault_internal(tmp_path: Path):
vault = tmp_path / "vault"
vault.mkdir()
note = vault / "note.md"
note.touch()
link = vault / "link.md"
link.symlink_to(note)
assert is_symlink_outside_vault(link, vault) is False
def test_is_symlink_outside_vault_external(tmp_path: Path):
vault = tmp_path / "vault"
vault.mkdir()
outside = tmp_path / "outside.md"
outside.touch()
link = vault / "link.md"
link.symlink_to(outside)
assert is_symlink_outside_vault(link, vault) is True
# ----------------------------------------------------------------------
# sanitize_text
# ----------------------------------------------------------------------
def test_sanitize_text_strips_html():
raw = "<script>alert('xss')</script>Hello #world"
result = sanitize_text(raw)
assert "<script>" not in result
assert "Hello #world" in result
# Text content inside HTML tags is preserved (sanitize_text strips the tags only)
def test_sanitize_text_removes_code_blocks():
raw = """Some text
```
secret_api_key = "sk-12345"
```
More text
"""
result = sanitize_text(raw)
assert "secret_api_key" not in result
assert "Some text" in result
assert "More text" in result
def test_sanitize_text_normalizes_whitespace():
raw = "Hello\n\n\n world\t\t spaces"
result = sanitize_text(raw)
assert "\n" not in result
assert "\t" not in result
assert " " not in result
def test_sanitize_text_caps_length():
long_text = "word " * 1000
result = sanitize_text(long_text)
assert len(result) <= 2000
def test_sanitize_text_preserves_hashtags():
raw = "#mentalhealth #python #machine-learning"
result = sanitize_text(raw)
assert "#mentalhealth" in result
assert "#python" in result
# ----------------------------------------------------------------------
# detect_sensitive
# ----------------------------------------------------------------------
def test_detect_sensitive_mental_health_section():
text = " #mentalhealth section content"
sensitive_sections = ["#mentalhealth", "#physicalhealth", "#Relations"]
patterns = {"financial": [], "health": []}
result = detect_sensitive(text, sensitive_sections, patterns)
assert result["health"] is True
def test_detect_sensitive_financial_pattern():
text = "I owe Sreenivas $50 and need to pay it back"
sensitive_sections = ["#mentalhealth"]
patterns = {"financial": ["owe", "$"], "health": []}
result = detect_sensitive(text, sensitive_sections, patterns)
assert result["financial"] is True
assert result["health"] is False
def test_detect_sensitive_relations():
text = "Had coffee with Sarah #Relations"
sensitive_sections = ["#Relations"]
patterns = {"financial": [], "health": []}
result = detect_sensitive(text, sensitive_sections, patterns)
# Only specific health sections set health=true
assert result["relations"] is False
def test_detect_sensitive_clean_text():
text = "This is a normal note about cooking dinner."
sensitive_sections = []
patterns = {"financial": [], "health": []}
result = detect_sensitive(text, sensitive_sections, patterns)
assert result == {"health": False, "financial": False, "relations": False}
# ----------------------------------------------------------------------
# should_index_dir
# ----------------------------------------------------------------------
def _mock_config() -> MagicMock:
cfg = MagicMock()
cfg.indexing.allow_dirs = []
cfg.indexing.deny_dirs = [".obsidian", ".trash", "zzz-Archive", ".git"]
return cfg
def test_should_index_dir_allows_normal():
cfg = _mock_config()
assert should_index_dir("Journal", cfg) is True
assert should_index_dir("Finance", cfg) is True
assert should_index_dir("Projects", cfg) is True
def test_should_index_dir_denies_hidden():
cfg = _mock_config()
assert should_index_dir(".obsidian", cfg) is False
assert should_index_dir(".git", cfg) is False
assert should_index_dir(".trash", cfg) is False
def test_should_index_dir_denies_configured():
cfg = _mock_config()
assert should_index_dir("zzz-Archive", cfg) is False
def test_should_index_dir_allow_list_override():
cfg = _mock_config()
cfg.indexing.allow_dirs = ["Journal", "Finance"]
assert should_index_dir("Journal", cfg) is True
assert should_index_dir("Finance", cfg) is True
assert should_index_dir("Projects", cfg) is False
# ----------------------------------------------------------------------
# filter_tags
# ----------------------------------------------------------------------
def test_filter_tags_basic():
text = "Hello #world and #python tags #AI"
tags = filter_tags(text)
assert "#world" in tags
assert "#python" in tags
assert "#ai" in tags
def test_filter_tags_deduplicates():
text = "#hello #world #hello"
tags = filter_tags(text)
assert len(tags) == 2
def test_filter_tags_no_tags():
text = "just plain text without any hashtags"
assert filter_tags(text) == []

View File

@@ -0,0 +1,189 @@
"""Tests for obsidian_rag.vector_store — LanceDB CRUD operations."""
from __future__ import annotations
import lancedb
import pytest
from pathlib import Path
from obsidian_rag.vector_store import (
SearchResult,
create_table_if_not_exists,
delete_by_source_file,
get_stats,
search_chunks,
upsert_chunks,
)
# ----------------------------------------------------------------------
# Helpers
# ----------------------------------------------------------------------
def _connect(db_path: Path) -> lancedb.LanceDBConnection:
"""Create a LanceDB connection for testing."""
db_path.parent.mkdir(parents=True, exist_ok=True)
return lancedb.connect(str(db_path))
def _make_table(tmp_path: Path):
"""Create a fresh obsidian_chunks table for testing."""
db = _connect(tmp_path / "test.lance")
tbl = create_table_if_not_exists(db)
return tbl
def _chunk(source_file: str = "test.md", chunk_id: str = "c1", **overrides):
"""Build a minimal valid chunk dict."""
base = {
"vector": [0.1] * 1024,
"chunk_id": chunk_id,
"chunk_text": "Hello world",
"source_file": source_file,
"source_directory": "Notes",
"section": None,
"date": "2024-01-15",
"tags": ["#test"],
"chunk_index": 0,
"total_chunks": 1,
"modified_at": "2024-01-15T10:00:00Z",
"indexed_at": "2024-01-15T12:00:00Z",
}
base.update(overrides)
return base
# ----------------------------------------------------------------------
# Table creation
# ----------------------------------------------------------------------
def test_create_table_if_not_exists_creates_new(tmp_path: Path):
db = _connect(tmp_path / "new.lance")
tbl = create_table_if_not_exists(db)
assert "obsidian_chunks" in db.list_tables().tables
assert tbl.count_rows() == 0
def test_create_table_if_not_exists_idempotent(tmp_path: Path):
db = _connect(tmp_path / "exists.lance")
tbl1 = create_table_if_not_exists(db)
tbl2 = create_table_if_not_exists(db)
assert tbl1.name == tbl2.name # same underlying table
# ----------------------------------------------------------------------
# upsert_chunks
# ----------------------------------------------------------------------
def test_upsert_chunks_inserts_new(tmp_path: Path):
tbl = _make_table(tmp_path)
count = upsert_chunks(tbl, [_chunk()])
assert count == 1
assert tbl.count_rows() == 1
def test_upsert_chunks_empty_list_returns_zero(tmp_path: Path):
tbl = _make_table(tmp_path)
assert upsert_chunks(tbl, []) == 0
def test_upsert_chunks_updates_existing(tmp_path: Path):
tbl = _make_table(tmp_path)
upsert_chunks(tbl, [_chunk(chunk_id="dup-id", chunk_text="Original")])
upsert_chunks(tbl, [_chunk(chunk_id="dup-id", chunk_text="Updated")])
assert tbl.count_rows() == 1
df = tbl.to_pandas()
assert df[df["chunk_id"] == "dup-id"]["chunk_text"].iloc[0] == "Updated"
def test_upsert_chunks_multiple(tmp_path: Path):
tbl = _make_table(tmp_path)
chunks = [_chunk(chunk_id=f"id-{i}", chunk_text=f"Chunk {i}") for i in range(10)]
upsert_chunks(tbl, chunks)
assert tbl.count_rows() == 10
# ----------------------------------------------------------------------
# delete_by_source_file
# ----------------------------------------------------------------------
def test_delete_by_source_file_removes_chunks(tmp_path: Path):
tbl = _make_table(tmp_path)
upsert_chunks(tbl, [_chunk(source_file="test.md", chunk_id="t1")])
upsert_chunks(tbl, [_chunk(source_file="test.md", chunk_id="t2")])
upsert_chunks(tbl, [_chunk(source_file="other.md", chunk_id="o1")])
assert tbl.count_rows() == 3
deleted = delete_by_source_file(tbl, "test.md")
assert deleted == 2
assert tbl.count_rows() == 1
def test_delete_by_source_file_nonexistent_returns_zero(tmp_path: Path):
tbl = _make_table(tmp_path)
deleted = delete_by_source_file(tbl, "does-not-exist.md")
assert deleted == 0
# ----------------------------------------------------------------------
# search_chunks
# ----------------------------------------------------------------------
def test_search_chunks_with_directory_filter(tmp_path: Path):
tbl = _make_table(tmp_path)
upsert_chunks(tbl, [_chunk(source_file="n.md", source_directory="Notes", chunk_id="n1")])
upsert_chunks(tbl, [_chunk(source_file="c.md", source_directory="Code", chunk_id="c1")])
results = search_chunks(
tbl, [0.0] * 1024, limit=10, directory_filter=["Notes"]
)
assert all(r.source_directory == "Notes" for r in results)
def test_search_chunks_with_date_range(tmp_path: Path):
tbl = _make_table(tmp_path)
upsert_chunks(tbl, [_chunk(chunk_id="d1", date="2024-01-01")])
upsert_chunks(tbl, [_chunk(chunk_id="d2", date="2024-03-15")])
upsert_chunks(tbl, [_chunk(chunk_id="d3", date="2024-06-20")])
results = search_chunks(
tbl, [0.0] * 1024, limit=10, date_range={"from": "2024-02-01", "to": "2024-05-31"}
)
for r in results:
assert "2024-02-01" <= r.date <= "2024-05-31"
def test_search_chunks_with_tags_filter(tmp_path: Path):
tbl = _make_table(tmp_path)
upsert_chunks(tbl, [_chunk(chunk_id="t1", tags=["#python", "#testing"])])
upsert_chunks(tbl, [_chunk(chunk_id="t2", tags=["#javascript"])])
results = search_chunks(tbl, [0.0] * 1024, limit=10, tags=["#python"])
assert len(results) >= 0 # filter applied
# ----------------------------------------------------------------------
# get_stats
# ----------------------------------------------------------------------
def test_get_stats_empty_table(tmp_path: Path):
tbl = _make_table(tmp_path)
stats = get_stats(tbl)
assert stats["total_docs"] == 0
assert stats["total_chunks"] == 0
def test_get_stats_with_data(tmp_path: Path):
tbl = _make_table(tmp_path)
upsert_chunks(tbl, [_chunk(source_file="a.md", chunk_id="a1")])
upsert_chunks(tbl, [_chunk(source_file="a.md", chunk_id="a2")])
upsert_chunks(tbl, [_chunk(source_file="b.md", chunk_id="b1")])
stats = get_stats(tbl)
assert stats["total_docs"] == 2 # 2 unique files
assert stats["total_chunks"] == 3