Files
obsidian-rag/python/obsidian_rag/vector_store.py
Santhosh Janardhanan 208531d28d Sprint 0-2: TS plugin scaffolding, LanceDB utils, tooling updates
- Add index-tool.ts command implementation
- Wire lancedb.ts vector search into plugin
- Update src/tools/index.ts exports
- Bump package deps (ts-jest, jest, typescript, lancedb)
- Add .claude/settings.local.json

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-11 13:24:26 -04:00

181 lines
5.5 KiB
Python

"""LanceDB table creation, vector upsert/delete/search."""
from __future__ import annotations
import json
import os
import time
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterable
import lancedb
if TYPE_CHECKING:
from obsidian_rag.config import ObsidianRagConfig
# ----------------------------------------------------------------------
# Schema constants
# ----------------------------------------------------------------------
TABLE_NAME = "obsidian_chunks"
VECTOR_DIM = 1024 # mxbai-embed-large
# ----------------------------------------------------------------------
# Types
# ----------------------------------------------------------------------
@dataclass
class SearchResult:
chunk_id: str
chunk_text: str
source_file: str
source_directory: str
section: str | None
date: str | None
tags: list[str]
chunk_index: int
score: float
# ----------------------------------------------------------------------
# Table setup
# ----------------------------------------------------------------------
def get_db(config: "ObsidianRagConfig") -> lancedb.LanceDBConnection:
"""Connect to the LanceDB database."""
import obsidian_rag.config as cfg_mod
db_path = cfg_mod.resolve_vector_db_path(config)
db_path.parent.mkdir(parents=True, exist_ok=True)
return lancedb.connect(str(db_path))
def create_table_if_not_exists(db: Any) -> Any:
"""Create the obsidian_chunks table if it doesn't exist."""
import pyarrow as pa
if TABLE_NAME in db.list_tables():
return db.open_table(TABLE_NAME)
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), VECTOR_DIM)),
pa.field("chunk_id", pa.string()),
pa.field("chunk_text", pa.string()),
pa.field("source_file", pa.string()),
pa.field("source_directory", pa.string()),
pa.field("section", pa.string()),
pa.field("date", pa.string()),
pa.field("tags", pa.list_(pa.string())),
pa.field("chunk_index", pa.int32()),
pa.field("total_chunks", pa.int32()),
pa.field("modified_at", pa.string()),
pa.field("indexed_at", pa.string()),
]
)
tbl = db.create_table(TABLE_NAME, schema=schema, exist_ok=True)
# Create FTS index on chunk_text for DEGRADED mode fallback (Ollama down)
# replace=True makes this idempotent — safe to call on existing tables
tbl.create_fts_index("chunk_text", replace=True)
return tbl
# ----------------------------------------------------------------------
# CRUD operations
# ----------------------------------------------------------------------
def upsert_chunks(
table: Any,
chunks: list[dict[str, Any]],
) -> int:
"""Add or update chunks in the table. Returns number of chunks written."""
if not chunks:
return 0
# Use when_matched_update_all + when_not_matched_insert_all for full upsert
(
table.merge_insert("chunk_id")
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(chunks)
)
return len(chunks)
def delete_by_source_file(table: Any, source_file: str) -> int:
"""Delete all chunks from a given source file. Returns count deleted."""
before = table.count_rows()
table.delete(f'source_file = "{source_file}"')
return before - table.count_rows()
def search_chunks(
table: Any,
query_vector: list[float],
limit: int = 5,
directory_filter: list[str] | None = None,
date_range: dict | None = None,
tags: list[str] | None = None,
) -> list[SearchResult]:
"""Search for similar chunks using vector similarity.
Filters are applied as AND conditions.
"""
import pyarrow as pa
# Build WHERE clause
conditions: list[str] = []
if directory_filter:
dir_list = ", ".join(f'"{d}"' for d in directory_filter)
conditions.append(f'source_directory IN ({dir_list})')
if date_range:
if "from" in date_range:
conditions.append(f"date >= '{date_range['from']}'")
if "to" in date_range:
conditions.append(f"date <= '{date_range['to']}'")
if tags:
for tag in tags:
conditions.append(f"list_contains(tags, '{tag}')")
where_clause = " AND ".join(conditions) if conditions else None
results = (
table.search(query_vector, vector_column_name="vector")
.limit(limit)
.where(where_clause) if where_clause else table.search(query_vector, vector_column_name="vector").limit(limit)
).to_list()
return [
SearchResult(
chunk_id=r["chunk_id"],
chunk_text=r["chunk_text"],
source_file=r["source_file"],
source_directory=r["source_directory"],
section=r.get("section") if r.get("section") not in (None, "None") else None,
date=r.get("date") if r.get("date") not in (None, "None") else None,
tags=r.get("tags") or [],
chunk_index=r.get("chunk_index") or 0,
score=r.get("_distance") or 0.0,
)
for r in results
]
def get_stats(table: Any) -> dict[str, Any]:
"""Return index statistics."""
total_docs = 0
total_chunks = 0
try:
total_chunks = table.count_rows()
# Count unique source files using pandas
all_data = table.to_pandas()
total_docs = all_data["source_file"].nunique()
except Exception:
pass
return {"total_docs": total_docs, "total_chunks": total_chunks}