From 47ac2f36e057f7a3cc15b1790e2358c5dbd6a09b Mon Sep 17 00:00:00 2001 From: Santhosh Janardhanan Date: Mon, 13 Apr 2026 15:21:06 -0400 Subject: [PATCH] feat: add model reload endpoint and forge CLI --- pyproject.toml | 10 +++ src/companion/api.py | 29 +++++++ src/companion/forge/cli.py | 133 ++++++++++++++++++++++++++++++ src/companion/forge/export.py | 151 ++++++++++++++++++++++++++++++++++ src/companion/forge/reload.py | 89 ++++++++++++++++++++ 5 files changed, 412 insertions(+) create mode 100644 src/companion/forge/cli.py create mode 100644 src/companion/forge/export.py create mode 100644 src/companion/forge/reload.py diff --git a/pyproject.toml b/pyproject.toml index 190fbc6..8002fe9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,16 @@ dev = [ "httpx>=0.27.0", "respx>=0.21.0", ] +train = [ + "unsloth>=2024.1.0", + "torch>=2.1.0", + "transformers>=4.36.0", + "datasets>=2.14.0", + "peft>=0.7.0", + "accelerate>=0.25.0", + "bitsandbytes>=0.41.0", + "trl>=0.7.0", +] [tool.hatchling] packages = ["src/companion"] diff --git a/src/companion/api.py b/src/companion/api.py index 18705e7..0dff87e 100644 --- a/src/companion/api.py +++ b/src/companion/api.py @@ -209,6 +209,35 @@ async def get_session_history(session_id: str) -> dict: } +class ReloadModelRequest(BaseModel): + """Model reload request.""" + + model_path: str + + +@app.post("/admin/reload-model") +async def reload_model_endpoint(request: ReloadModelRequest) -> dict: + """Reload the model with a new fine-tuned version (admin only).""" + from pathlib import Path + + from companion.forge.reload import reload_model + + new_path = Path(request.model_path).expanduser() + + if not new_path.exists(): + raise HTTPException(status_code=404, detail=f"Model not found: {new_path}") + + try: + active_path = reload_model(config, new_path, backup=True) + return { + "status": "success", + "message": f"Model reloaded successfully", + "active_model": str(active_path), + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to reload model: {e}") + + if __name__ == "__main__": import uvicorn diff --git a/src/companion/forge/cli.py b/src/companion/forge/cli.py new file mode 100644 index 0000000..3893a6a --- /dev/null +++ b/src/companion/forge/cli.py @@ -0,0 +1,133 @@ +"""CLI for model forge operations.""" + +from __future__ import annotations + +from pathlib import Path + +import typer + +from companion.config import load_config +from companion.forge.extract import TrainingDataExtractor +from companion.forge.reload import get_model_status, reload_model +from companion.forge.train import train as train_model + +app = typer.Typer(help="Companion model forge - training pipeline") + + +@app.command() +def extract( + output: Path = typer.Option( + Path("~/.companion/training_data/extracted.jsonl"), + help="Output JSONL file path", + ), +) -> None: + """Extract training examples from vault.""" + config = load_config() + + typer.echo("Scanning vault for reflection examples...") + + extractor = TrainingDataExtractor(config) + examples = extractor.extract() + + if not examples: + typer.echo("No reflection examples found in vault.") + typer.echo( + "Try adding tags like #reflection, #insight, or #learning to your notes." + ) + raise typer.Exit(1) + + # Save to JSONL + output = output.expanduser() + output.parent.mkdir(parents=True, exist_ok=True) + count = extractor.save_to_jsonl(output) + + stats = extractor.get_stats() + + typer.echo(f"\nExtracted {count} training examples:") + typer.echo(f" - Average length: {stats.get('avg_length', 0)} chars") + if stats.get("top_tags"): + typer.echo( + f" - Top tags: {', '.join(f'{tag}({cnt})' for tag, cnt in stats['top_tags'][:5])}" + ) + typer.echo(f"\nSaved to: {output}") + + +@app.command() +def status() -> None: + """Check model status.""" + config = load_config() + + model_status = get_model_status(config) + + typer.echo(f"Model Status:") + typer.echo(f" Path: {model_status['path']}") + typer.echo(f" Exists: {'Yes' if model_status['exists'] else 'No'}") + if model_status["exists"]: + typer.echo(f" Type: {model_status['type']}") + typer.echo(f" Size: {model_status['size_mb']} MB") + + +@app.command() +def reload( + model_path: Path = typer.Argument( + ..., + help="Path to new model directory or GGUF file", + ), + no_backup: bool = typer.Option( + False, + "--no-backup", + help="Skip backing up current model", + ), +) -> None: + """Reload model with a new fine-tuned version.""" + config = load_config() + + model_path = model_path.expanduser() + + try: + active_path = reload_model(config, model_path, backup=not no_backup) + typer.echo(f"Model reloaded successfully: {active_path}") + except FileNotFoundError as e: + typer.echo(f"Error: {e}") + raise typer.Exit(1) + + +@app.command() +def train( + data: Path = typer.Option( + Path("~/.companion/training_data/extracted.jsonl"), + help="Path to training data JSONL", + ), + output: Path = typer.Option( + Path("~/.companion/training"), + help="Output directory for checkpoints", + ), + epochs: int = typer.Option(3, help="Number of training epochs"), + lr: float = typer.Option(2e-4, help="Learning rate"), +) -> None: + """Train model using QLoRA fine-tuning.""" + data = data.expanduser() + output = output.expanduser() + + if not data.exists(): + typer.echo(f"Training data not found: {data}") + typer.echo("Run 'forge extract' first to generate training data.") + raise typer.Exit(1) + + try: + final_path = train_model( + data_path=data, + output_dir=output, + num_epochs=epochs, + learning_rate=lr, + ) + typer.echo(f"\nTraining complete! Model saved to: {final_path}") + typer.echo("\nTo use this model:") + typer.echo(f" forge reload {final_path}") + except Exception as e: + typer.echo(f"Training failed: {e}") + raise typer.Exit(1) + + +if __name__ == "__main__": + app() diff --git a/src/companion/forge/export.py b/src/companion/forge/export.py new file mode 100644 index 0000000..60e62f4 --- /dev/null +++ b/src/companion/forge/export.py @@ -0,0 +1,151 @@ +"""Merge LoRA weights and export to GGUF for llama.cpp inference.""" + +from __future__ import annotations + +import argparse +import shutil +from pathlib import Path + +from peft import PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def export_to_gguf( + checkpoint_path: Path, + output_path: Path, + quantization: str = "Q4_K_M", +) -> Path: + """Export fine-tuned model to GGUF format. + + Args: + checkpoint_path: Path to checkpoint directory with LoRA weights + output_path: Path to save GGUF file + quantization: Quantization method (Q4_K_M, Q5_K_M, Q8_0) + + Returns: + Path to exported GGUF file + """ + print(f"Loading checkpoint from: {checkpoint_path}") + + # Load the base model + # Note: This assumes the checkpoint was saved with save_pretrained + # which includes the adapter_config.json + + from unsloth import FastLanguageModel + + # Load model with adapters + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=str(checkpoint_path), + max_seq_length=2048, + dtype=None, + load_in_4bit=False, # Load full precision for export + ) + + # Merge LoRA weights into base + print("Merging LoRA weights...") + model = model.merge_and_unload() + + # Save merged model temporarily + temp_path = checkpoint_path.parent / "merged" + temp_path.mkdir(exist_ok=True) + + print(f"Saving merged model to: {temp_path}") + model.save_pretrained(temp_path) + tokenizer.save_pretrained(temp_path) + + # Convert to GGUF using llama.cpp + # Note: This requires llama.cpp's convert script + output_path.parent.mkdir(parents=True, exist_ok=True) + + print(f"Exporting to GGUF format...") + print(f" Quantization: {quantization}") + print(f" Output: {output_path}") + + # For now, we'll save in HuggingFace format + # Full GGUF conversion would require llama.cpp tools + # which may not be installed in the environment + + # Alternative: Save as merged HF model + hf_output = output_path.parent / "merged_hf" + hf_output.mkdir(parents=True, exist_ok=True) + + model.save_pretrained(hf_output) + tokenizer.save_pretrained(hf_output) + + print(f"\nModel exported to HuggingFace format: {hf_output}") + print(f"\nTo convert to GGUF, install llama.cpp and run:") + print( + f" python convert_hf_to_gguf.py {hf_output} --outfile {output_path} --outtype {quantization}" + ) + + # Create a marker file + marker = output_path.parent / "EXPORTED" + marker.write_text(f"Merged model saved to: {hf_output}\n") + + return hf_output + + +def merge_only( + checkpoint_path: Path, + output_path: Path, +) -> Path: + """Just merge LoRA weights, save as HF model. + + This is useful if you want to serve via vLLM or HuggingFace directly + instead of converting to GGUF. + """ + print(f"Loading checkpoint from: {checkpoint_path}") + + from unsloth import FastLanguageModel + + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=str(checkpoint_path), + max_seq_length=2048, + dtype=None, + load_in_4bit=False, + ) + + print("Merging LoRA weights...") + model = model.merge_and_unload() + + output_path.mkdir(parents=True, exist_ok=True) + + print(f"Saving merged model to: {output_path}") + model.save_pretrained(output_path) + tokenizer.save_pretrained(output_path) + + print(f"Done! Model saved to: {output_path}") + + return output_path + + +def main(): + """CLI entry point.""" + parser = argparse.ArgumentParser(description="Export fine-tuned model") + parser.add_argument( + "--checkpoint", type=Path, required=True, help="Checkpoint directory" + ) + parser.add_argument( + "--output", + type=Path, + default=Path("~/.companion/models/exported"), + help="Output path", + ) + parser.add_argument("--gguf", action="store_true", help="Export to GGUF format") + parser.add_argument( + "--quant", type=str, default="Q4_K_M", help="GGUF quantization type" + ) + + args = parser.parse_args() + + checkpoint = args.checkpoint.expanduser() + output = args.output.expanduser() + + if args.gguf: + export_to_gguf(checkpoint, output, args.quant) + else: + merge_only(checkpoint, output) + + +if __name__ == "__main__": + main() diff --git a/src/companion/forge/reload.py b/src/companion/forge/reload.py new file mode 100644 index 0000000..df03023 --- /dev/null +++ b/src/companion/forge/reload.py @@ -0,0 +1,89 @@ +"""Model reloader for hot-swapping fine-tuned models.""" + +from __future__ import annotations + +import shutil +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from companion.config import Config + + +def reload_model( + config: Config, + new_model_path: Path, + backup: bool = True, +) -> Path: + """Reload the model with a new fine-tuned version. + + Args: + config: Current configuration + new_model_path: Path to new model directory or GGUF file + backup: Whether to backup the old model + + Returns: + Path to the active model + """ + current_model = Path(config.model.inference.model_path).expanduser() + + # Validate new model exists + if not new_model_path.exists(): + raise FileNotFoundError(f"New model not found: {new_model_path}") + + # Backup current model if it exists + if backup and current_model.exists(): + backup_path = current_model.parent / f"{current_model.name}.backup" + if backup_path.exists(): + shutil.rmtree(backup_path, ignore_errors=True) + + if current_model.is_dir(): + shutil.copytree(current_model, backup_path) + else: + shutil.copy2(current_model, backup_path) + + print(f"Backed up current model to: {backup_path}") + + # Copy new model to active location + if current_model.exists(): + if current_model.is_dir(): + shutil.rmtree(current_model, ignore_errors=True) + else: + current_model.unlink() + + current_model.parent.mkdir(parents=True, exist_ok=True) + + if new_model_path.is_dir(): + shutil.copytree(new_model_path, current_model) + else: + shutil.copy2(new_model_path, current_model) + + print(f"Model reloaded: {new_model_path} -> {current_model}") + + return current_model + + +def get_model_status(config: Config) -> dict: + """Get status of current model.""" + model_path = Path(config.model.inference.model_path).expanduser() + + status = { + "path": str(model_path), + "exists": model_path.exists(), + "type": None, + "size_mb": 0, + } + + if model_path.exists(): + if model_path.is_dir(): + status["type"] = "directory" + # Calculate directory size + total_size = sum( + f.stat().st_size for f in model_path.rglob("*") if f.is_file() + ) + status["size_mb"] = round(total_size / (1024 * 1024), 2) + else: + status["type"] = "file" + status["size_mb"] = round(model_path.stat().st_size / (1024 * 1024), 2) + + return status