136 lines
3.8 KiB
Python
136 lines
3.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Model Setup and Training Launcher
|
|
Prepares the environment and starts training
|
|
"""
|
|
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
from dotenv import load_dotenv
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
|
|
def check_gpu():
|
|
"""Check if GPU is available"""
|
|
try:
|
|
import torch
|
|
|
|
if torch.cuda.is_available():
|
|
print(f"✓ GPU detected: {torch.cuda.get_device_name()}")
|
|
print(
|
|
f"✓ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB"
|
|
)
|
|
return True
|
|
else:
|
|
print("⚠ No GPU detected. Training will be very slow.")
|
|
return False
|
|
except ImportError:
|
|
print("⚠ PyTorch not installed")
|
|
return False
|
|
|
|
|
|
def check_model_access():
|
|
"""Check if we can access the base model"""
|
|
try:
|
|
from transformers import AutoTokenizer
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
|
print("✓ Base model accessible")
|
|
return True
|
|
except Exception as e:
|
|
print(f"⚠ Cannot access base model: {e}")
|
|
print("Note: You may need to request access to Llama models on HuggingFace")
|
|
return False
|
|
|
|
|
|
def setup_environment():
|
|
"""Setup training environment"""
|
|
print("Setting up training environment...")
|
|
|
|
# Create necessary directories
|
|
dirs = ["models", "models/monte_cristo_qlora", "logs"]
|
|
for dir_path in dirs:
|
|
Path(dir_path).mkdir(parents=True, exist_ok=True)
|
|
|
|
# Set environment variables
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
os.environ["WANDB_PROJECT"] = "the-trial-slm"
|
|
|
|
# Set Hugging Face token if provided
|
|
hf_token = os.getenv("HF_TOKEN")
|
|
if hf_token:
|
|
os.environ["HF_TOKEN"] = hf_token
|
|
print("✓ Hugging Face token set")
|
|
|
|
print("✓ Environment setup complete")
|
|
|
|
|
|
def start_training():
|
|
"""Start the training process"""
|
|
print("\n" + "=" * 60)
|
|
print("The Trial SLM TRAINING")
|
|
print("=" * 60)
|
|
|
|
# Pre-flight checks
|
|
print("Running pre-flight checks...")
|
|
gpu_available = check_gpu()
|
|
model_accessible = check_model_access()
|
|
|
|
if not model_accessible:
|
|
print("\n⚠ Cannot proceed without model access")
|
|
print("Please ensure you have:")
|
|
print("1. A HuggingFace account")
|
|
print("2. Requested access to meta-llama/Llama-3.2-3B-Instruct")
|
|
print("3. Set your HuggingFace token: hf_token=<your_token>")
|
|
return False
|
|
|
|
# Setup environment
|
|
setup_environment()
|
|
|
|
# Start training
|
|
print("\nStarting training...")
|
|
try:
|
|
result = subprocess.run([sys.executable, "scripts/train_qlora.py"], check=True)
|
|
|
|
if result.returncode == 0:
|
|
print("\n" + "=" * 60)
|
|
print("TRAINING COMPLETED SUCCESSFULLY!")
|
|
print("=" * 60)
|
|
print("Next steps:")
|
|
print("1. Check the model in models/monte_cristo_qlora/")
|
|
print("2. Run the integration script to prepare for Ollama")
|
|
print("3. Test the model with sample queries")
|
|
return True
|
|
else:
|
|
print(f"Training failed with return code: {result.returncode}")
|
|
return False
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Training failed: {e}")
|
|
return False
|
|
except KeyboardInterrupt:
|
|
print("\nTraining interrupted by user")
|
|
return False
|
|
|
|
|
|
def main():
|
|
"""Main function"""
|
|
print("The Trial Literary Analysis SLM - Training Launcher")
|
|
print("=" * 60)
|
|
|
|
success = start_training()
|
|
|
|
if success:
|
|
print("\n✓ Ready for next phase: Ollama Integration")
|
|
else:
|
|
print("\n⚠ Training failed. Check logs for details.")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|