"""LanceDB table creation, vector upsert/delete/search.""" from __future__ import annotations import json import os import time import uuid from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Iterable import lancedb if TYPE_CHECKING: from obsidian_rag.config import ObsidianRagConfig # ---------------------------------------------------------------------- # Schema constants # ---------------------------------------------------------------------- TABLE_NAME = "obsidian_chunks" VECTOR_DIM = 1024 # mxbai-embed-large # ---------------------------------------------------------------------- # Types # ---------------------------------------------------------------------- @dataclass class SearchResult: chunk_id: str chunk_text: str source_file: str source_directory: str section: str | None date: str | None tags: list[str] chunk_index: int score: float # ---------------------------------------------------------------------- # Table setup # ---------------------------------------------------------------------- def get_db(config: "ObsidianRagConfig") -> lancedb.LanceDBConnection: """Connect to the LanceDB database.""" import obsidian_rag.config as cfg_mod db_path = cfg_mod.resolve_vector_db_path(config) db_path.parent.mkdir(parents=True, exist_ok=True) return lancedb.connect(str(db_path)) def create_table_if_not_exists(db: Any) -> Any: """Create the obsidian_chunks table if it doesn't exist.""" import pyarrow as pa if TABLE_NAME in db.list_tables(): return db.open_table(TABLE_NAME) schema = pa.schema( [ pa.field("vector", pa.list_(pa.float32(), VECTOR_DIM)), pa.field("chunk_id", pa.string()), pa.field("chunk_text", pa.string()), pa.field("source_file", pa.string()), pa.field("source_directory", pa.string()), pa.field("section", pa.string()), pa.field("date", pa.string()), pa.field("tags", pa.list_(pa.string())), pa.field("chunk_index", pa.int32()), pa.field("total_chunks", pa.int32()), pa.field("modified_at", pa.string()), pa.field("indexed_at", pa.string()), ] ) tbl = db.create_table(TABLE_NAME, schema=schema, exist_ok=True) # Create FTS index on chunk_text for DEGRADED mode fallback (Ollama down) # replace=True makes this idempotent — safe to call on existing tables tbl.create_fts_index("chunk_text", replace=True) return tbl # ---------------------------------------------------------------------- # CRUD operations # ---------------------------------------------------------------------- def upsert_chunks( table: Any, chunks: list[dict[str, Any]], ) -> int: """Add or update chunks in the table. Returns number of chunks written.""" if not chunks: return 0 # Use when_matched_update_all + when_not_matched_insert_all for full upsert ( table.merge_insert("chunk_id") .when_matched_update_all() .when_not_matched_insert_all() .execute(chunks) ) return len(chunks) def delete_by_source_file(table: Any, source_file: str) -> int: """Delete all chunks from a given source file. Returns count deleted.""" before = table.count_rows() table.delete(f'source_file = "{source_file}"') return before - table.count_rows() def search_chunks( table: Any, query_vector: list[float], limit: int | None = None, directory_filter: list[str] | None = None, date_range: dict | None = None, tags: list[str] | None = None, ) -> list[SearchResult]: """Search for similar chunks using vector similarity. Filters are applied as AND conditions. """ import pyarrow as pa # Build WHERE clause conditions: list[str] = [] if directory_filter: dir_list = ", ".join(f'"{d}"' for d in directory_filter) conditions.append(f"source_directory IN ({dir_list})") if date_range: if "from" in date_range: conditions.append(f"date >= '{date_range['from']}'") if "to" in date_range: conditions.append(f"date <= '{date_range['to']}'") if tags: for tag in tags: conditions.append(f"list_contains(tags, '{tag}')") where_clause = " AND ".join(conditions) if conditions else None search_query = table.search(query_vector, vector_column_name="vector") if limit is not None: search_query = search_query.limit(limit) if where_clause: search_query = search_query.where(where_clause) results = search_query.to_list() return [ SearchResult( chunk_id=r["chunk_id"], chunk_text=r["chunk_text"], source_file=r["source_file"], source_directory=r["source_directory"], section=r.get("section") if r.get("section") not in (None, "None") else None, date=r.get("date") if r.get("date") not in (None, "None") else None, tags=r.get("tags") or [], chunk_index=r.get("chunk_index") or 0, score=r.get("_distance") or 0.0, ) for r in results ] def get_stats(table: Any) -> dict[str, Any]: """Return index statistics.""" total_docs = 0 total_chunks = 0 try: total_chunks = table.count_rows() # Count non-null, non-empty source files all_data = table.to_pandas() total_docs = ( all_data["source_file"] .dropna() .astype(str) .str.strip() .loc[lambda s: s.str.len() > 0] .nunique() ) except Exception: pass return {"total_docs": total_docs, "total_chunks": total_chunks}