The Trial - Initial commit
This commit is contained in:
135
scripts/start_training.py
Normal file
135
scripts/start_training.py
Normal file
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Model Setup and Training Launcher
|
||||
Prepares the environment and starts training
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def check_gpu():
|
||||
"""Check if GPU is available"""
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
print(f"✓ GPU detected: {torch.cuda.get_device_name()}")
|
||||
print(
|
||||
f"✓ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
print("⚠ No GPU detected. Training will be very slow.")
|
||||
return False
|
||||
except ImportError:
|
||||
print("⚠ PyTorch not installed")
|
||||
return False
|
||||
|
||||
|
||||
def check_model_access():
|
||||
"""Check if we can access the base model"""
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
|
||||
print("✓ Base model accessible")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"⚠ Cannot access base model: {e}")
|
||||
print("Note: You may need to request access to Llama models on HuggingFace")
|
||||
return False
|
||||
|
||||
|
||||
def setup_environment():
|
||||
"""Setup training environment"""
|
||||
print("Setting up training environment...")
|
||||
|
||||
# Create necessary directories
|
||||
dirs = ["models", "models/monte_cristo_qlora", "logs"]
|
||||
for dir_path in dirs:
|
||||
Path(dir_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Set environment variables
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["WANDB_PROJECT"] = "the-trial-slm"
|
||||
|
||||
# Set Hugging Face token if provided
|
||||
hf_token = os.getenv("HF_TOKEN")
|
||||
if hf_token:
|
||||
os.environ["HF_TOKEN"] = hf_token
|
||||
print("✓ Hugging Face token set")
|
||||
|
||||
print("✓ Environment setup complete")
|
||||
|
||||
|
||||
def start_training():
|
||||
"""Start the training process"""
|
||||
print("\n" + "=" * 60)
|
||||
print("The Trial SLM TRAINING")
|
||||
print("=" * 60)
|
||||
|
||||
# Pre-flight checks
|
||||
print("Running pre-flight checks...")
|
||||
gpu_available = check_gpu()
|
||||
model_accessible = check_model_access()
|
||||
|
||||
if not model_accessible:
|
||||
print("\n⚠ Cannot proceed without model access")
|
||||
print("Please ensure you have:")
|
||||
print("1. A HuggingFace account")
|
||||
print("2. Requested access to meta-llama/Llama-3.2-3B-Instruct")
|
||||
print("3. Set your HuggingFace token: hf_token=<your_token>")
|
||||
return False
|
||||
|
||||
# Setup environment
|
||||
setup_environment()
|
||||
|
||||
# Start training
|
||||
print("\nStarting training...")
|
||||
try:
|
||||
result = subprocess.run([sys.executable, "scripts/train_qlora.py"], check=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("\n" + "=" * 60)
|
||||
print("TRAINING COMPLETED SUCCESSFULLY!")
|
||||
print("=" * 60)
|
||||
print("Next steps:")
|
||||
print("1. Check the model in models/monte_cristo_qlora/")
|
||||
print("2. Run the integration script to prepare for Ollama")
|
||||
print("3. Test the model with sample queries")
|
||||
return True
|
||||
else:
|
||||
print(f"Training failed with return code: {result.returncode}")
|
||||
return False
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Training failed: {e}")
|
||||
return False
|
||||
except KeyboardInterrupt:
|
||||
print("\nTraining interrupted by user")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
print("The Trial Literary Analysis SLM - Training Launcher")
|
||||
print("=" * 60)
|
||||
|
||||
success = start_training()
|
||||
|
||||
if success:
|
||||
print("\n✓ Ready for next phase: Ollama Integration")
|
||||
else:
|
||||
print("\n⚠ Training failed. Check logs for details.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user