Previously total_chunks counted from process_file return (num_chunks) which could differ from actual stored count if upsert silently failed. Now using stored count returned by upsert_chunks. Also fixes cli._index to skip progress yields when building result.
372 lines
14 KiB
Python
372 lines
14 KiB
Python
"""Full indexing pipeline: scan → parse → chunk → embed → store."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Generator, Iterator
|
|
|
|
if TYPE_CHECKING:
|
|
from obsidian_rag.config import ObsidianRagConfig
|
|
|
|
import obsidian_rag.config as config_mod
|
|
from obsidian_rag.config import _resolve_data_dir
|
|
from obsidian_rag.chunker import chunk_file
|
|
from obsidian_rag.embedder import OllamaUnavailableError
|
|
from obsidian_rag.security import should_index_dir, validate_path
|
|
from obsidian_rag.vector_store import (
|
|
create_table_if_not_exists,
|
|
get_db,
|
|
upsert_chunks,
|
|
)
|
|
|
|
# ----------------------------------------------------------------------
|
|
# Pipeline
|
|
# ----------------------------------------------------------------------
|
|
|
|
|
|
class SensitiveContentError(Exception):
|
|
"""Raised when sensitive content requires approval but isn't approved."""
|
|
|
|
|
|
class Indexer:
|
|
"""Coordinates the scan → chunk → embed → store pipeline."""
|
|
|
|
def __init__(self, config: "ObsidianRagConfig"):
|
|
self.config = config
|
|
self.vault_path = config_mod.resolve_vault_path(config)
|
|
self._embedder = None # lazy init
|
|
self._audit_logger = None # lazy init
|
|
|
|
@property
|
|
def embedder(self):
|
|
if self._embedder is None:
|
|
from obsidian_rag.embedder import OllamaEmbedder
|
|
|
|
self._embedder = OllamaEmbedder(self.config)
|
|
return self._embedder
|
|
|
|
@property
|
|
def audit_logger(self):
|
|
if self._audit_logger is None:
|
|
from obsidian_rag.audit_logger import AuditLogger
|
|
|
|
log_dir = _resolve_data_dir() / "audit"
|
|
self._audit_logger = AuditLogger(log_dir / "audit.log")
|
|
return self._audit_logger
|
|
|
|
def _check_sensitive_content_approval(self, chunks: list[dict[str, Any]]) -> None:
|
|
"""Enforce user approval for sensitive content before indexing."""
|
|
from obsidian_rag import security
|
|
|
|
sensitive_categories = self.config.security.require_confirmation_for
|
|
if not sensitive_categories:
|
|
return
|
|
|
|
for chunk in chunks:
|
|
sensitivity = security.detect_sensitive(
|
|
chunk["chunk_text"],
|
|
self.config.security.sensitive_sections,
|
|
self.config.memory.patterns,
|
|
)
|
|
|
|
for category in sensitive_categories:
|
|
if sensitivity.get(category, False):
|
|
if not self.config.security.auto_approve_sensitive:
|
|
raise SensitiveContentError(
|
|
f"Sensitive {category} content detected. "
|
|
f"Requires explicit approval before indexing. "
|
|
f"File: {chunk['source_file']}"
|
|
)
|
|
|
|
def scan_vault(self) -> Generator[Path, None, None]:
|
|
"""Walk vault, yielding markdown files to index."""
|
|
for root, dirs, files in os.walk(self.vault_path):
|
|
root_path = Path(root)
|
|
# Filter directories
|
|
dirs[:] = [d for d in dirs if should_index_dir(d, self.config)]
|
|
|
|
for fname in files:
|
|
if not fname.endswith(".md"):
|
|
continue
|
|
filepath = root_path / fname
|
|
try:
|
|
validate_path(filepath, self.vault_path)
|
|
except ValueError:
|
|
continue
|
|
yield filepath
|
|
|
|
def process_file(self, filepath: Path) -> tuple[int, list[dict[str, Any]]]:
|
|
"""Index a single file. Returns (num_chunks, enriched_chunks)."""
|
|
from obsidian_rag import security
|
|
|
|
mtime = str(
|
|
datetime.fromtimestamp(
|
|
filepath.stat().st_mtime, tz=timezone.utc
|
|
).isoformat()
|
|
)
|
|
content = filepath.read_text(encoding="utf-8")
|
|
# Sanitize
|
|
content = security.sanitize_text(content)
|
|
# Chunk
|
|
chunks = chunk_file(filepath, content, mtime, self.config)
|
|
# Enrich with indexed_at
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
enriched: list[dict[str, Any]] = []
|
|
for chunk in chunks:
|
|
enriched.append(
|
|
{
|
|
"chunk_id": chunk.chunk_id,
|
|
"chunk_text": chunk.text,
|
|
"source_file": chunk.source_file,
|
|
"source_directory": chunk.source_directory,
|
|
"section": chunk.section,
|
|
"date": chunk.date,
|
|
"tags": chunk.tags,
|
|
"chunk_index": chunk.chunk_index,
|
|
"total_chunks": chunk.total_chunks,
|
|
"modified_at": chunk.modified_at,
|
|
"indexed_at": now,
|
|
}
|
|
)
|
|
return len(chunks), enriched
|
|
|
|
def full_index(self, on_progress: Iterator[dict] | None = None) -> dict[str, Any]:
|
|
"""Run full index of the vault. Calls on_progress with status dicts."""
|
|
vault_path = self.vault_path
|
|
if not vault_path.exists():
|
|
raise FileNotFoundError(f"Vault not found: {vault_path}")
|
|
|
|
db = get_db(self.config)
|
|
table = create_table_if_not_exists(db)
|
|
embedder = self.embedder
|
|
|
|
files = list(self.scan_vault())
|
|
total_files = len(files)
|
|
indexed_files = 0
|
|
total_chunks = 0
|
|
errors: list[dict] = []
|
|
|
|
for idx, filepath in enumerate(files):
|
|
try:
|
|
num_chunks, enriched = self.process_file(filepath)
|
|
# Enforce sensitive content policies
|
|
self._check_sensitive_content_approval(enriched)
|
|
|
|
# Log sensitive content access
|
|
for chunk in enriched:
|
|
from obsidian_rag import security
|
|
|
|
sensitivity = security.detect_sensitive(
|
|
chunk["chunk_text"],
|
|
self.config.security.sensitive_sections,
|
|
self.config.memory.patterns,
|
|
)
|
|
for category in ["health", "financial", "relations"]:
|
|
if sensitivity.get(category, False):
|
|
self.audit_logger.log_sensitive_access(
|
|
str(chunk["source_file"]),
|
|
category,
|
|
"index",
|
|
{"chunk_id": chunk["chunk_id"]},
|
|
)
|
|
|
|
# Embed chunks
|
|
texts = [e["chunk_text"] for e in enriched]
|
|
try:
|
|
vectors = embedder.embed_chunks(texts)
|
|
except OllamaUnavailableError:
|
|
# Partial results without embeddings — skip
|
|
vectors = [[0.0] * 1024 for _ in texts]
|
|
# Add vectors
|
|
for e, v in zip(enriched, vectors):
|
|
e["vector"] = v
|
|
# Store
|
|
stored = upsert_chunks(table, enriched)
|
|
total_chunks += stored
|
|
indexed_files += 1
|
|
except Exception as exc:
|
|
errors.append({"file": str(filepath), "error": str(exc)})
|
|
|
|
if on_progress:
|
|
phase = "embedding" if idx < total_files // 2 else "storing"
|
|
yield {
|
|
"type": "progress",
|
|
"phase": phase,
|
|
"current": idx + 1,
|
|
"total": total_files,
|
|
}
|
|
|
|
# Yield final result
|
|
yield {
|
|
"indexed_files": indexed_files,
|
|
"total_chunks": total_chunks,
|
|
"duration_ms": 0, # caller can fill
|
|
"errors": errors,
|
|
}
|
|
"""Incremental sync: only process files modified since last sync."""
|
|
sync_result_path = self._sync_result_path()
|
|
last_sync = None
|
|
if sync_result_path.exists():
|
|
try:
|
|
last_sync = json.loads(sync_result_path.read_text()).get("timestamp")
|
|
except Exception:
|
|
pass
|
|
|
|
db = get_db(self.config)
|
|
table = create_table_if_not_exists(db)
|
|
embedder = self.embedder
|
|
|
|
files = list(self.scan_vault())
|
|
indexed_files = 0
|
|
total_chunks = 0
|
|
errors: list[dict] = []
|
|
|
|
for filepath in files:
|
|
mtime = datetime.fromtimestamp(filepath.stat().st_mtime, tz=timezone.utc)
|
|
mtime_str = mtime.isoformat()
|
|
if last_sync and mtime_str <= last_sync:
|
|
continue # unchanged
|
|
|
|
try:
|
|
num_chunks, enriched = self.process_file(filepath)
|
|
texts = [e["chunk_text"] for e in enriched]
|
|
try:
|
|
vectors = embedder.embed_chunks(texts)
|
|
except OllamaUnavailableError:
|
|
vectors = [[0.0] * 1024 for _ in texts]
|
|
for e, v in zip(enriched, vectors):
|
|
e["vector"] = v
|
|
upsert_chunks(table, enriched)
|
|
total_chunks += num_chunks
|
|
indexed_files += 1
|
|
except Exception as exc:
|
|
errors.append({"file": str(filepath), "error": str(exc)})
|
|
|
|
self._write_sync_result(indexed_files, total_chunks, errors)
|
|
return {
|
|
"indexed_files": indexed_files,
|
|
"total_chunks": total_chunks,
|
|
"errors": errors,
|
|
}
|
|
|
|
def reindex(self) -> dict[str, Any]:
|
|
"""Nuke and rebuild: drop table and run full index."""
|
|
db = get_db(self.config)
|
|
if "obsidian_chunks" in db.list_tables():
|
|
db.drop_table("obsidian_chunks")
|
|
results = list(self.full_index())
|
|
final = (
|
|
results[-1]
|
|
if results
|
|
else {"indexed_files": 0, "total_chunks": 0, "errors": []}
|
|
)
|
|
self._write_sync_result(
|
|
final["indexed_files"], final["total_chunks"], final["errors"]
|
|
)
|
|
return final
|
|
|
|
def _sync_result_path(self) -> Path:
|
|
# Use the same dev-data-dir convention as config.py
|
|
project_root = Path(__file__).parent.parent.parent
|
|
data_dir = project_root / "obsidian-rag"
|
|
if not data_dir.exists() and not (project_root / "KnowledgeVault").exists():
|
|
data_dir = Path(os.path.expanduser("~/.obsidian-rag"))
|
|
return data_dir / "sync-result.json"
|
|
|
|
def _write_sync_result(
|
|
self,
|
|
indexed_files: int,
|
|
total_chunks: int,
|
|
errors: list[dict],
|
|
) -> None:
|
|
path = self._sync_result_path()
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
result = {
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
"indexed_files": indexed_files,
|
|
"total_chunks": total_chunks,
|
|
"errors": errors,
|
|
}
|
|
# Atomic write: .tmp → rename
|
|
tmp = path.with_suffix(".json.tmp")
|
|
tmp.write_text(json.dumps(result, indent=2))
|
|
tmp.rename(path)
|
|
|
|
def sync(self, on_progress: Iterator[dict] | None = None) -> dict[str, Any]:
|
|
"""Incremental sync: only process files modified since last sync."""
|
|
sync_result_path = self._sync_result_path()
|
|
last_sync = None
|
|
if sync_result_path.exists():
|
|
try:
|
|
last_sync = json.loads(sync_result_path.read_text()).get("timestamp")
|
|
except Exception:
|
|
pass
|
|
|
|
db = get_db(self.config)
|
|
table = create_table_if_not_exists(db)
|
|
embedder = self.embedder
|
|
|
|
files = list(self.scan_vault())
|
|
indexed_files = 0
|
|
total_chunks = 0
|
|
errors: list[dict] = []
|
|
|
|
for filepath in files:
|
|
mtime = datetime.fromtimestamp(filepath.stat().st_mtime, tz=timezone.utc)
|
|
mtime_str = mtime.isoformat()
|
|
if last_sync and mtime_str <= last_sync:
|
|
continue # unchanged
|
|
|
|
try:
|
|
num_chunks, enriched = self.process_file(filepath)
|
|
# Enforce sensitive content policies
|
|
self._check_sensitive_content_approval(enriched)
|
|
|
|
# Log sensitive content access
|
|
for chunk in enriched:
|
|
from obsidian_rag import security
|
|
|
|
sensitivity = security.detect_sensitive(
|
|
chunk["chunk_text"],
|
|
self.config.security.sensitive_sections,
|
|
self.config.memory.patterns,
|
|
)
|
|
for category in ["health", "financial", "relations"]:
|
|
if sensitivity.get(category, False):
|
|
self.audit_logger.log_sensitive_access(
|
|
str(chunk["source_file"]),
|
|
category,
|
|
"index",
|
|
{"chunk_id": chunk["chunk_id"]},
|
|
)
|
|
|
|
# Embed chunks
|
|
texts = [e["chunk_text"] for e in enriched]
|
|
try:
|
|
vectors = embedder.embed_chunks(texts)
|
|
except OllamaUnavailableError:
|
|
vectors = [[0.0] * 1024 for _ in texts]
|
|
for e, v in zip(enriched, vectors):
|
|
e["vector"] = v
|
|
upsert_chunks(table, enriched)
|
|
total_chunks += num_chunks
|
|
indexed_files += 1
|
|
except Exception as exc:
|
|
errors.append({"file": str(filepath), "error": str(exc)})
|
|
|
|
self._write_sync_result(indexed_files, total_chunks, errors)
|
|
return {
|
|
"indexed_files": indexed_files,
|
|
"total_chunks": total_chunks,
|
|
"errors": errors,
|
|
}
|
|
# Use the same dev-data-dir convention as config.py
|
|
project_root = Path(__file__).parent.parent.parent
|
|
data_dir = project_root / "obsidian-rag"
|
|
if not data_dir.exists() and not (project_root / "KnowledgeVault").exists():
|
|
data_dir = Path(os.path.expanduser("~/.obsidian-rag"))
|
|
return data_dir / "sync-result.json"
|