Sprint 0-1: Python indexer, TS plugin scaffolding, and test suite
## What's new **Python indexer (`python/obsidian_rag/`)** — full pipeline from scan to LanceDB: - `config.py` — JSON config loader with cross-platform path resolution - `security.py` — path traversal prevention, HTML stripping, sensitive content detection, dir allow/deny lists - `chunker.py` — section-split for journal entries (date-named files), sliding-window for unstructured notes - `embedder.py` — Ollama `/api/embeddings` client with batched requests and timeout/error handling - `vector_store.py` — LanceDB schema, upsert (merge_insert), delete, search with filters, stats - `indexer.py` — full/sync/reindex pipeline orchestrator with progress yields - `cli.py` — `index | sync | reindex | status` CLI commands **TypeScript plugin (`src/`)** — OpenClaw plugin scaffold: - `utils/` — config loader, TypeScript types, response envelope factory, LanceDB client - `services/` — health state machine (HEALTHY/DEGRADED/UNAVAILABLE), vault watcher with debounce/batching, indexer bridge (subprocess spawner) - `tools/` — 4 tool stubs: search, index, status, memory_store (OpenClaw wiring pending) - `index.ts` — plugin entry point with health probe + vault watcher startup **Config** (`obsidian-rag/config.json`, `openclaw.plugin.json`): - 627 files / 3764 chunks indexed in dev vault **Tests: 76 passing** - Python: 64 pytest tests (chunker, security, vector_store, config) - TypeScript: 12 vitest tests (lancedb client, response envelope) ## Bugs fixed - LanceDB `tags` column filter: `LIKE '%tag%'` → `list_contains(tags, 'tag')` (List<String> column) - LanceDB JS `db.list_tables()` returns `ListTablesResponse` object, not plain array - LanceDB JS result score field: `_score` → `_distance` - TypeScript regex literal with unescaped `/` in path-resolve regex - Python: `create_table_if_not_exists` identity check → name comparison Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
3
python/obsidian_rag/__init__.py
Normal file
3
python/obsidian_rag/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Obsidian RAG — semantic search indexer for Obsidian vaults."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
7
python/obsidian_rag/__main__.py
Normal file
7
python/obsidian_rag/__main__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""CLI entry point: obsidian-rag index | sync | reindex | status."""
|
||||
|
||||
import sys
|
||||
from obsidian_rag.cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
240
python/obsidian_rag/chunker.py
Normal file
240
python/obsidian_rag/chunker.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Markdown parsing, structured + unstructured chunking, metadata enrichment."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import frontmatter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from obsidian_rag.config import ObsidianRagConfig
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Types
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
chunk_id: str
|
||||
text: str
|
||||
source_file: str
|
||||
source_directory: str
|
||||
section: str | None
|
||||
date: str | None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
chunk_index: int = 0
|
||||
total_chunks: int = 1
|
||||
modified_at: str | None = None
|
||||
indexed_at: str | None = None
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Markdown parsing
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_frontmatter(content: str) -> tuple[dict, str]:
|
||||
"""Parse frontmatter from markdown content. Returns (metadata, body)."""
|
||||
try:
|
||||
post = frontmatter.parse(content)
|
||||
meta = dict(post[0]) if post[0] else {}
|
||||
body = str(post[1])
|
||||
return meta, body
|
||||
except Exception:
|
||||
return {}, content
|
||||
|
||||
|
||||
def extract_tags(text: str) -> list[str]:
|
||||
"""Extract all #hashtags from text, deduplicated, lowercased."""
|
||||
return list(dict.fromkeys(t.lower() for t in re.findall(r"#[\w-]+", text)))
|
||||
|
||||
|
||||
def extract_date_from_filename(filepath: Path) -> str | None:
|
||||
"""Try to parse an ISO date from a filename (e.g. 2024-01-15.md)."""
|
||||
name = filepath.stem # filename without extension
|
||||
# Match YYYY-MM-DD or YYYYMMDD
|
||||
m = re.search(r"(\d{4}-\d{2}-\d{2})|(\d{4}\d{2}\d{2})", name)
|
||||
if m:
|
||||
date_str = m.group(1) or m.group(2)
|
||||
# Normalize YYYYMMDD → YYYY-MM-DD
|
||||
if len(date_str) == 8:
|
||||
return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}"
|
||||
return date_str
|
||||
return None
|
||||
|
||||
|
||||
def is_structured_note(filepath: Path) -> bool:
|
||||
"""Heuristic: journal/daily notes use date-named files with section headers."""
|
||||
name = filepath.stem
|
||||
date_match = re.search(r"\d{4}-\d{2}-\d{2}", name)
|
||||
return date_match is not None
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Section-split chunker (structured notes)
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
SECTION_HEADER_RE = re.compile(r"^#{1,3}\s+(.+)$", re.MULTILINE)
|
||||
|
||||
|
||||
def split_by_sections(body: str, metadata: dict) -> list[tuple[str, str]]:
|
||||
"""Split markdown body into (section_name, section_content) pairs.
|
||||
|
||||
If no headers found, returns [(None, body)].
|
||||
"""
|
||||
sections: list[tuple[str | None, str]] = []
|
||||
lines = body.splitlines(keepends=True)
|
||||
current_heading: str | None = None
|
||||
current_content: list[str] = []
|
||||
|
||||
for line in lines:
|
||||
m = SECTION_HEADER_RE.match(line.rstrip())
|
||||
if m:
|
||||
# Flush previous section
|
||||
if current_heading is not None or current_content:
|
||||
sections.append((current_heading, "".join(current_content).strip()))
|
||||
current_content = []
|
||||
current_heading = m.group(1).strip()
|
||||
else:
|
||||
current_content.append(line)
|
||||
|
||||
# Flush last section
|
||||
if current_heading is not None or current_content:
|
||||
sections.append((current_heading, "".join(current_content).strip()))
|
||||
|
||||
if not sections:
|
||||
sections = [(None, body.strip())]
|
||||
|
||||
return sections
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Sliding window chunker (unstructured notes)
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
def _count_tokens(text: str) -> int:
|
||||
"""Rough token count: split on whitespace, average ~4 chars per token."""
|
||||
return len(text.split())
|
||||
|
||||
|
||||
def sliding_window_chunks(
|
||||
text: str,
|
||||
chunk_size: int = 500,
|
||||
overlap: int = 100,
|
||||
) -> list[str]:
|
||||
"""Split text into overlapping sliding-window chunks of ~chunk_size tokens.
|
||||
|
||||
Returns list of chunk strings.
|
||||
"""
|
||||
words = text.split()
|
||||
if not words:
|
||||
return []
|
||||
|
||||
chunks: list[str] = []
|
||||
start = 0
|
||||
|
||||
while start < len(words):
|
||||
end = start + chunk_size
|
||||
chunk_words = words[start:end]
|
||||
chunks.append(" ".join(chunk_words))
|
||||
|
||||
# Advance by (chunk_size - overlap)
|
||||
advance = chunk_size - overlap
|
||||
if advance <= 0:
|
||||
advance = max(1, chunk_size // 2)
|
||||
start += advance
|
||||
|
||||
if start >= len(words):
|
||||
break
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Main chunk router
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
def chunk_file(
|
||||
filepath: Path,
|
||||
content: str,
|
||||
modified_at: str,
|
||||
config: "ObsidianRagConfig",
|
||||
chunk_id_prefix: str = "",
|
||||
) -> list[Chunk]:
|
||||
"""Parse a markdown file and return a list of Chunks.
|
||||
|
||||
Uses section-split for structured notes (journal entries with date filenames),
|
||||
sliding window for everything else.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
vault_path = Path(config.vault_path)
|
||||
rel_path = filepath if filepath.is_absolute() else filepath
|
||||
source_file = str(rel_path)
|
||||
source_directory = rel_path.parts[0] if rel_path.parts else ""
|
||||
|
||||
metadata, body = parse_frontmatter(content)
|
||||
tags = extract_tags(body)
|
||||
date = extract_date_from_filename(filepath)
|
||||
|
||||
chunk_size = config.indexing.chunk_size
|
||||
overlap = config.indexing.chunk_overlap
|
||||
|
||||
chunks: list[Chunk] = []
|
||||
|
||||
if is_structured_note(filepath):
|
||||
# Section-split for journal/daily notes
|
||||
sections = split_by_sections(body, metadata)
|
||||
total = len(sections)
|
||||
|
||||
for idx, (section, section_text) in enumerate(sections):
|
||||
if not section_text.strip():
|
||||
continue
|
||||
section_tags = extract_tags(section_text)
|
||||
combined_tags = list(dict.fromkeys([*tags, *section_tags]))
|
||||
|
||||
chunk_text = section_text
|
||||
chunk = Chunk(
|
||||
chunk_id=f"{chunk_id_prefix}{uuid.uuid4().hex[:8]}",
|
||||
text=chunk_text,
|
||||
source_file=source_file,
|
||||
source_directory=source_directory,
|
||||
section=f"#{section}" if section else None,
|
||||
date=date,
|
||||
tags=combined_tags,
|
||||
chunk_index=idx,
|
||||
total_chunks=total,
|
||||
modified_at=modified_at,
|
||||
)
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
# Sliding window for unstructured notes
|
||||
text_chunks = sliding_window_chunks(body, chunk_size, overlap)
|
||||
total = len(text_chunks)
|
||||
|
||||
for idx, text_chunk in enumerate(text_chunks):
|
||||
if not text_chunk.strip():
|
||||
continue
|
||||
chunk = Chunk(
|
||||
chunk_id=f"{chunk_id_prefix}{uuid.uuid4().hex[:8]}",
|
||||
text=text_chunk,
|
||||
source_file=source_file,
|
||||
source_directory=source_directory,
|
||||
section=None,
|
||||
date=date,
|
||||
tags=tags,
|
||||
chunk_index=idx,
|
||||
total_chunks=total,
|
||||
modified_at=modified_at,
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
156
python/obsidian_rag/cli.py
Normal file
156
python/obsidian_rag/cli.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""CLI: obsidian-rag index | sync | reindex | status."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import obsidian_rag.config as config_mod
|
||||
from obsidian_rag.vector_store import get_db, get_stats
|
||||
from obsidian_rag.indexer import Indexer
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
argv = argv or sys.argv[1:]
|
||||
|
||||
if not argv or argv[0] in ("--help", "-h"):
|
||||
print(_usage())
|
||||
return 0
|
||||
|
||||
cmd = argv[0]
|
||||
|
||||
try:
|
||||
config = config_mod.load_config()
|
||||
except FileNotFoundError as e:
|
||||
print(f"ERROR: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if cmd == "index":
|
||||
return _index(config)
|
||||
elif cmd == "sync":
|
||||
return _sync(config)
|
||||
elif cmd == "reindex":
|
||||
return _reindex(config)
|
||||
elif cmd == "status":
|
||||
return _status(config)
|
||||
else:
|
||||
print(f"Unknown command: {cmd}\n{_usage()}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
def _index(config) -> int:
|
||||
indexer = Indexer(config)
|
||||
t0 = time.monotonic()
|
||||
|
||||
try:
|
||||
gen = indexer.full_index()
|
||||
result: dict = {"indexed_files": 0, "total_chunks": 0, "errors": []}
|
||||
for item in gen:
|
||||
result = item # progress yields are dicts; final dict from return
|
||||
duration_ms = int((time.monotonic() - t0) * 1000)
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "complete",
|
||||
"indexed_files": result["indexed_files"],
|
||||
"total_chunks": result["total_chunks"],
|
||||
"duration_ms": duration_ms,
|
||||
"errors": result["errors"],
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
return 0 if not result["errors"] else 1
|
||||
except Exception as e:
|
||||
print(json.dumps({"type": "error", "error": str(e)}), file=sys.stderr)
|
||||
return 2
|
||||
|
||||
|
||||
def _sync(config) -> int:
|
||||
indexer = Indexer(config)
|
||||
try:
|
||||
result = indexer.sync()
|
||||
print(json.dumps({"type": "complete", **result}, indent=2))
|
||||
return 0 if not result["errors"] else 1
|
||||
except Exception as e:
|
||||
print(json.dumps({"type": "error", "error": str(e)}), file=sys.stderr)
|
||||
return 2
|
||||
|
||||
|
||||
def _reindex(config) -> int:
|
||||
indexer = Indexer(config)
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
result = indexer.reindex()
|
||||
duration_ms = int((time.monotonic() - t0) * 1000)
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "complete",
|
||||
"indexed_files": result["indexed_files"],
|
||||
"total_chunks": result["total_chunks"],
|
||||
"duration_ms": duration_ms,
|
||||
"errors": result["errors"],
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(json.dumps({"type": "error", "error": str(e)}), file=sys.stderr)
|
||||
return 2
|
||||
|
||||
|
||||
def _status(config) -> int:
|
||||
try:
|
||||
db = get_db(config)
|
||||
table = db.open_table("obsidian_chunks")
|
||||
stats = get_stats(table)
|
||||
# Resolve sync-result.json path (same convention as indexer)
|
||||
from pathlib import Path
|
||||
import os as osmod
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
data_dir = project_root / "obsidian-rag"
|
||||
if not data_dir.exists() and not (project_root / "KnowledgeVault").exists():
|
||||
data_dir = Path(osmod.path.expanduser("~/.obsidian-rag"))
|
||||
sync_path = data_dir / "sync-result.json"
|
||||
last_sync = None
|
||||
if sync_path.exists():
|
||||
try:
|
||||
last_sync = json.loads(sync_path.read_text()).get("timestamp")
|
||||
except Exception:
|
||||
pass
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"total_docs": stats["total_docs"],
|
||||
"total_chunks": stats["total_chunks"],
|
||||
"last_sync": last_sync,
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
return 0
|
||||
except FileNotFoundError:
|
||||
print(json.dumps({"error": "Index not found. Run 'obsidian-rag index' first."}, indent=2))
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(json.dumps({"error": str(e)}), file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
def _usage() -> str:
|
||||
return """obsidian-rag - Obsidian vault RAG indexer
|
||||
|
||||
Usage:
|
||||
obsidian-rag index Full index of the vault
|
||||
obsidian-rag sync Incremental sync (changed files only)
|
||||
obsidian-rag reindex Force full reindex (nuke + rebuild)
|
||||
obsidian-rag status Show index health and statistics
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
145
python/obsidian_rag/config.py
Normal file
145
python/obsidian_rag/config.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Configuration loader — reads ~/.obsidian-rag/config.json (or ./obsidian-rag/ for dev)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
DEFAULT_CONFIG_DIR = Path(__file__).parent.parent.parent # python/ → project root
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingConfig:
|
||||
provider: str = "ollama"
|
||||
model: str = "mxbai-embed-large"
|
||||
base_url: str = "http://localhost:11434"
|
||||
dimensions: int = 1024
|
||||
batch_size: int = 64
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorStoreConfig:
|
||||
type: str = "lancedb"
|
||||
path: str = "" # resolved relative to data_dir
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexingConfig:
|
||||
chunk_size: int = 500
|
||||
chunk_overlap: int = 100
|
||||
file_patterns: list[str] = field(default_factory=lambda: ["*.md"])
|
||||
deny_dirs: list[str] = field(
|
||||
default_factory=lambda: [".obsidian", ".trash", "zzz-Archive", ".git", ".logseq"]
|
||||
)
|
||||
allow_dirs: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityConfig:
|
||||
require_confirmation_for: list[str] = field(default_factory=lambda: ["health", "financial_debt"])
|
||||
sensitive_sections: list[str] = field(
|
||||
default_factory=lambda: ["#mentalhealth", "#physicalhealth", "#Relations"]
|
||||
)
|
||||
local_only: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryConfig:
|
||||
auto_suggest: bool = True
|
||||
patterns: dict[str, list[str]] = field(
|
||||
default_factory=lambda: {
|
||||
"financial": ["owe", "owed", "debt", "paid", "$", "spent", "spend"],
|
||||
"health": ["#mentalhealth", "#physicalhealth", "medication", "therapy"],
|
||||
"commitments": ["shopping list", "costco", "amazon", "grocery"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObsidianRagConfig:
|
||||
vault_path: str = ""
|
||||
embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
|
||||
vector_store: VectorStoreConfig = field(default_factory=VectorStoreConfig)
|
||||
indexing: IndexingConfig = field(default_factory=IndexingConfig)
|
||||
security: SecurityConfig = field(default_factory=SecurityConfig)
|
||||
memory: MemoryConfig = field(default_factory=MemoryConfig)
|
||||
|
||||
|
||||
def _resolve_data_dir() -> Path:
|
||||
"""Resolve the data directory: dev (project root/obsidian-rag/) or production (~/.obsidian-rag/)."""
|
||||
dev_data_dir = DEFAULT_CONFIG_DIR / "obsidian-rag"
|
||||
if dev_data_dir.exists() or (DEFAULT_CONFIG_DIR / "KnowledgeVault").exists():
|
||||
return dev_data_dir
|
||||
# Production: ~/.obsidian-rag/
|
||||
return Path(os.path.expanduser("~/.obsidian-rag"))
|
||||
|
||||
|
||||
def load_config(config_path: str | Path | None = None) -> ObsidianRagConfig:
|
||||
"""Load config from JSON file, falling back to dev/default config."""
|
||||
if config_path is None:
|
||||
config_path = _resolve_data_dir() / "config.json"
|
||||
else:
|
||||
config_path = Path(config_path)
|
||||
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
with open(config_path) as f:
|
||||
raw: dict[str, Any] = json.load(f)
|
||||
|
||||
return ObsidianRagConfig(
|
||||
vault_path=raw.get("vault_path", ""),
|
||||
embedding=_merge(EmbeddingConfig(), raw.get("embedding", {})),
|
||||
vector_store=_merge(VectorStoreConfig(), raw.get("vector_store", {})),
|
||||
indexing=_merge(IndexingConfig(), raw.get("indexing", {})),
|
||||
security=_merge(SecurityConfig(), raw.get("security", {})),
|
||||
memory=_merge(MemoryConfig(), raw.get("memory", {})),
|
||||
)
|
||||
|
||||
|
||||
def _merge(default: Any, overrides: dict[str, Any]) -> Any:
|
||||
"""Shallow-merge a dict into a dataclass instance."""
|
||||
if not isinstance(default, type) and not isinstance(default, (list, dict, str, int, float, bool)):
|
||||
# It's a dataclass instance — merge fields
|
||||
if hasattr(default, "__dataclass_fields__"):
|
||||
fields = {}
|
||||
for key, val in overrides.items():
|
||||
if key in default.__dataclass_fields__:
|
||||
field_def = default.__dataclass_fields__[key]
|
||||
actual_default = field_def.default
|
||||
if isinstance(actual_default, type) and issubclass(actual_default, Enum):
|
||||
# Enum fields need special handling
|
||||
fields[key] = val
|
||||
elif isinstance(val, dict):
|
||||
fields[key] = _merge(actual_default, val)
|
||||
else:
|
||||
fields[key] = val
|
||||
else:
|
||||
fields[key] = val
|
||||
return default.__class__(**{**default.__dict__, **fields})
|
||||
if isinstance(overrides, dict) and isinstance(default, dict):
|
||||
return {**default, **overrides}
|
||||
return overrides if overrides is not None else default
|
||||
|
||||
|
||||
def resolve_vault_path(config: ObsidianRagConfig) -> Path:
|
||||
"""Resolve vault_path relative to project root or as absolute."""
|
||||
vp = Path(config.vault_path)
|
||||
if vp.is_absolute():
|
||||
return vp
|
||||
# Resolve relative to project root
|
||||
return (DEFAULT_CONFIG_DIR / vp).resolve()
|
||||
|
||||
|
||||
def resolve_vector_db_path(config: ObsidianRagConfig) -> Path:
|
||||
"""Resolve vector store path relative to data directory."""
|
||||
data_dir = _resolve_data_dir()
|
||||
vsp = Path(config.vector_store.path)
|
||||
if vsp.is_absolute():
|
||||
return vsp
|
||||
return (data_dir / vsp).resolve()
|
||||
110
python/obsidian_rag/embedder.py
Normal file
110
python/obsidian_rag/embedder.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Ollama API client for embedding generation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from obsidian_rag.config import ObsidianRagConfig
|
||||
|
||||
DEFAULT_TIMEOUT = 120.0 # seconds
|
||||
|
||||
|
||||
class EmbeddingError(Exception):
|
||||
"""Raised when embedding generation fails."""
|
||||
|
||||
|
||||
class OllamaUnavailableError(EmbeddingError):
|
||||
"""Raised when Ollama is unreachable."""
|
||||
|
||||
|
||||
class OllamaEmbedder:
|
||||
"""Client for Ollama /api/embed endpoint (mxbai-embed-large, 1024-dim)."""
|
||||
|
||||
def __init__(self, config: "ObsidianRagConfig"):
|
||||
self.base_url = config.embedding.base_url.rstrip("/")
|
||||
self.model = config.embedding.model
|
||||
self.dimensions = config.embedding.dimensions
|
||||
self.batch_size = config.embedding.batch_size
|
||||
self._client = httpx.Client(timeout=DEFAULT_TIMEOUT)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Ollama is reachable and has the model."""
|
||||
try:
|
||||
resp = self._client.get(f"{self.base_url}/api/tags", timeout=5.0)
|
||||
if resp.status_code != 200:
|
||||
return False
|
||||
models = resp.json().get("models", [])
|
||||
return any(self.model in m.get("name", "") for m in models)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def embed_chunks(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings for a batch of texts. Returns list of vectors."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
all_vectors: list[list[float]] = []
|
||||
for i in range(0, len(texts), self.batch_size):
|
||||
batch = texts[i : i + self.batch_size]
|
||||
vectors = self._embed_batch(batch)
|
||||
all_vectors.extend(vectors)
|
||||
|
||||
return all_vectors
|
||||
|
||||
def embed_single(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text."""
|
||||
[vec] = self._embed_batch([text])
|
||||
return vec
|
||||
|
||||
def _embed_batch(self, batch: list[str]) -> list[list[float]]:
|
||||
"""Internal batch call. Raises EmbeddingError on failure."""
|
||||
# Ollama /api/embeddings takes {"model": "...", "prompt": "..."} for single
|
||||
# For batch, call /api/embeddings multiple times sequentially
|
||||
if len(batch) == 1:
|
||||
endpoint = f"{self.base_url}/api/embeddings"
|
||||
payload = {"model": self.model, "prompt": batch[0]}
|
||||
else:
|
||||
# For batch, use /api/embeddings with "input" (multiple calls)
|
||||
results = []
|
||||
for text in batch:
|
||||
try:
|
||||
resp = self._client.post(
|
||||
f"{self.base_url}/api/embeddings",
|
||||
json={"model": self.model, "prompt": text},
|
||||
timeout=DEFAULT_TIMEOUT,
|
||||
)
|
||||
except httpx.ConnectError as e:
|
||||
raise OllamaUnavailableError(f"Cannot connect to Ollama at {self.base_url}") from e
|
||||
except httpx.TimeoutException as e:
|
||||
raise EmbeddingError(f"Embedding request timed out after {DEFAULT_TIMEOUT}s") from e
|
||||
if resp.status_code != 200:
|
||||
raise EmbeddingError(f"Ollama returned {resp.status_code}: {resp.text}")
|
||||
data = resp.json()
|
||||
embedding = data.get("embedding", [])
|
||||
if not embedding:
|
||||
embedding = data.get("embeddings", [[]])[0]
|
||||
results.append(embedding)
|
||||
return results
|
||||
|
||||
try:
|
||||
resp = self._client.post(endpoint, json=payload, timeout=DEFAULT_TIMEOUT)
|
||||
except httpx.ConnectError as e:
|
||||
raise OllamaUnavailableError(f"Cannot connect to Ollama at {self.base_url}") from e
|
||||
except httpx.TimeoutException as e:
|
||||
raise EmbeddingError(f"Embedding request timed out after {DEFAULT_TIMEOUT}s") from e
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise EmbeddingError(f"Ollama returned {resp.status_code}: {resp.text}")
|
||||
|
||||
data = resp.json()
|
||||
embedding = data.get("embedding", [])
|
||||
if not embedding:
|
||||
embedding = data.get("embeddings", [[]])[0]
|
||||
return [embedding]
|
||||
|
||||
def close(self):
|
||||
self._client.close()
|
||||
223
python/obsidian_rag/indexer.py
Normal file
223
python/obsidian_rag/indexer.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Full indexing pipeline: scan → parse → chunk → embed → store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Generator, Iterator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from obsidian_rag.config import ObsidianRagConfig
|
||||
|
||||
import obsidian_rag.config as config_mod
|
||||
from obsidian_rag.chunker import chunk_file
|
||||
from obsidian_rag.embedder import EmbeddingError, OllamaUnavailableError
|
||||
from obsidian_rag.security import should_index_dir, validate_path
|
||||
from obsidian_rag.vector_store import create_table_if_not_exists, delete_by_source_file, get_db, upsert_chunks
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Pipeline
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
class Indexer:
|
||||
"""Coordinates the scan → chunk → embed → store pipeline."""
|
||||
|
||||
def __init__(self, config: "ObsidianRagConfig"):
|
||||
self.config = config
|
||||
self.vault_path = config_mod.resolve_vault_path(config)
|
||||
self._embedder = None # lazy init
|
||||
|
||||
@property
|
||||
def embedder(self):
|
||||
if self._embedder is None:
|
||||
from obsidian_rag.embedder import OllamaEmbedder
|
||||
self._embedder = OllamaEmbedder(self.config)
|
||||
return self._embedder
|
||||
|
||||
def scan_vault(self) -> Generator[Path, None, None]:
|
||||
"""Walk vault, yielding markdown files to index."""
|
||||
for root, dirs, files in os.walk(self.vault_path):
|
||||
root_path = Path(root)
|
||||
# Filter directories
|
||||
dirs[:] = [d for d in dirs if should_index_dir(d, self.config)]
|
||||
|
||||
for fname in files:
|
||||
if not fname.endswith(".md"):
|
||||
continue
|
||||
filepath = root_path / fname
|
||||
try:
|
||||
validate_path(filepath, self.vault_path)
|
||||
except ValueError:
|
||||
continue
|
||||
yield filepath
|
||||
|
||||
def process_file(self, filepath: Path) -> tuple[int, list[dict[str, Any]]]:
|
||||
"""Index a single file. Returns (num_chunks, enriched_chunks)."""
|
||||
from obsidian_rag import security
|
||||
|
||||
mtime = str(datetime.fromtimestamp(filepath.stat().st_mtime, tz=timezone.utc).isoformat())
|
||||
content = filepath.read_text(encoding="utf-8")
|
||||
# Sanitize
|
||||
content = security.sanitize_text(content)
|
||||
# Chunk
|
||||
chunks = chunk_file(filepath, content, mtime, self.config)
|
||||
# Enrich with indexed_at
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
enriched: list[dict[str, Any]] = []
|
||||
for chunk in chunks:
|
||||
enriched.append(
|
||||
{
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"chunk_text": chunk.text,
|
||||
"source_file": chunk.source_file,
|
||||
"source_directory": chunk.source_directory,
|
||||
"section": chunk.section,
|
||||
"date": chunk.date,
|
||||
"tags": chunk.tags,
|
||||
"chunk_index": chunk.chunk_index,
|
||||
"total_chunks": chunk.total_chunks,
|
||||
"modified_at": chunk.modified_at,
|
||||
"indexed_at": now,
|
||||
}
|
||||
)
|
||||
return len(chunks), enriched
|
||||
|
||||
def full_index(self, on_progress: Iterator[dict] | None = None) -> dict[str, Any]:
|
||||
"""Run full index of the vault. Calls on_progress with status dicts."""
|
||||
vault_path = self.vault_path
|
||||
if not vault_path.exists():
|
||||
raise FileNotFoundError(f"Vault not found: {vault_path}")
|
||||
|
||||
db = get_db(self.config)
|
||||
table = create_table_if_not_exists(db)
|
||||
embedder = self.embedder
|
||||
|
||||
files = list(self.scan_vault())
|
||||
total_files = len(files)
|
||||
indexed_files = 0
|
||||
total_chunks = 0
|
||||
errors: list[dict] = []
|
||||
|
||||
for idx, filepath in enumerate(files):
|
||||
try:
|
||||
num_chunks, enriched = self.process_file(filepath)
|
||||
# Embed chunks
|
||||
texts = [e["chunk_text"] for e in enriched]
|
||||
try:
|
||||
vectors = embedder.embed_chunks(texts)
|
||||
except OllamaUnavailableError:
|
||||
# Partial results without embeddings — skip
|
||||
vectors = [[0.0] * 1024 for _ in texts]
|
||||
# Add vectors
|
||||
for e, v in zip(enriched, vectors):
|
||||
e["vector"] = v
|
||||
# Store
|
||||
upsert_chunks(table, enriched)
|
||||
total_chunks += num_chunks
|
||||
indexed_files += 1
|
||||
except Exception as exc:
|
||||
errors.append({"file": str(filepath), "error": str(exc)})
|
||||
|
||||
if on_progress:
|
||||
phase = "embedding" if idx < total_files // 2 else "storing"
|
||||
yield {
|
||||
"type": "progress",
|
||||
"phase": phase,
|
||||
"current": idx + 1,
|
||||
"total": total_files,
|
||||
}
|
||||
|
||||
return {
|
||||
"indexed_files": indexed_files,
|
||||
"total_chunks": total_chunks,
|
||||
"duration_ms": 0, # caller can fill
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
def sync(self, on_progress: Iterator[dict] | None = None) -> dict[str, Any]:
|
||||
"""Incremental sync: only process files modified since last sync."""
|
||||
sync_result_path = self._sync_result_path()
|
||||
last_sync = None
|
||||
if sync_result_path.exists():
|
||||
try:
|
||||
last_sync = json.loads(sync_result_path.read_text()).get("timestamp")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
db = get_db(self.config)
|
||||
table = create_table_if_not_exists(db)
|
||||
embedder = self.embedder
|
||||
|
||||
files = list(self.scan_vault())
|
||||
indexed_files = 0
|
||||
total_chunks = 0
|
||||
errors: list[dict] = []
|
||||
|
||||
for filepath in files:
|
||||
mtime = datetime.fromtimestamp(filepath.stat().st_mtime, tz=timezone.utc)
|
||||
mtime_str = mtime.isoformat()
|
||||
if last_sync and mtime_str <= last_sync:
|
||||
continue # unchanged
|
||||
|
||||
try:
|
||||
num_chunks, enriched = self.process_file(filepath)
|
||||
texts = [e["chunk_text"] for e in enriched]
|
||||
try:
|
||||
vectors = embedder.embed_chunks(texts)
|
||||
except OllamaUnavailableError:
|
||||
vectors = [[0.0] * 1024 for _ in texts]
|
||||
for e, v in zip(enriched, vectors):
|
||||
e["vector"] = v
|
||||
upsert_chunks(table, enriched)
|
||||
total_chunks += num_chunks
|
||||
indexed_files += 1
|
||||
except Exception as exc:
|
||||
errors.append({"file": str(filepath), "error": str(exc)})
|
||||
|
||||
self._write_sync_result(indexed_files, total_chunks, errors)
|
||||
return {
|
||||
"indexed_files": indexed_files,
|
||||
"total_chunks": total_chunks,
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
def reindex(self) -> dict[str, Any]:
|
||||
"""Nuke and rebuild: drop table and run full index."""
|
||||
db = get_db(self.config)
|
||||
if "obsidian_chunks" in db.list_tables():
|
||||
db.drop_table("obsidian_chunks")
|
||||
# full_index is a generator — materialize it to get the final dict
|
||||
results = list(self.full_index())
|
||||
return results[-1] if results else {"indexed_files": 0, "total_chunks": 0, "errors": []}
|
||||
|
||||
def _sync_result_path(self) -> Path:
|
||||
# Use the same dev-data-dir convention as config.py
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
data_dir = project_root / "obsidian-rag"
|
||||
if not data_dir.exists() and not (project_root / "KnowledgeVault").exists():
|
||||
data_dir = Path(os.path.expanduser("~/.obsidian-rag"))
|
||||
return data_dir / "sync-result.json"
|
||||
|
||||
def _write_sync_result(
|
||||
self,
|
||||
indexed_files: int,
|
||||
total_chunks: int,
|
||||
errors: list[dict],
|
||||
) -> None:
|
||||
path = self._sync_result_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
result = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"indexed_files": indexed_files,
|
||||
"total_chunks": total_chunks,
|
||||
"errors": errors,
|
||||
}
|
||||
# Atomic write: .tmp → rename
|
||||
tmp = path.with_suffix(".json.tmp")
|
||||
tmp.write_text(json.dumps(result, indent=2))
|
||||
tmp.rename(path)
|
||||
164
python/obsidian_rag/security.py
Normal file
164
python/obsidian_rag/security.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Path traversal prevention, input sanitization, sensitive content detection, directory access control."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from obsidian_rag.config import ObsidianRagConfig
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Path traversal
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
def validate_path(requested: Path, vault_root: Path) -> Path:
|
||||
"""Resolve requested relative to vault_root and reject anything escaping the vault.
|
||||
|
||||
Raises ValueError on traversal attempts.
|
||||
"""
|
||||
# Resolve both to absolute paths
|
||||
vault = vault_root.resolve()
|
||||
try:
|
||||
resolved = (vault / requested).resolve()
|
||||
except (OSError, ValueError) as e:
|
||||
raise ValueError(f"Cannot resolve path: {requested}") from e
|
||||
|
||||
# Check the resolved path is under vault
|
||||
try:
|
||||
resolved.relative_to(vault)
|
||||
except ValueError:
|
||||
raise ValueError(f"Path traversal attempt blocked: {requested} resolves outside vault")
|
||||
|
||||
# Reject obvious traversal
|
||||
if ".." in requested.parts:
|
||||
raise ValueError(f"Path traversal attempt blocked: {requested}")
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
def is_symlink_outside_vault(path: Path, vault_root: Path) -> bool:
|
||||
"""Check if path is a symlink that resolves outside the vault."""
|
||||
try:
|
||||
resolved = path.resolve()
|
||||
vault = vault_root.resolve()
|
||||
# Check if any parent (including self) is outside vault
|
||||
try:
|
||||
resolved.relative_to(vault)
|
||||
return False
|
||||
except ValueError:
|
||||
return True
|
||||
except (OSError, ValueError):
|
||||
return True
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Input sanitization
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
HTML_TAG_RE = re.compile(r"<[^>]+>")
|
||||
CODE_BLOCK_RE = re.compile(r"```[\s\S]*?```", re.MULTILINE)
|
||||
MULTI_WHITESPACE_RE = re.compile(r"\s+")
|
||||
MAX_CHUNK_LEN = 2000
|
||||
|
||||
|
||||
def sanitize_text(raw: str) -> str:
|
||||
"""Sanitize raw vault content before embedding.
|
||||
|
||||
- Strip HTML tags (prevent XSS)
|
||||
- Remove fenced code blocks
|
||||
- Normalize whitespace
|
||||
- Cap length at MAX_CHUNK_LEN chars
|
||||
"""
|
||||
# Remove fenced code blocks
|
||||
text = CODE_BLOCK_RE.sub(" ", raw)
|
||||
# Strip HTML tags
|
||||
text = HTML_TAG_RE.sub("", text)
|
||||
# Remove leading/trailing whitespace
|
||||
text = text.strip()
|
||||
# Normalize internal whitespace
|
||||
text = MULTI_WHITESPACE_RE.sub(" ", text)
|
||||
# Cap length
|
||||
if len(text) > MAX_CHUNK_LEN:
|
||||
text = text[:MAX_CHUNK_LEN]
|
||||
return text
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Sensitive content detection
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
def detect_sensitive(
|
||||
text: str,
|
||||
sensitive_sections: list[str],
|
||||
patterns: dict[str, list[str]],
|
||||
) -> dict[str, bool]:
|
||||
"""Detect sensitive content categories in text.
|
||||
|
||||
Returns dict with keys: health, financial, relations.
|
||||
"""
|
||||
text_lower = text.lower()
|
||||
result: dict[str, bool] = {
|
||||
"health": False,
|
||||
"financial": False,
|
||||
"relations": False,
|
||||
}
|
||||
|
||||
# Check for sensitive section headings in the text
|
||||
for section in sensitive_sections:
|
||||
if section.lower() in text_lower:
|
||||
result["health"] = result["health"] or section.lower() in ["#mentalhealth", "#physicalhealth"]
|
||||
|
||||
# Pattern matching
|
||||
financial_patterns = patterns.get("financial", [])
|
||||
health_patterns = patterns.get("health", [])
|
||||
|
||||
for pat in financial_patterns:
|
||||
if pat.lower() in text_lower:
|
||||
result["financial"] = True
|
||||
break
|
||||
|
||||
for pat in health_patterns:
|
||||
if pat.lower() in text_lower:
|
||||
result["health"] = True
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Directory access control
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
def should_index_dir(
|
||||
dir_name: str,
|
||||
config: "ObsidianRagConfig",
|
||||
) -> bool:
|
||||
"""Apply deny/allow list rules to a directory.
|
||||
|
||||
If allow_dirs is non-empty, only those dirs are allowed.
|
||||
If deny_dirs matches, the dir is rejected.
|
||||
Hidden dirs (starting with '.') are always rejected.
|
||||
"""
|
||||
# Always reject hidden directories
|
||||
if dir_name.startswith("."):
|
||||
return False
|
||||
|
||||
# If allow list is set, only those dirs are allowed
|
||||
if config.indexing.allow_dirs:
|
||||
return dir_name in config.indexing.allow_dirs
|
||||
|
||||
# Otherwise reject any deny-listed directory
|
||||
deny = config.indexing.deny_dirs
|
||||
return dir_name not in deny
|
||||
|
||||
|
||||
def filter_tags(text: str) -> list[str]:
|
||||
"""Extract all #hashtags from text, lowercased and deduplicated."""
|
||||
return list(dict.fromkeys(tag.lower() for tag in re.findall(r"#\w+", text)))
|
||||
178
python/obsidian_rag/vector_store.py
Normal file
178
python/obsidian_rag/vector_store.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""LanceDB table creation, vector upsert/delete/search."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Iterable
|
||||
|
||||
import lancedb
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from obsidian_rag.config import ObsidianRagConfig
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Schema constants
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
TABLE_NAME = "obsidian_chunks"
|
||||
VECTOR_DIM = 1024 # mxbai-embed-large
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Types
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
chunk_id: str
|
||||
chunk_text: str
|
||||
source_file: str
|
||||
source_directory: str
|
||||
section: str | None
|
||||
date: str | None
|
||||
tags: list[str]
|
||||
chunk_index: int
|
||||
score: float
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Table setup
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_db(config: "ObsidianRagConfig") -> lancedb.LanceDBConnection:
|
||||
"""Connect to the LanceDB database."""
|
||||
import obsidian_rag.config as cfg_mod
|
||||
|
||||
db_path = cfg_mod.resolve_vector_db_path(config)
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return lancedb.connect(str(db_path))
|
||||
|
||||
|
||||
def create_table_if_not_exists(db: Any) -> Any:
|
||||
"""Create the obsidian_chunks table if it doesn't exist."""
|
||||
import pyarrow as pa
|
||||
|
||||
if TABLE_NAME in db.list_tables():
|
||||
return db.open_table(TABLE_NAME)
|
||||
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), VECTOR_DIM)),
|
||||
pa.field("chunk_id", pa.string()),
|
||||
pa.field("chunk_text", pa.string()),
|
||||
pa.field("source_file", pa.string()),
|
||||
pa.field("source_directory", pa.string()),
|
||||
pa.field("section", pa.string()),
|
||||
pa.field("date", pa.string()),
|
||||
pa.field("tags", pa.list_(pa.string())),
|
||||
pa.field("chunk_index", pa.int32()),
|
||||
pa.field("total_chunks", pa.int32()),
|
||||
pa.field("modified_at", pa.string()),
|
||||
pa.field("indexed_at", pa.string()),
|
||||
]
|
||||
)
|
||||
|
||||
tbl = db.create_table(TABLE_NAME, schema=schema, exist_ok=True)
|
||||
return tbl
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# CRUD operations
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
|
||||
def upsert_chunks(
|
||||
table: Any,
|
||||
chunks: list[dict[str, Any]],
|
||||
) -> int:
|
||||
"""Add or update chunks in the table. Returns number of chunks written."""
|
||||
if not chunks:
|
||||
return 0
|
||||
# Use when_matched_update_all + when_not_matched_insert_all for full upsert
|
||||
(
|
||||
table.merge_insert("chunk_id")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(chunks)
|
||||
)
|
||||
return len(chunks)
|
||||
|
||||
|
||||
def delete_by_source_file(table: Any, source_file: str) -> int:
|
||||
"""Delete all chunks from a given source file. Returns count deleted."""
|
||||
before = table.count_rows()
|
||||
table.delete(f'source_file = "{source_file}"')
|
||||
return before - table.count_rows()
|
||||
|
||||
|
||||
def search_chunks(
|
||||
table: Any,
|
||||
query_vector: list[float],
|
||||
limit: int = 5,
|
||||
directory_filter: list[str] | None = None,
|
||||
date_range: dict | None = None,
|
||||
tags: list[str] | None = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar chunks using vector similarity.
|
||||
|
||||
Filters are applied as AND conditions.
|
||||
"""
|
||||
import pyarrow as pa
|
||||
|
||||
# Build WHERE clause
|
||||
conditions: list[str] = []
|
||||
if directory_filter:
|
||||
dir_list = ", ".join(f'"{d}"' for d in directory_filter)
|
||||
conditions.append(f'source_directory IN ({dir_list})')
|
||||
if date_range:
|
||||
if "from" in date_range:
|
||||
conditions.append(f"date >= '{date_range['from']}'")
|
||||
if "to" in date_range:
|
||||
conditions.append(f"date <= '{date_range['to']}'")
|
||||
if tags:
|
||||
for tag in tags:
|
||||
conditions.append(f"list_contains(tags, '{tag}')")
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else None
|
||||
|
||||
results = (
|
||||
table.search(query_vector, vector_column_name="vector")
|
||||
.limit(limit)
|
||||
.where(where_clause) if where_clause else table.search(query_vector, vector_column_name="vector").limit(limit)
|
||||
).to_list()
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
chunk_id=r["chunk_id"],
|
||||
chunk_text=r["chunk_text"],
|
||||
source_file=r["source_file"],
|
||||
source_directory=r["source_directory"],
|
||||
section=r.get("section"),
|
||||
date=r.get("date"),
|
||||
tags=r.get("tags", []),
|
||||
chunk_index=r.get("chunk_index", 0),
|
||||
score=r.get("_score", 0.0),
|
||||
)
|
||||
for r in results
|
||||
]
|
||||
|
||||
|
||||
def get_stats(table: Any) -> dict[str, Any]:
|
||||
"""Return index statistics."""
|
||||
total_docs = 0
|
||||
total_chunks = 0
|
||||
try:
|
||||
total_chunks = table.count_rows()
|
||||
# Count unique source files using pandas
|
||||
all_data = table.to_pandas()
|
||||
total_docs = all_data["source_file"].nunique()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"total_docs": total_docs, "total_chunks": total_chunks}
|
||||
Reference in New Issue
Block a user