diff --git a/src/companion/forge/train.py b/src/companion/forge/train.py new file mode 100644 index 0000000..86f604e --- /dev/null +++ b/src/companion/forge/train.py @@ -0,0 +1,238 @@ +"""QLoRA fine-tuning script using Unsloth. + +Trains a Llama 3.1 8B model on extracted reflection data from the vault. +Optimized for RTX 5070 with 12GB VRAM. +""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any + +from datasets import Dataset +from transformers import TrainingArguments +from trl import SFTTrainer + + +def load_training_data(data_path: Path) -> list[dict[str, Any]]: + """Load training examples from JSONL file.""" + examples = [] + with open(data_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + examples.append(json.loads(line)) + return examples + + +def prepare_dataset( + examples: list[dict], validation_split: float = 0.1 +) -> tuple[Dataset, Dataset | None]: + """Split examples into train and validation datasets.""" + import random + + random.seed(42) + shuffled = examples.copy() + random.shuffle(shuffled) + + split_idx = int(len(shuffled) * (1 - validation_split)) + train_examples = shuffled[:split_idx] + val_examples = shuffled[split_idx:] if len(shuffled[split_idx:]) > 5 else None + + train_dataset = Dataset.from_list(train_examples) + val_dataset = Dataset.from_list(val_examples) if val_examples else None + + return train_dataset, val_dataset + + +def format_example(example: dict) -> str: + """Format a training example into the chat template format.""" + messages = example.get("messages", []) + if not messages: + return "" + + # Format as conversation + formatted = "" + for msg in messages: + role = msg.get("role", "") + content = msg.get("content", "") + if role == "system": + formatted += f"System: {content}\n\n" + elif role == "user": + formatted += f"User: {content}\n\n" + elif role == "assistant": + formatted += f"Assistant: {content}\n" + + return formatted.strip() + + +def train( + data_path: Path, + output_dir: Path, + base_model: str = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit", + lora_rank: int = 16, + lora_alpha: int = 32, + learning_rate: float = 2e-4, + num_epochs: int = 3, + batch_size: int = 4, + gradient_accumulation_steps: int = 4, + warmup_steps: int = 100, + save_steps: int = 500, + eval_steps: int = 250, + validation_split: float = 0.1, +) -> Path: + """Run QLoRA fine-tuning on the training data. + + Args: + data_path: Path to JSONL file with training examples + output_dir: Directory to save model checkpoints and outputs + base_model: HuggingFace model ID for base model + lora_rank: LoRA rank (higher = more capacity, more memory) + lora_alpha: LoRA alpha (scaling factor) + learning_rate: Learning rate for optimizer + num_epochs: Number of training epochs + batch_size: Per-device batch size + gradient_accumulation_steps: Steps to accumulate before update + warmup_steps: Learning rate warmup steps + save_steps: Save checkpoint every N steps + eval_steps: Run evaluation every N steps + validation_split: Fraction of data for validation + + Returns: + Path to final checkpoint directory + """ + try: + from unsloth import FastLanguageModel + except ImportError: + raise ImportError("unsloth is required. Install with: pip install unsloth") + + print(f"Loading base model: {base_model}") + + # Load model with Unsloth (4-bit quantization) + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=base_model, + max_seq_length=2048, + dtype=None, # Auto-detect + load_in_4bit=True, + ) + + # Add LoRA adapters + print(f"Adding LoRA adapters (rank={lora_rank}, alpha={lora_alpha})") + model = FastLanguageModel.get_peft_model( + model, + r=lora_rank, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + lora_alpha=lora_alpha, + lora_dropout=0, + bias="none", + use_gradient_checkpointing="unsloth", + random_state=3407, + use_rslora=False, + ) + + # Load training data + print(f"Loading training data from: {data_path}") + examples = load_training_data(data_path) + print(f"Loaded {len(examples)} examples") + + if len(examples) < 10: + raise ValueError(f"Need at least 10 examples, got {len(examples)}") + + train_dataset, val_dataset = prepare_dataset(examples, validation_split) + + # Set up training arguments + training_args = TrainingArguments( + output_dir=str(output_dir), + num_train_epochs=num_epochs, + per_device_train_batch_size=batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + warmup_steps=warmup_steps, + learning_rate=learning_rate, + logging_steps=10, + optim="adamw_8bit", + weight_decay=0.01, + lr_scheduler_type="linear", + seed=3407, + save_steps=save_steps, + eval_steps=eval_steps if val_dataset else None, + evaluation_strategy="steps" if val_dataset else "no", + save_strategy="steps", + load_best_model_at_end=True if val_dataset else False, + ) + + # Initialize trainer + trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=train_dataset, + eval_dataset=val_dataset, + dataset_text_field="text", + max_seq_length=2048, + dataset_num_proc=2, + packing=False, + args=training_args, + formatting_func=lambda ex: format_example(ex), + ) + + # Train + print("Starting training...") + trainer_stats = trainer.train() + + # Save final model + final_path = output_dir / "final" + final_path.mkdir(parents=True, exist_ok=True) + + print(f"Saving final model to: {final_path}") + trainer.save_model(str(final_path)) + + print(f"Training complete!") + print(f" - Final loss: {trainer_stats.training_loss:.4f}") + print(f" - Trained for {trainer_stats.global_step} steps") + + return final_path + + +def main(): + """CLI entry point for training.""" + import argparse + + parser = argparse.ArgumentParser(description="Train companion model") + parser.add_argument( + "--data", type=Path, required=True, help="Path to training data JSONL" + ) + parser.add_argument( + "--output", + type=Path, + default=Path("~/.companion/training"), + help="Output directory", + ) + parser.add_argument("--epochs", type=int, default=3, help="Number of epochs") + parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate") + + args = parser.parse_args() + + output = args.output.expanduser() + output.mkdir(parents=True, exist_ok=True) + + final_model = train( + data_path=args.data, + output_dir=output, + num_epochs=args.epochs, + learning_rate=args.lr, + ) + + print(f"\nModel saved to: {final_model}") + + +if __name__ == "__main__": + main()