The Trial - Initial commit

This commit is contained in:
2026-01-17 14:59:35 -05:00
commit c401cf655d
27 changed files with 132452 additions and 0 deletions

116
scripts/test_environment.py Normal file
View File

@@ -0,0 +1,116 @@
#!/usr/bin/env python3
"""
Test script to verify our training setup
"""
import os
import sys
def test_imports():
"""Test if required packages are available"""
print("Testing package imports...")
try:
import torch
print(f"+ PyTorch {torch.__version__}")
except ImportError:
print("X PyTorch not installed")
return False
try:
import transformers
print(f"+ Transformers {transformers.__version__}")
except ImportError:
print("X Transformers not installed")
return False
try:
import peft
print(f"+ PEFT {peft.__version__}")
except ImportError:
print("X PEFT not installed")
return False
try:
import datasets
print(f"+ Datasets {datasets.__version__}")
except ImportError:
print("X Datasets not installed")
return False
return True
def test_gpu():
"""Test GPU availability"""
print("\nTesting GPU...")
try:
import torch
if torch.cuda.is_available():
print(f"+ GPU detected: {torch.cuda.get_device_name()}")
print(
f"+ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB"
)
return True
else:
print("! No GPU detected - training will be very slow")
return False
except:
print("X Cannot check GPU")
return False
def test_data():
"""Test if training data exists"""
print("\nTesting training data...")
data_file = "data/training/monte_cristo_combined.json"
if os.path.exists(data_file):
import json
with open(data_file, "r", encoding="utf-8") as f:
data = json.load(f)
print(f"+ Training data found: {len(data)} examples")
# Show categories
categories = {}
for item in data:
cat = item.get("category", "unknown")
categories[cat] = categories.get(cat, 0) + 1
print(f"+ Categories: {categories}")
return True
else:
print("X Training data not found")
return False
def main():
"""Main test function"""
print("The Trial SLM - Environment Test")
print("=" * 50)
tests = [test_imports(), test_gpu(), test_data()]
if all(tests):
print("\n" + "=" * 50)
print("+ ALL TESTS PASSED - Ready for training!")
print("=" * 50)
return True
else:
print("\n" + "=" * 50)
print("! Some tests failed - fix issues before training")
print("=" * 50)
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)