diff --git a/src/companion/config.py b/src/companion/config.py new file mode 100644 index 0000000..7ba76a8 --- /dev/null +++ b/src/companion/config.py @@ -0,0 +1,212 @@ +import json +import os +from pathlib import Path +from typing import Any, Dict, List + +from pydantic import BaseModel, Field + + +class PersonaConfig(BaseModel): + role: str + tone: str + style: str + boundaries: List[str] = [] + + +class MemoryConfig(BaseModel): + session_turns: int + persistent_store: str + summarize_after: int + + +class ChatConfig(BaseModel): + streaming: bool + max_response_tokens: int + default_temperature: float + allow_temperature_override: bool + + +class CompanionConfig(BaseModel): + name: str + persona: PersonaConfig + memory: MemoryConfig + chat: ChatConfig + + +class IndexingConfig(BaseModel): + auto_sync: bool + auto_sync_interval_minutes: int + watch_fs_events: bool + file_patterns: List[str] + deny_dirs: List[str] + deny_patterns: List[str] + + +class ChunkingRule(BaseModel): + strategy: str + chunk_size: int + chunk_overlap: int + section_tags: List[str] = [] + + +class VaultConfig(BaseModel): + path: str + indexing: IndexingConfig + chunking_rules: Dict[str, ChunkingRule] = {} + + +class EmbeddingConfig(BaseModel): + provider: str + model: str + base_url: str + dimensions: int + batch_size: int + + +class VectorStoreConfig(BaseModel): + type: str + path: str + + +class HybridSearchConfig(BaseModel): + enabled: bool + keyword_weight: float + semantic_weight: float + + +class FiltersConfig(BaseModel): + date_range_enabled: bool + tag_filter_enabled: bool + directory_filter_enabled: bool + + +class SearchConfig(BaseModel): + default_top_k: int + max_top_k: int + similarity_threshold: float + hybrid_search: HybridSearchConfig + filters: FiltersConfig + + +class RagConfig(BaseModel): + embedding: EmbeddingConfig + vector_store: VectorStoreConfig + search: SearchConfig + + +class InferenceConfig(BaseModel): + backend: str + model_path: str + context_length: int + gpu_layers: int + batch_size: int + threads: int + + +class FineTuningConfig(BaseModel): + base_model: str + output_dir: str + lora_rank: int + lora_alpha: int + learning_rate: float + batch_size: int + gradient_accumulation_steps: int + num_epochs: int + warmup_steps: int + save_steps: int + eval_steps: int + training_data_path: str + validation_split: float + + +class RetrainScheduleConfig(BaseModel): + auto_reminder: bool + default_interval_days: int + reminder_channels: List[str] = [] + + +class ModelConfig(BaseModel): + inference: InferenceConfig + fine_tuning: FineTuningConfig + retrain_schedule: RetrainScheduleConfig + + +class AuthConfig(BaseModel): + enabled: bool + + +class ApiConfig(BaseModel): + host: str + port: int + cors_origins: List[str] = [] + auth: AuthConfig + + +class WebFeaturesConfig(BaseModel): + streaming: bool + citations: bool + source_preview: bool + + +class WebConfig(BaseModel): + enabled: bool + theme: str + features: WebFeaturesConfig + + +class CliConfig(BaseModel): + enabled: bool + rich_output: bool + + +class UiConfig(BaseModel): + web: WebConfig + cli: CliConfig + + +class LoggingConfig(BaseModel): + level: str + file: str + max_size_mb: int + backup_count: int + + +class SecurityConfig(BaseModel): + local_only: bool + vault_path_traversal_check: bool + sensitive_content_detection: bool + sensitive_patterns: List[str] = [] + require_confirmation_for_external_apis: bool + + +class Config(BaseModel): + companion: CompanionConfig + vault: VaultConfig + rag: RagConfig + model: ModelConfig + api: ApiConfig + ui: UiConfig + logging: LoggingConfig + security: SecurityConfig + + +def expand_tilde_recursive(obj: Any) -> Any: + """Recursively expand ~/ in string values.""" + if isinstance(obj, str) and obj.startswith("~/"): + return os.path.expanduser(obj) + elif isinstance(obj, dict): + return {k: expand_tilde_recursive(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [expand_tilde_recursive(item) for item in obj] + return obj + + +def load_config(path: str) -> Config: + """Load configuration from a JSON file with tilde expansion.""" + with open(path, "r", encoding="utf-8") as f: + raw_data = json.load(f) + + # Recursively expand tilde paths in strings + expanded_data = expand_tilde_recursive(raw_data) + + return Config.model_validate(expanded_data) diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..7df4a53 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,144 @@ +import json +import os +import tempfile + +from companion.config import load_config + + +def test_load_config_reads_json_and_expands_tilde(): + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump( + { + "companion": { + "name": "SAN", + "persona": { + "role": "companion", + "tone": "reflective", + "style": "questioning", + "boundaries": [], + }, + "memory": { + "session_turns": 20, + "persistent_store": "~/mem.db", + "summarize_after": 10, + }, + "chat": { + "streaming": True, + "max_response_tokens": 2048, + "default_temperature": 0.7, + "allow_temperature_override": True, + }, + }, + "vault": { + "path": "~/test-vault", + "indexing": { + "auto_sync": False, + "auto_sync_interval_minutes": 1440, + "watch_fs_events": False, + "file_patterns": ["*.md"], + "deny_dirs": [".git"], + "deny_patterns": [".*"], + }, + "chunking_rules": {}, + }, + "rag": { + "embedding": { + "provider": "ollama", + "model": "dummy", + "base_url": "http://localhost:11434", + "dimensions": 4, + "batch_size": 2, + }, + "vector_store": { + "type": "lancedb", + "path": "~/.companion/vectors.lance", + }, + "search": { + "default_top_k": 8, + "max_top_k": 20, + "similarity_threshold": 0.75, + "hybrid_search": { + "enabled": False, + "keyword_weight": 0.3, + "semantic_weight": 0.7, + }, + "filters": { + "date_range_enabled": True, + "tag_filter_enabled": True, + "directory_filter_enabled": True, + }, + }, + }, + "model": { + "inference": { + "backend": "llama.cpp", + "model_path": "", + "context_length": 8192, + "gpu_layers": 35, + "batch_size": 512, + "threads": 8, + }, + "fine_tuning": { + "base_model": "", + "output_dir": "", + "lora_rank": 16, + "lora_alpha": 32, + "learning_rate": 0.0002, + "batch_size": 4, + "gradient_accumulation_steps": 4, + "num_epochs": 3, + "warmup_steps": 100, + "save_steps": 500, + "eval_steps": 250, + "training_data_path": "", + "validation_split": 0.1, + }, + "retrain_schedule": { + "auto_reminder": True, + "default_interval_days": 90, + "reminder_channels": [], + }, + }, + "api": { + "host": "127.0.0.1", + "port": 7373, + "cors_origins": [], + "auth": {"enabled": False}, + }, + "ui": { + "web": { + "enabled": True, + "theme": "obsidian", + "features": { + "streaming": True, + "citations": True, + "source_preview": True, + }, + }, + "cli": {"enabled": True, "rich_output": True}, + }, + "logging": { + "level": "INFO", + "file": "", + "max_size_mb": 100, + "backup_count": 5, + }, + "security": { + "local_only": True, + "vault_path_traversal_check": True, + "sensitive_content_detection": True, + "sensitive_patterns": [], + "require_confirmation_for_external_apis": True, + }, + }, + f, + ) + path = f.name + try: + config = load_config(path) + assert config.vault.path == os.path.expanduser("~/test-vault") + assert config.rag.vector_store.path == os.path.expanduser( + "~/.companion/vectors.lance" + ) + finally: + os.unlink(path)