feat: add QLoRA training script with Unsloth
This commit is contained in:
238
src/companion/forge/train.py
Normal file
238
src/companion/forge/train.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""QLoRA fine-tuning script using Unsloth.
|
||||||
|
|
||||||
|
Trains a Llama 3.1 8B model on extracted reflection data from the vault.
|
||||||
|
Optimized for RTX 5070 with 12GB VRAM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers import TrainingArguments
|
||||||
|
from trl import SFTTrainer
|
||||||
|
|
||||||
|
|
||||||
|
def load_training_data(data_path: Path) -> list[dict[str, Any]]:
|
||||||
|
"""Load training examples from JSONL file."""
|
||||||
|
examples = []
|
||||||
|
with open(data_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
examples.append(json.loads(line))
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataset(
|
||||||
|
examples: list[dict], validation_split: float = 0.1
|
||||||
|
) -> tuple[Dataset, Dataset | None]:
|
||||||
|
"""Split examples into train and validation datasets."""
|
||||||
|
import random
|
||||||
|
|
||||||
|
random.seed(42)
|
||||||
|
shuffled = examples.copy()
|
||||||
|
random.shuffle(shuffled)
|
||||||
|
|
||||||
|
split_idx = int(len(shuffled) * (1 - validation_split))
|
||||||
|
train_examples = shuffled[:split_idx]
|
||||||
|
val_examples = shuffled[split_idx:] if len(shuffled[split_idx:]) > 5 else None
|
||||||
|
|
||||||
|
train_dataset = Dataset.from_list(train_examples)
|
||||||
|
val_dataset = Dataset.from_list(val_examples) if val_examples else None
|
||||||
|
|
||||||
|
return train_dataset, val_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def format_example(example: dict) -> str:
|
||||||
|
"""Format a training example into the chat template format."""
|
||||||
|
messages = example.get("messages", [])
|
||||||
|
if not messages:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Format as conversation
|
||||||
|
formatted = ""
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if role == "system":
|
||||||
|
formatted += f"System: {content}\n\n"
|
||||||
|
elif role == "user":
|
||||||
|
formatted += f"User: {content}\n\n"
|
||||||
|
elif role == "assistant":
|
||||||
|
formatted += f"Assistant: {content}\n"
|
||||||
|
|
||||||
|
return formatted.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
data_path: Path,
|
||||||
|
output_dir: Path,
|
||||||
|
base_model: str = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
|
||||||
|
lora_rank: int = 16,
|
||||||
|
lora_alpha: int = 32,
|
||||||
|
learning_rate: float = 2e-4,
|
||||||
|
num_epochs: int = 3,
|
||||||
|
batch_size: int = 4,
|
||||||
|
gradient_accumulation_steps: int = 4,
|
||||||
|
warmup_steps: int = 100,
|
||||||
|
save_steps: int = 500,
|
||||||
|
eval_steps: int = 250,
|
||||||
|
validation_split: float = 0.1,
|
||||||
|
) -> Path:
|
||||||
|
"""Run QLoRA fine-tuning on the training data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_path: Path to JSONL file with training examples
|
||||||
|
output_dir: Directory to save model checkpoints and outputs
|
||||||
|
base_model: HuggingFace model ID for base model
|
||||||
|
lora_rank: LoRA rank (higher = more capacity, more memory)
|
||||||
|
lora_alpha: LoRA alpha (scaling factor)
|
||||||
|
learning_rate: Learning rate for optimizer
|
||||||
|
num_epochs: Number of training epochs
|
||||||
|
batch_size: Per-device batch size
|
||||||
|
gradient_accumulation_steps: Steps to accumulate before update
|
||||||
|
warmup_steps: Learning rate warmup steps
|
||||||
|
save_steps: Save checkpoint every N steps
|
||||||
|
eval_steps: Run evaluation every N steps
|
||||||
|
validation_split: Fraction of data for validation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to final checkpoint directory
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from unsloth import FastLanguageModel
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("unsloth is required. Install with: pip install unsloth")
|
||||||
|
|
||||||
|
print(f"Loading base model: {base_model}")
|
||||||
|
|
||||||
|
# Load model with Unsloth (4-bit quantization)
|
||||||
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
|
model_name=base_model,
|
||||||
|
max_seq_length=2048,
|
||||||
|
dtype=None, # Auto-detect
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add LoRA adapters
|
||||||
|
print(f"Adding LoRA adapters (rank={lora_rank}, alpha={lora_alpha})")
|
||||||
|
model = FastLanguageModel.get_peft_model(
|
||||||
|
model,
|
||||||
|
r=lora_rank,
|
||||||
|
target_modules=[
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
],
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=0,
|
||||||
|
bias="none",
|
||||||
|
use_gradient_checkpointing="unsloth",
|
||||||
|
random_state=3407,
|
||||||
|
use_rslora=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load training data
|
||||||
|
print(f"Loading training data from: {data_path}")
|
||||||
|
examples = load_training_data(data_path)
|
||||||
|
print(f"Loaded {len(examples)} examples")
|
||||||
|
|
||||||
|
if len(examples) < 10:
|
||||||
|
raise ValueError(f"Need at least 10 examples, got {len(examples)}")
|
||||||
|
|
||||||
|
train_dataset, val_dataset = prepare_dataset(examples, validation_split)
|
||||||
|
|
||||||
|
# Set up training arguments
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=str(output_dir),
|
||||||
|
num_train_epochs=num_epochs,
|
||||||
|
per_device_train_batch_size=batch_size,
|
||||||
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
logging_steps=10,
|
||||||
|
optim="adamw_8bit",
|
||||||
|
weight_decay=0.01,
|
||||||
|
lr_scheduler_type="linear",
|
||||||
|
seed=3407,
|
||||||
|
save_steps=save_steps,
|
||||||
|
eval_steps=eval_steps if val_dataset else None,
|
||||||
|
evaluation_strategy="steps" if val_dataset else "no",
|
||||||
|
save_strategy="steps",
|
||||||
|
load_best_model_at_end=True if val_dataset else False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize trainer
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=val_dataset,
|
||||||
|
dataset_text_field="text",
|
||||||
|
max_seq_length=2048,
|
||||||
|
dataset_num_proc=2,
|
||||||
|
packing=False,
|
||||||
|
args=training_args,
|
||||||
|
formatting_func=lambda ex: format_example(ex),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train
|
||||||
|
print("Starting training...")
|
||||||
|
trainer_stats = trainer.train()
|
||||||
|
|
||||||
|
# Save final model
|
||||||
|
final_path = output_dir / "final"
|
||||||
|
final_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print(f"Saving final model to: {final_path}")
|
||||||
|
trainer.save_model(str(final_path))
|
||||||
|
|
||||||
|
print(f"Training complete!")
|
||||||
|
print(f" - Final loss: {trainer_stats.training_loss:.4f}")
|
||||||
|
print(f" - Trained for {trainer_stats.global_step} steps")
|
||||||
|
|
||||||
|
return final_path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""CLI entry point for training."""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Train companion model")
|
||||||
|
parser.add_argument(
|
||||||
|
"--data", type=Path, required=True, help="Path to training data JSONL"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
type=Path,
|
||||||
|
default=Path("~/.companion/training"),
|
||||||
|
help="Output directory",
|
||||||
|
)
|
||||||
|
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
|
||||||
|
parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
output = args.output.expanduser()
|
||||||
|
output.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
final_model = train(
|
||||||
|
data_path=args.data,
|
||||||
|
output_dir=output,
|
||||||
|
num_epochs=args.epochs,
|
||||||
|
learning_rate=args.lr,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nModel saved to: {final_model}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user