117 lines
2.6 KiB
Python
117 lines
2.6 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script to verify our training setup
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
|
|
|
|
def test_imports():
|
|
"""Test if required packages are available"""
|
|
print("Testing package imports...")
|
|
|
|
try:
|
|
import torch
|
|
|
|
print(f"+ PyTorch {torch.__version__}")
|
|
except ImportError:
|
|
print("X PyTorch not installed")
|
|
return False
|
|
|
|
try:
|
|
import transformers
|
|
|
|
print(f"+ Transformers {transformers.__version__}")
|
|
except ImportError:
|
|
print("X Transformers not installed")
|
|
return False
|
|
|
|
try:
|
|
import peft
|
|
|
|
print(f"+ PEFT {peft.__version__}")
|
|
except ImportError:
|
|
print("X PEFT not installed")
|
|
return False
|
|
|
|
try:
|
|
import datasets
|
|
|
|
print(f"+ Datasets {datasets.__version__}")
|
|
except ImportError:
|
|
print("X Datasets not installed")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def test_gpu():
|
|
"""Test GPU availability"""
|
|
print("\nTesting GPU...")
|
|
|
|
try:
|
|
import torch
|
|
|
|
if torch.cuda.is_available():
|
|
print(f"+ GPU detected: {torch.cuda.get_device_name()}")
|
|
print(
|
|
f"+ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB"
|
|
)
|
|
return True
|
|
else:
|
|
print("! No GPU detected - training will be very slow")
|
|
return False
|
|
except:
|
|
print("X Cannot check GPU")
|
|
return False
|
|
|
|
|
|
def test_data():
|
|
"""Test if training data exists"""
|
|
print("\nTesting training data...")
|
|
|
|
data_file = "data/training/monte_cristo_combined.json"
|
|
if os.path.exists(data_file):
|
|
import json
|
|
|
|
with open(data_file, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
print(f"+ Training data found: {len(data)} examples")
|
|
|
|
# Show categories
|
|
categories = {}
|
|
for item in data:
|
|
cat = item.get("category", "unknown")
|
|
categories[cat] = categories.get(cat, 0) + 1
|
|
|
|
print(f"+ Categories: {categories}")
|
|
return True
|
|
else:
|
|
print("X Training data not found")
|
|
return False
|
|
|
|
|
|
def main():
|
|
"""Main test function"""
|
|
print("The Trial SLM - Environment Test")
|
|
print("=" * 50)
|
|
|
|
tests = [test_imports(), test_gpu(), test_data()]
|
|
|
|
if all(tests):
|
|
print("\n" + "=" * 50)
|
|
print("+ ALL TESTS PASSED - Ready for training!")
|
|
print("=" * 50)
|
|
return True
|
|
else:
|
|
print("\n" + "=" * 50)
|
|
print("! Some tests failed - fix issues before training")
|
|
print("=" * 50)
|
|
return False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
success = main()
|
|
sys.exit(0 if success else 1)
|