feat: add QLoRA training script with Unsloth
This commit is contained in:
238
src/companion/forge/train.py
Normal file
238
src/companion/forge/train.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""QLoRA fine-tuning script using Unsloth.
|
||||
|
||||
Trains a Llama 3.1 8B model on extracted reflection data from the vault.
|
||||
Optimized for RTX 5070 with 12GB VRAM.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from datasets import Dataset
|
||||
from transformers import TrainingArguments
|
||||
from trl import SFTTrainer
|
||||
|
||||
|
||||
def load_training_data(data_path: Path) -> list[dict[str, Any]]:
|
||||
"""Load training examples from JSONL file."""
|
||||
examples = []
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
examples.append(json.loads(line))
|
||||
return examples
|
||||
|
||||
|
||||
def prepare_dataset(
|
||||
examples: list[dict], validation_split: float = 0.1
|
||||
) -> tuple[Dataset, Dataset | None]:
|
||||
"""Split examples into train and validation datasets."""
|
||||
import random
|
||||
|
||||
random.seed(42)
|
||||
shuffled = examples.copy()
|
||||
random.shuffle(shuffled)
|
||||
|
||||
split_idx = int(len(shuffled) * (1 - validation_split))
|
||||
train_examples = shuffled[:split_idx]
|
||||
val_examples = shuffled[split_idx:] if len(shuffled[split_idx:]) > 5 else None
|
||||
|
||||
train_dataset = Dataset.from_list(train_examples)
|
||||
val_dataset = Dataset.from_list(val_examples) if val_examples else None
|
||||
|
||||
return train_dataset, val_dataset
|
||||
|
||||
|
||||
def format_example(example: dict) -> str:
|
||||
"""Format a training example into the chat template format."""
|
||||
messages = example.get("messages", [])
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
# Format as conversation
|
||||
formatted = ""
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
if role == "system":
|
||||
formatted += f"System: {content}\n\n"
|
||||
elif role == "user":
|
||||
formatted += f"User: {content}\n\n"
|
||||
elif role == "assistant":
|
||||
formatted += f"Assistant: {content}\n"
|
||||
|
||||
return formatted.strip()
|
||||
|
||||
|
||||
def train(
|
||||
data_path: Path,
|
||||
output_dir: Path,
|
||||
base_model: str = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
|
||||
lora_rank: int = 16,
|
||||
lora_alpha: int = 32,
|
||||
learning_rate: float = 2e-4,
|
||||
num_epochs: int = 3,
|
||||
batch_size: int = 4,
|
||||
gradient_accumulation_steps: int = 4,
|
||||
warmup_steps: int = 100,
|
||||
save_steps: int = 500,
|
||||
eval_steps: int = 250,
|
||||
validation_split: float = 0.1,
|
||||
) -> Path:
|
||||
"""Run QLoRA fine-tuning on the training data.
|
||||
|
||||
Args:
|
||||
data_path: Path to JSONL file with training examples
|
||||
output_dir: Directory to save model checkpoints and outputs
|
||||
base_model: HuggingFace model ID for base model
|
||||
lora_rank: LoRA rank (higher = more capacity, more memory)
|
||||
lora_alpha: LoRA alpha (scaling factor)
|
||||
learning_rate: Learning rate for optimizer
|
||||
num_epochs: Number of training epochs
|
||||
batch_size: Per-device batch size
|
||||
gradient_accumulation_steps: Steps to accumulate before update
|
||||
warmup_steps: Learning rate warmup steps
|
||||
save_steps: Save checkpoint every N steps
|
||||
eval_steps: Run evaluation every N steps
|
||||
validation_split: Fraction of data for validation
|
||||
|
||||
Returns:
|
||||
Path to final checkpoint directory
|
||||
"""
|
||||
try:
|
||||
from unsloth import FastLanguageModel
|
||||
except ImportError:
|
||||
raise ImportError("unsloth is required. Install with: pip install unsloth")
|
||||
|
||||
print(f"Loading base model: {base_model}")
|
||||
|
||||
# Load model with Unsloth (4-bit quantization)
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name=base_model,
|
||||
max_seq_length=2048,
|
||||
dtype=None, # Auto-detect
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
# Add LoRA adapters
|
||||
print(f"Adding LoRA adapters (rank={lora_rank}, alpha={lora_alpha})")
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=lora_rank,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
use_gradient_checkpointing="unsloth",
|
||||
random_state=3407,
|
||||
use_rslora=False,
|
||||
)
|
||||
|
||||
# Load training data
|
||||
print(f"Loading training data from: {data_path}")
|
||||
examples = load_training_data(data_path)
|
||||
print(f"Loaded {len(examples)} examples")
|
||||
|
||||
if len(examples) < 10:
|
||||
raise ValueError(f"Need at least 10 examples, got {len(examples)}")
|
||||
|
||||
train_dataset, val_dataset = prepare_dataset(examples, validation_split)
|
||||
|
||||
# Set up training arguments
|
||||
training_args = TrainingArguments(
|
||||
output_dir=str(output_dir),
|
||||
num_train_epochs=num_epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
learning_rate=learning_rate,
|
||||
logging_steps=10,
|
||||
optim="adamw_8bit",
|
||||
weight_decay=0.01,
|
||||
lr_scheduler_type="linear",
|
||||
seed=3407,
|
||||
save_steps=save_steps,
|
||||
eval_steps=eval_steps if val_dataset else None,
|
||||
evaluation_strategy="steps" if val_dataset else "no",
|
||||
save_strategy="steps",
|
||||
load_best_model_at_end=True if val_dataset else False,
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
dataset_num_proc=2,
|
||||
packing=False,
|
||||
args=training_args,
|
||||
formatting_func=lambda ex: format_example(ex),
|
||||
)
|
||||
|
||||
# Train
|
||||
print("Starting training...")
|
||||
trainer_stats = trainer.train()
|
||||
|
||||
# Save final model
|
||||
final_path = output_dir / "final"
|
||||
final_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Saving final model to: {final_path}")
|
||||
trainer.save_model(str(final_path))
|
||||
|
||||
print(f"Training complete!")
|
||||
print(f" - Final loss: {trainer_stats.training_loss:.4f}")
|
||||
print(f" - Trained for {trainer_stats.global_step} steps")
|
||||
|
||||
return final_path
|
||||
|
||||
|
||||
def main():
|
||||
"""CLI entry point for training."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train companion model")
|
||||
parser.add_argument(
|
||||
"--data", type=Path, required=True, help="Path to training data JSONL"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=Path("~/.companion/training"),
|
||||
help="Output directory",
|
||||
)
|
||||
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
|
||||
parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
output = args.output.expanduser()
|
||||
output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
final_model = train(
|
||||
data_path=args.data,
|
||||
output_dir=output,
|
||||
num_epochs=args.epochs,
|
||||
learning_rate=args.lr,
|
||||
)
|
||||
|
||||
print(f"\nModel saved to: {final_model}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user