feat: add typed configuration loader with tilde expansion
This commit is contained in:
212
src/companion/config.py
Normal file
212
src/companion/config.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PersonaConfig(BaseModel):
|
||||
role: str
|
||||
tone: str
|
||||
style: str
|
||||
boundaries: List[str] = []
|
||||
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
session_turns: int
|
||||
persistent_store: str
|
||||
summarize_after: int
|
||||
|
||||
|
||||
class ChatConfig(BaseModel):
|
||||
streaming: bool
|
||||
max_response_tokens: int
|
||||
default_temperature: float
|
||||
allow_temperature_override: bool
|
||||
|
||||
|
||||
class CompanionConfig(BaseModel):
|
||||
name: str
|
||||
persona: PersonaConfig
|
||||
memory: MemoryConfig
|
||||
chat: ChatConfig
|
||||
|
||||
|
||||
class IndexingConfig(BaseModel):
|
||||
auto_sync: bool
|
||||
auto_sync_interval_minutes: int
|
||||
watch_fs_events: bool
|
||||
file_patterns: List[str]
|
||||
deny_dirs: List[str]
|
||||
deny_patterns: List[str]
|
||||
|
||||
|
||||
class ChunkingRule(BaseModel):
|
||||
strategy: str
|
||||
chunk_size: int
|
||||
chunk_overlap: int
|
||||
section_tags: List[str] = []
|
||||
|
||||
|
||||
class VaultConfig(BaseModel):
|
||||
path: str
|
||||
indexing: IndexingConfig
|
||||
chunking_rules: Dict[str, ChunkingRule] = {}
|
||||
|
||||
|
||||
class EmbeddingConfig(BaseModel):
|
||||
provider: str
|
||||
model: str
|
||||
base_url: str
|
||||
dimensions: int
|
||||
batch_size: int
|
||||
|
||||
|
||||
class VectorStoreConfig(BaseModel):
|
||||
type: str
|
||||
path: str
|
||||
|
||||
|
||||
class HybridSearchConfig(BaseModel):
|
||||
enabled: bool
|
||||
keyword_weight: float
|
||||
semantic_weight: float
|
||||
|
||||
|
||||
class FiltersConfig(BaseModel):
|
||||
date_range_enabled: bool
|
||||
tag_filter_enabled: bool
|
||||
directory_filter_enabled: bool
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
default_top_k: int
|
||||
max_top_k: int
|
||||
similarity_threshold: float
|
||||
hybrid_search: HybridSearchConfig
|
||||
filters: FiltersConfig
|
||||
|
||||
|
||||
class RagConfig(BaseModel):
|
||||
embedding: EmbeddingConfig
|
||||
vector_store: VectorStoreConfig
|
||||
search: SearchConfig
|
||||
|
||||
|
||||
class InferenceConfig(BaseModel):
|
||||
backend: str
|
||||
model_path: str
|
||||
context_length: int
|
||||
gpu_layers: int
|
||||
batch_size: int
|
||||
threads: int
|
||||
|
||||
|
||||
class FineTuningConfig(BaseModel):
|
||||
base_model: str
|
||||
output_dir: str
|
||||
lora_rank: int
|
||||
lora_alpha: int
|
||||
learning_rate: float
|
||||
batch_size: int
|
||||
gradient_accumulation_steps: int
|
||||
num_epochs: int
|
||||
warmup_steps: int
|
||||
save_steps: int
|
||||
eval_steps: int
|
||||
training_data_path: str
|
||||
validation_split: float
|
||||
|
||||
|
||||
class RetrainScheduleConfig(BaseModel):
|
||||
auto_reminder: bool
|
||||
default_interval_days: int
|
||||
reminder_channels: List[str] = []
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
inference: InferenceConfig
|
||||
fine_tuning: FineTuningConfig
|
||||
retrain_schedule: RetrainScheduleConfig
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class ApiConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
cors_origins: List[str] = []
|
||||
auth: AuthConfig
|
||||
|
||||
|
||||
class WebFeaturesConfig(BaseModel):
|
||||
streaming: bool
|
||||
citations: bool
|
||||
source_preview: bool
|
||||
|
||||
|
||||
class WebConfig(BaseModel):
|
||||
enabled: bool
|
||||
theme: str
|
||||
features: WebFeaturesConfig
|
||||
|
||||
|
||||
class CliConfig(BaseModel):
|
||||
enabled: bool
|
||||
rich_output: bool
|
||||
|
||||
|
||||
class UiConfig(BaseModel):
|
||||
web: WebConfig
|
||||
cli: CliConfig
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
level: str
|
||||
file: str
|
||||
max_size_mb: int
|
||||
backup_count: int
|
||||
|
||||
|
||||
class SecurityConfig(BaseModel):
|
||||
local_only: bool
|
||||
vault_path_traversal_check: bool
|
||||
sensitive_content_detection: bool
|
||||
sensitive_patterns: List[str] = []
|
||||
require_confirmation_for_external_apis: bool
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
companion: CompanionConfig
|
||||
vault: VaultConfig
|
||||
rag: RagConfig
|
||||
model: ModelConfig
|
||||
api: ApiConfig
|
||||
ui: UiConfig
|
||||
logging: LoggingConfig
|
||||
security: SecurityConfig
|
||||
|
||||
|
||||
def expand_tilde_recursive(obj: Any) -> Any:
|
||||
"""Recursively expand ~/ in string values."""
|
||||
if isinstance(obj, str) and obj.startswith("~/"):
|
||||
return os.path.expanduser(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: expand_tilde_recursive(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [expand_tilde_recursive(item) for item in obj]
|
||||
return obj
|
||||
|
||||
|
||||
def load_config(path: str) -> Config:
|
||||
"""Load configuration from a JSON file with tilde expansion."""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
raw_data = json.load(f)
|
||||
|
||||
# Recursively expand tilde paths in strings
|
||||
expanded_data = expand_tilde_recursive(raw_data)
|
||||
|
||||
return Config.model_validate(expanded_data)
|
||||
Reference in New Issue
Block a user