feat: add QLoRA training script with Unsloth

This commit is contained in:
2026-04-13 15:16:17 -04:00
parent f944bdc573
commit e919d2a8e2

View File

@@ -0,0 +1,238 @@
"""QLoRA fine-tuning script using Unsloth.
Trains a Llama 3.1 8B model on extracted reflection data from the vault.
Optimized for RTX 5070 with 12GB VRAM.
"""
from __future__ import annotations
import json
import os
from pathlib import Path
from typing import Any
from datasets import Dataset
from transformers import TrainingArguments
from trl import SFTTrainer
def load_training_data(data_path: Path) -> list[dict[str, Any]]:
"""Load training examples from JSONL file."""
examples = []
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
examples.append(json.loads(line))
return examples
def prepare_dataset(
examples: list[dict], validation_split: float = 0.1
) -> tuple[Dataset, Dataset | None]:
"""Split examples into train and validation datasets."""
import random
random.seed(42)
shuffled = examples.copy()
random.shuffle(shuffled)
split_idx = int(len(shuffled) * (1 - validation_split))
train_examples = shuffled[:split_idx]
val_examples = shuffled[split_idx:] if len(shuffled[split_idx:]) > 5 else None
train_dataset = Dataset.from_list(train_examples)
val_dataset = Dataset.from_list(val_examples) if val_examples else None
return train_dataset, val_dataset
def format_example(example: dict) -> str:
"""Format a training example into the chat template format."""
messages = example.get("messages", [])
if not messages:
return ""
# Format as conversation
formatted = ""
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
formatted += f"System: {content}\n\n"
elif role == "user":
formatted += f"User: {content}\n\n"
elif role == "assistant":
formatted += f"Assistant: {content}\n"
return formatted.strip()
def train(
data_path: Path,
output_dir: Path,
base_model: str = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
lora_rank: int = 16,
lora_alpha: int = 32,
learning_rate: float = 2e-4,
num_epochs: int = 3,
batch_size: int = 4,
gradient_accumulation_steps: int = 4,
warmup_steps: int = 100,
save_steps: int = 500,
eval_steps: int = 250,
validation_split: float = 0.1,
) -> Path:
"""Run QLoRA fine-tuning on the training data.
Args:
data_path: Path to JSONL file with training examples
output_dir: Directory to save model checkpoints and outputs
base_model: HuggingFace model ID for base model
lora_rank: LoRA rank (higher = more capacity, more memory)
lora_alpha: LoRA alpha (scaling factor)
learning_rate: Learning rate for optimizer
num_epochs: Number of training epochs
batch_size: Per-device batch size
gradient_accumulation_steps: Steps to accumulate before update
warmup_steps: Learning rate warmup steps
save_steps: Save checkpoint every N steps
eval_steps: Run evaluation every N steps
validation_split: Fraction of data for validation
Returns:
Path to final checkpoint directory
"""
try:
from unsloth import FastLanguageModel
except ImportError:
raise ImportError("unsloth is required. Install with: pip install unsloth")
print(f"Loading base model: {base_model}")
# Load model with Unsloth (4-bit quantization)
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=base_model,
max_seq_length=2048,
dtype=None, # Auto-detect
load_in_4bit=True,
)
# Add LoRA adapters
print(f"Adding LoRA adapters (rank={lora_rank}, alpha={lora_alpha})")
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=lora_alpha,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
)
# Load training data
print(f"Loading training data from: {data_path}")
examples = load_training_data(data_path)
print(f"Loaded {len(examples)} examples")
if len(examples) < 10:
raise ValueError(f"Need at least 10 examples, got {len(examples)}")
train_dataset, val_dataset = prepare_dataset(examples, validation_split)
# Set up training arguments
training_args = TrainingArguments(
output_dir=str(output_dir),
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=warmup_steps,
learning_rate=learning_rate,
logging_steps=10,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
save_steps=save_steps,
eval_steps=eval_steps if val_dataset else None,
evaluation_strategy="steps" if val_dataset else "no",
save_strategy="steps",
load_best_model_at_end=True if val_dataset else False,
)
# Initialize trainer
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=val_dataset,
dataset_text_field="text",
max_seq_length=2048,
dataset_num_proc=2,
packing=False,
args=training_args,
formatting_func=lambda ex: format_example(ex),
)
# Train
print("Starting training...")
trainer_stats = trainer.train()
# Save final model
final_path = output_dir / "final"
final_path.mkdir(parents=True, exist_ok=True)
print(f"Saving final model to: {final_path}")
trainer.save_model(str(final_path))
print(f"Training complete!")
print(f" - Final loss: {trainer_stats.training_loss:.4f}")
print(f" - Trained for {trainer_stats.global_step} steps")
return final_path
def main():
"""CLI entry point for training."""
import argparse
parser = argparse.ArgumentParser(description="Train companion model")
parser.add_argument(
"--data", type=Path, required=True, help="Path to training data JSONL"
)
parser.add_argument(
"--output",
type=Path,
default=Path("~/.companion/training"),
help="Output directory",
)
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate")
args = parser.parse_args()
output = args.output.expanduser()
output.mkdir(parents=True, exist_ok=True)
final_model = train(
data_path=args.data,
output_dir=output,
num_epochs=args.epochs,
learning_rate=args.lr,
)
print(f"\nModel saved to: {final_model}")
if __name__ == "__main__":
main()