From ec69aa6f8dcc378ab0a66a9523e960d6b7e0f9a5 Mon Sep 17 00:00:00 2001 From: Santhosh Janardhanan Date: Mon, 13 Apr 2026 14:21:03 -0400 Subject: [PATCH] fix: sanitize SQL strings to prevent injection --- src/companion/rag/vector_store.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/companion/rag/vector_store.py b/src/companion/rag/vector_store.py index cc732f9..ec3e0a6 100644 --- a/src/companion/rag/vector_store.py +++ b/src/companion/rag/vector_store.py @@ -93,13 +93,17 @@ class VectorStore: "id" ).when_matched_update_all().when_not_matched_insert_all().execute(data) + def _escape_sql_string(self, value: str) -> str: + return value.replace("'", "''") + def delete_by_source_file(self, source_file: str) -> None: """Delete all chunks from a source file. Args: source_file: Path of source file to delete chunks for """ - self.table.delete(f"source_file = '{source_file}'") + escaped = self._escape_sql_string(source_file) + self.table.delete(f"source_file = '{escaped}'") def search( self, @@ -128,7 +132,8 @@ class VectorStore: filter_parts = [] for key, value in filters.items(): if isinstance(value, str): - filter_parts.append(f"{key} = '{value}'") + escaped = self._escape_sql_string(value) + filter_parts.append(f"{key} = '{escaped}'") else: filter_parts.append(f"{key} = {value}") if filter_parts: @@ -140,4 +145,4 @@ class VectorStore: def count(self) -> int: """Return total number of chunks.""" - return len(self.table) + return self.table.count_rows()