The Trial - Initial commit
This commit is contained in:
355
scripts/train_qlora.py
Normal file
355
scripts/train_qlora.py
Normal file
@@ -0,0 +1,355 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
QLoRA Training Script for The Trial Literary Analysis SLM
|
||||
Uses parameter-efficient fine-tuning to adapt a base model
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import wandb
|
||||
from datasets import Dataset, load_dataset
|
||||
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
DataCollatorForLanguageModeling,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
"""Configuration for QLoRA training"""
|
||||
|
||||
# Model configuration
|
||||
base_model: str = "meta-llama/Llama-3.2-3B-Instruct"
|
||||
adapter_name: str = "the-trial-adapter"
|
||||
|
||||
# Data configuration
|
||||
dataset_path: str = "data/training/monte_cristo_combined.json"
|
||||
max_seq_length: int = 2048
|
||||
|
||||
# QLoRA configuration
|
||||
use_4bit: bool = True
|
||||
use_nested_quant: bool = False
|
||||
bnb_4bit_compute_dtype: str = "bfloat16"
|
||||
bnb_4bit_quant_type: str = "nf4"
|
||||
|
||||
# LoRA configuration
|
||||
lora_r: int = 16
|
||||
lora_alpha: int = 32
|
||||
lora_dropout: float = 0.1
|
||||
target_modules: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
)
|
||||
|
||||
# Training arguments
|
||||
output_dir: str = "models/monte_cristo_qlora"
|
||||
num_train_epochs: int = 3
|
||||
per_device_train_batch_size: int = 1
|
||||
gradient_accumulation_steps: int = 8
|
||||
learning_rate: float = 2e-5
|
||||
weight_decay: float = 0.0
|
||||
warmup_ratio: float = 0.03
|
||||
max_grad_norm: float = 1.0
|
||||
|
||||
# Optimization
|
||||
optim: str = "paged_adamw_32bit"
|
||||
lr_scheduler_type: str = "cosine"
|
||||
logging_steps: int = 10
|
||||
save_steps: int = 100
|
||||
eval_steps: int = 100
|
||||
save_total_limit: int = 3
|
||||
|
||||
# Hardware
|
||||
fp16: bool = False
|
||||
bf16: bool = True
|
||||
gradient_checkpointing: bool = True
|
||||
dataloader_pin_memory: bool = False
|
||||
|
||||
# Miscellaneous
|
||||
report_to: str = "wandb"
|
||||
run_name: str = "the-trial-qlora"
|
||||
seed: int = 42
|
||||
|
||||
|
||||
class MonteCristoTrainer:
|
||||
"""Trainer class for The Trial SLM"""
|
||||
|
||||
def __init__(self, config: TrainingConfig):
|
||||
self.config = config
|
||||
self.setup_directories()
|
||||
self.setup_logging()
|
||||
|
||||
def setup_directories(self):
|
||||
"""Create necessary directories"""
|
||||
Path(self.config.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def setup_logging(self):
|
||||
"""Setup logging and wandb"""
|
||||
if self.config.report_to == "wandb":
|
||||
wandb.init(
|
||||
project="the-trial-slm",
|
||||
name=self.config.run_name,
|
||||
config=self.config.__dict__,
|
||||
)
|
||||
|
||||
def load_tokenizer(self) -> AutoTokenizer:
|
||||
"""Load and configure tokenizer"""
|
||||
logger.info(f"Loading tokenizer for {self.config.base_model}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.config.base_model, trust_remote_code=True, padding_side="right"
|
||||
)
|
||||
|
||||
# Set pad token if not present
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
logger.info(f"Tokenizer vocab size: {tokenizer.vocab_size}")
|
||||
return tokenizer
|
||||
|
||||
def load_model(self) -> AutoModelForCausalLM:
|
||||
"""Load model with QLoRA configuration"""
|
||||
logger.info(f"Loading model {self.config.base_model}")
|
||||
|
||||
# Configure 4-bit quantization
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=self.config.use_4bit,
|
||||
bnb_4bit_quant_type=self.config.bnb_4bit_quant_type,
|
||||
bnb_4bit_compute_dtype=getattr(torch, self.config.bnb_4bit_compute_dtype),
|
||||
bnb_4bit_use_double_quant=self.config.use_nested_quant,
|
||||
)
|
||||
|
||||
# Load model
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.config.base_model,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Prepare model for k-bit training
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
logger.info(f"Model loaded on device: {next(model.parameters()).device}")
|
||||
return model
|
||||
|
||||
def setup_lora(self, model: AutoModelForCausalLM) -> AutoModelForCausalLM:
|
||||
"""Setup LoRA adapter"""
|
||||
logger.info("Setting up LoRA adapter")
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
r=self.config.lora_r,
|
||||
lora_alpha=self.config.lora_alpha,
|
||||
lora_dropout=self.config.lora_dropout,
|
||||
target_modules=self.config.target_modules,
|
||||
bias="none",
|
||||
)
|
||||
|
||||
# Get PEFT model
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
# Print trainable parameters
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
all_params = sum(p.numel() for p in model.parameters())
|
||||
|
||||
logger.info(f"Trainable parameters: {trainable_params:,}")
|
||||
logger.info(f"All parameters: {all_params:,}")
|
||||
logger.info(f"Trainable%: {100 * trainable_params / all_params:.2f}%")
|
||||
|
||||
return model
|
||||
|
||||
def load_and_preprocess_data(self, tokenizer: AutoTokenizer) -> Dataset:
|
||||
"""Load and preprocess training data"""
|
||||
logger.info(f"Loading dataset from {self.config.dataset_path}")
|
||||
|
||||
# Load dataset
|
||||
with open(self.config.dataset_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
logger.info(f"Loaded {len(data)} training examples")
|
||||
|
||||
# Convert to HuggingFace Dataset
|
||||
dataset = Dataset.from_list(data)
|
||||
|
||||
# Tokenization function
|
||||
def tokenize_function(examples):
|
||||
# Format the prompt
|
||||
prompts = []
|
||||
for i in range(len(examples["instruction"])):
|
||||
instruction = examples["instruction"][i]
|
||||
input_text = examples["input"][i] if examples["input"][i] else ""
|
||||
output = examples["output"][i]
|
||||
|
||||
# Create prompt in instruction format
|
||||
prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
{instruction}
|
||||
|
||||
{input_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{output}<|eot_id|>"""
|
||||
|
||||
prompts.append(prompt)
|
||||
|
||||
# Tokenize
|
||||
tokenized = tokenizer(
|
||||
prompts,
|
||||
truncation=True,
|
||||
padding=False,
|
||||
max_length=self.config.max_seq_length,
|
||||
return_tensors=None,
|
||||
)
|
||||
|
||||
# Set labels for causal LM
|
||||
tokenized["labels"] = tokenized["input_ids"].copy()
|
||||
|
||||
return tokenized
|
||||
|
||||
# Apply tokenization
|
||||
tokenized_dataset = dataset.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
remove_columns=dataset.column_names,
|
||||
desc="Tokenizing dataset",
|
||||
)
|
||||
|
||||
logger.info(f"Tokenized dataset: {len(tokenized_dataset)} examples")
|
||||
return tokenized_dataset
|
||||
|
||||
def create_trainer(
|
||||
self,
|
||||
model: AutoModelForCausalLM,
|
||||
tokenizer: AutoTokenizer,
|
||||
train_dataset: Dataset,
|
||||
) -> Trainer:
|
||||
"""Create Trainer instance"""
|
||||
|
||||
# Training arguments
|
||||
training_args = TrainingArguments(
|
||||
output_dir=self.config.output_dir,
|
||||
num_train_epochs=self.config.num_train_epochs,
|
||||
per_device_train_batch_size=self.config.per_device_train_batch_size,
|
||||
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
|
||||
learning_rate=self.config.learning_rate,
|
||||
weight_decay=self.config.weight_decay,
|
||||
warmup_ratio=self.config.warmup_ratio,
|
||||
max_grad_norm=self.config.max_grad_norm,
|
||||
optim=self.config.optim,
|
||||
lr_scheduler_type=self.config.lr_scheduler_type,
|
||||
logging_steps=self.config.logging_steps,
|
||||
save_steps=self.config.save_steps,
|
||||
eval_steps=self.config.eval_steps,
|
||||
save_total_limit=self.config.save_total_limit,
|
||||
fp16=self.config.fp16,
|
||||
bf16=self.config.bf16,
|
||||
gradient_checkpointing=self.config.gradient_checkpointing,
|
||||
dataloader_pin_memory=self.config.dataloader_pin_memory,
|
||||
report_to=self.config.report_to,
|
||||
run_name=self.config.run_name,
|
||||
seed=self.config.seed,
|
||||
# Performance optimizations
|
||||
dataloader_num_workers=0,
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
|
||||
# Data collator
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm=False,
|
||||
pad_to_multiple_of=8,
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=None, # No evaluation dataset for now
|
||||
data_collator=data_collator,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
def train(self):
|
||||
"""Execute training"""
|
||||
logger.info("Starting The Trial SLM training")
|
||||
|
||||
# Load components
|
||||
tokenizer = self.load_tokenizer()
|
||||
model = self.load_model()
|
||||
model = self.setup_lora(model)
|
||||
train_dataset = self.load_and_preprocess_data(tokenizer)
|
||||
trainer = self.create_trainer(model, tokenizer, train_dataset)
|
||||
|
||||
# Train model
|
||||
logger.info("Beginning training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
logger.info("Saving final model...")
|
||||
trainer.save_model()
|
||||
tokenizer.save_pretrained(self.config.output_dir)
|
||||
|
||||
# Save adapter separately for Ollama
|
||||
adapter_path = Path(self.config.output_dir) / "adapter_model"
|
||||
if adapter_path.exists():
|
||||
logger.info(f"Adapter saved to {adapter_path}")
|
||||
|
||||
logger.info("Training completed successfully!")
|
||||
|
||||
return trainer, model
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function"""
|
||||
# Configuration
|
||||
config = TrainingConfig()
|
||||
|
||||
# Create trainer
|
||||
trainer_instance = MonteCristoTrainer(config)
|
||||
|
||||
# Execute training
|
||||
try:
|
||||
trainer, model = trainer_instance.train()
|
||||
logger.info("=" * 50)
|
||||
logger.info("TRAINING COMPLETED SUCCESSFULLY!")
|
||||
logger.info(f"Model saved to: {config.output_dir}")
|
||||
logger.info("=" * 50)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {e}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Cleanup wandb
|
||||
if config.report_to == "wandb":
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user