feat: add LanceDB vector store with upsert, delete, and search
This commit is contained in:
0
src/companion/__init__.py
Normal file
0
src/companion/__init__.py
Normal file
143
src/companion/rag/vector_store.py
Normal file
143
src/companion/rag/vector_store.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import lancedb
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
TABLE_NAME = "chunks"
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""Vector store wrapper around LanceDB for RAG."""
|
||||
|
||||
def __init__(self, uri: str, dimensions: int):
|
||||
"""Connect to LanceDB, create table if not exists.
|
||||
|
||||
Args:
|
||||
uri: Path to LanceDB database
|
||||
dimensions: Dimensionality of embeddings
|
||||
"""
|
||||
self.uri = uri
|
||||
self.dimensions = dimensions
|
||||
self.db = lancedb.connect(uri)
|
||||
self.table = self._get_or_create_table()
|
||||
|
||||
def _get_or_create_table(self):
|
||||
"""Get existing table or create new one with schema."""
|
||||
try:
|
||||
return self.db.open_table(TABLE_NAME)
|
||||
except (FileNotFoundError, ValueError):
|
||||
# Create table with schema
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("id", pa.string()),
|
||||
pa.field("text", pa.string()),
|
||||
pa.field("vector", pa.list_(pa.float32(), self.dimensions)),
|
||||
pa.field("source_file", pa.string()),
|
||||
pa.field("source_directory", pa.string()),
|
||||
pa.field("section", pa.string(), nullable=True),
|
||||
pa.field("date", pa.string(), nullable=True),
|
||||
pa.field("tags", pa.list_(pa.string()), nullable=True),
|
||||
pa.field("chunk_index", pa.int32()),
|
||||
pa.field("total_chunks", pa.int32()),
|
||||
pa.field("modified_at", pa.float64(), nullable=True),
|
||||
pa.field("rule_applied", pa.string()),
|
||||
]
|
||||
)
|
||||
return self.db.create_table(TABLE_NAME, schema=schema)
|
||||
|
||||
def upsert(
|
||||
self,
|
||||
ids: List[str],
|
||||
texts: List[str],
|
||||
embeddings: List[List[float]],
|
||||
metadatas: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Insert or update chunks using merge_insert.
|
||||
|
||||
Args:
|
||||
ids: List of unique chunk IDs
|
||||
texts: List of text content
|
||||
embeddings: List of embedding vectors
|
||||
metadatas: List of metadata dicts with keys like source_file, etc.
|
||||
"""
|
||||
# Convert embeddings to numpy float32
|
||||
vectors = np.array(embeddings, dtype=np.float32)
|
||||
|
||||
# Build records
|
||||
records = []
|
||||
for i, (id_, text, vector, metadata) in enumerate(
|
||||
zip(ids, texts, vectors, metadatas)
|
||||
):
|
||||
record = {
|
||||
"id": id_,
|
||||
"text": text,
|
||||
"vector": vector,
|
||||
"source_file": metadata.get("source_file", ""),
|
||||
"source_directory": metadata.get("source_directory", ""),
|
||||
"section": metadata.get("section"),
|
||||
"date": metadata.get("date"),
|
||||
"tags": metadata.get("tags"),
|
||||
"chunk_index": metadata.get("chunk_index", 0),
|
||||
"total_chunks": metadata.get("total_chunks", 1),
|
||||
"modified_at": metadata.get("modified_at"),
|
||||
"rule_applied": metadata.get("rule_applied", "default"),
|
||||
}
|
||||
records.append(record)
|
||||
|
||||
# Convert records to pyarrow Table with proper schema
|
||||
# This ensures the vector field is correctly typed as fixed_size_list
|
||||
data = pa.Table.from_pylist(records, schema=self.table.schema)
|
||||
|
||||
# Use merge_insert for upsert
|
||||
self.table.merge_insert(
|
||||
"id"
|
||||
).when_matched_update_all().when_not_matched_insert_all().execute(data)
|
||||
|
||||
def delete_by_source_file(self, source_file: str) -> None:
|
||||
"""Delete all chunks from a source file.
|
||||
|
||||
Args:
|
||||
source_file: Path of source file to delete chunks for
|
||||
"""
|
||||
self.table.delete(f"source_file = '{source_file}'")
|
||||
|
||||
def search(
|
||||
self,
|
||||
query_vector: List[float],
|
||||
top_k: int = 10,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Vector similarity search with optional filters.
|
||||
|
||||
Args:
|
||||
query_vector: Query embedding vector
|
||||
top_k: Number of results to return
|
||||
filters: Optional dict of filters (e.g., {"source_directory": "docs"})
|
||||
|
||||
Returns:
|
||||
List of result dicts with metadata
|
||||
"""
|
||||
# Convert query to numpy float32
|
||||
query = np.array(query_vector, dtype=np.float32)
|
||||
|
||||
# Build search
|
||||
search = self.table.search(query).limit(top_k)
|
||||
|
||||
# Apply filters if provided
|
||||
if filters:
|
||||
filter_parts = []
|
||||
for key, value in filters.items():
|
||||
if isinstance(value, str):
|
||||
filter_parts.append(f"{key} = '{value}'")
|
||||
else:
|
||||
filter_parts.append(f"{key} = {value}")
|
||||
if filter_parts:
|
||||
search = search.where(" AND ".join(filter_parts))
|
||||
|
||||
# Execute and convert to list of dicts
|
||||
results = search.to_list()
|
||||
return results
|
||||
|
||||
def count(self) -> int:
|
||||
"""Return total number of chunks."""
|
||||
return len(self.table)
|
||||
Reference in New Issue
Block a user