"""QLoRA fine-tuning script using Unsloth. Trains a Llama 3.1 8B model on extracted reflection data from the vault. Optimized for RTX 5070 with 12GB VRAM. """ from __future__ import annotations import json import os from pathlib import Path from typing import Any from datasets import Dataset from transformers import TrainingArguments from trl import SFTTrainer def load_training_data(data_path: Path) -> list[dict[str, Any]]: """Load training examples from JSONL file.""" examples = [] with open(data_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: examples.append(json.loads(line)) return examples def prepare_dataset( examples: list[dict], validation_split: float = 0.1 ) -> tuple[Dataset, Dataset | None]: """Split examples into train and validation datasets.""" import random random.seed(42) shuffled = examples.copy() random.shuffle(shuffled) split_idx = int(len(shuffled) * (1 - validation_split)) train_examples = shuffled[:split_idx] val_examples = shuffled[split_idx:] if len(shuffled[split_idx:]) > 5 else None train_dataset = Dataset.from_list(train_examples) val_dataset = Dataset.from_list(val_examples) if val_examples else None return train_dataset, val_dataset def format_example(example: dict) -> str: """Format a training example into the chat template format.""" messages = example.get("messages", []) if not messages: return "" # Format as conversation formatted = "" for msg in messages: role = msg.get("role", "") content = msg.get("content", "") if role == "system": formatted += f"System: {content}\n\n" elif role == "user": formatted += f"User: {content}\n\n" elif role == "assistant": formatted += f"Assistant: {content}\n" return formatted.strip() def train( data_path: Path, output_dir: Path, base_model: str = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit", lora_rank: int = 16, lora_alpha: int = 32, learning_rate: float = 2e-4, num_epochs: int = 3, batch_size: int = 4, gradient_accumulation_steps: int = 4, warmup_steps: int = 100, save_steps: int = 500, eval_steps: int = 250, validation_split: float = 0.1, ) -> Path: """Run QLoRA fine-tuning on the training data. Args: data_path: Path to JSONL file with training examples output_dir: Directory to save model checkpoints and outputs base_model: HuggingFace model ID for base model lora_rank: LoRA rank (higher = more capacity, more memory) lora_alpha: LoRA alpha (scaling factor) learning_rate: Learning rate for optimizer num_epochs: Number of training epochs batch_size: Per-device batch size gradient_accumulation_steps: Steps to accumulate before update warmup_steps: Learning rate warmup steps save_steps: Save checkpoint every N steps eval_steps: Run evaluation every N steps validation_split: Fraction of data for validation Returns: Path to final checkpoint directory """ try: from unsloth import FastLanguageModel except ImportError: raise ImportError("unsloth is required. Install with: pip install unsloth") print(f"Loading base model: {base_model}") # Load model with Unsloth (4-bit quantization) model, tokenizer = FastLanguageModel.from_pretrained( model_name=base_model, max_seq_length=2048, dtype=None, # Auto-detect load_in_4bit=True, ) # Add LoRA adapters print(f"Adding LoRA adapters (rank={lora_rank}, alpha={lora_alpha})") model = FastLanguageModel.get_peft_model( model, r=lora_rank, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=lora_alpha, lora_dropout=0, bias="none", use_gradient_checkpointing="unsloth", random_state=3407, use_rslora=False, ) # Load training data print(f"Loading training data from: {data_path}") examples = load_training_data(data_path) print(f"Loaded {len(examples)} examples") if len(examples) < 10: raise ValueError(f"Need at least 10 examples, got {len(examples)}") train_dataset, val_dataset = prepare_dataset(examples, validation_split) # Set up training arguments training_args = TrainingArguments( output_dir=str(output_dir), num_train_epochs=num_epochs, per_device_train_batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, warmup_steps=warmup_steps, learning_rate=learning_rate, logging_steps=10, optim="adamw_8bit", weight_decay=0.01, lr_scheduler_type="linear", seed=3407, save_steps=save_steps, eval_steps=eval_steps if val_dataset else None, evaluation_strategy="steps" if val_dataset else "no", save_strategy="steps", load_best_model_at_end=True if val_dataset else False, ) # Initialize trainer trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, eval_dataset=val_dataset, dataset_text_field="text", max_seq_length=2048, dataset_num_proc=2, packing=False, args=training_args, formatting_func=lambda ex: format_example(ex), ) # Train print("Starting training...") trainer_stats = trainer.train() # Save final model final_path = output_dir / "final" final_path.mkdir(parents=True, exist_ok=True) print(f"Saving final model to: {final_path}") trainer.save_model(str(final_path)) print(f"Training complete!") print(f" - Final loss: {trainer_stats.training_loss:.4f}") print(f" - Trained for {trainer_stats.global_step} steps") return final_path def main(): """CLI entry point for training.""" import argparse parser = argparse.ArgumentParser(description="Train companion model") parser.add_argument( "--data", type=Path, required=True, help="Path to training data JSONL" ) parser.add_argument( "--output-dir", "--output", dest="output", type=Path, default=Path("~/.companion/training"), help="Output directory", ) parser.add_argument("--epochs", type=int, default=3, help="Number of epochs") parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate") args = parser.parse_args() output = args.output.expanduser() output.mkdir(parents=True, exist_ok=True) final_model = train( data_path=args.data, output_dir=output, num_epochs=args.epochs, learning_rate=args.lr, ) print(f"\nModel saved to: {final_model}") if __name__ == "__main__": main()