fix: sanitize SQL strings to prevent injection
This commit is contained in:
@@ -93,13 +93,17 @@ class VectorStore:
|
|||||||
"id"
|
"id"
|
||||||
).when_matched_update_all().when_not_matched_insert_all().execute(data)
|
).when_matched_update_all().when_not_matched_insert_all().execute(data)
|
||||||
|
|
||||||
|
def _escape_sql_string(self, value: str) -> str:
|
||||||
|
return value.replace("'", "''")
|
||||||
|
|
||||||
def delete_by_source_file(self, source_file: str) -> None:
|
def delete_by_source_file(self, source_file: str) -> None:
|
||||||
"""Delete all chunks from a source file.
|
"""Delete all chunks from a source file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_file: Path of source file to delete chunks for
|
source_file: Path of source file to delete chunks for
|
||||||
"""
|
"""
|
||||||
self.table.delete(f"source_file = '{source_file}'")
|
escaped = self._escape_sql_string(source_file)
|
||||||
|
self.table.delete(f"source_file = '{escaped}'")
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -128,7 +132,8 @@ class VectorStore:
|
|||||||
filter_parts = []
|
filter_parts = []
|
||||||
for key, value in filters.items():
|
for key, value in filters.items():
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
filter_parts.append(f"{key} = '{value}'")
|
escaped = self._escape_sql_string(value)
|
||||||
|
filter_parts.append(f"{key} = '{escaped}'")
|
||||||
else:
|
else:
|
||||||
filter_parts.append(f"{key} = {value}")
|
filter_parts.append(f"{key} = {value}")
|
||||||
if filter_parts:
|
if filter_parts:
|
||||||
@@ -140,4 +145,4 @@ class VectorStore:
|
|||||||
|
|
||||||
def count(self) -> int:
|
def count(self) -> int:
|
||||||
"""Return total number of chunks."""
|
"""Return total number of chunks."""
|
||||||
return len(self.table)
|
return self.table.count_rows()
|
||||||
|
|||||||
Reference in New Issue
Block a user