The Trial - Initial commit
This commit is contained in:
116
scripts/test_environment.py
Normal file
116
scripts/test_environment.py
Normal file
@@ -0,0 +1,116 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify our training setup
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def test_imports():
|
||||
"""Test if required packages are available"""
|
||||
print("Testing package imports...")
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
print(f"+ PyTorch {torch.__version__}")
|
||||
except ImportError:
|
||||
print("X PyTorch not installed")
|
||||
return False
|
||||
|
||||
try:
|
||||
import transformers
|
||||
|
||||
print(f"+ Transformers {transformers.__version__}")
|
||||
except ImportError:
|
||||
print("X Transformers not installed")
|
||||
return False
|
||||
|
||||
try:
|
||||
import peft
|
||||
|
||||
print(f"+ PEFT {peft.__version__}")
|
||||
except ImportError:
|
||||
print("X PEFT not installed")
|
||||
return False
|
||||
|
||||
try:
|
||||
import datasets
|
||||
|
||||
print(f"+ Datasets {datasets.__version__}")
|
||||
except ImportError:
|
||||
print("X Datasets not installed")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_gpu():
|
||||
"""Test GPU availability"""
|
||||
print("\nTesting GPU...")
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
print(f"+ GPU detected: {torch.cuda.get_device_name()}")
|
||||
print(
|
||||
f"+ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
print("! No GPU detected - training will be very slow")
|
||||
return False
|
||||
except:
|
||||
print("X Cannot check GPU")
|
||||
return False
|
||||
|
||||
|
||||
def test_data():
|
||||
"""Test if training data exists"""
|
||||
print("\nTesting training data...")
|
||||
|
||||
data_file = "data/training/monte_cristo_combined.json"
|
||||
if os.path.exists(data_file):
|
||||
import json
|
||||
|
||||
with open(data_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
print(f"+ Training data found: {len(data)} examples")
|
||||
|
||||
# Show categories
|
||||
categories = {}
|
||||
for item in data:
|
||||
cat = item.get("category", "unknown")
|
||||
categories[cat] = categories.get(cat, 0) + 1
|
||||
|
||||
print(f"+ Categories: {categories}")
|
||||
return True
|
||||
else:
|
||||
print("X Training data not found")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
print("The Trial SLM - Environment Test")
|
||||
print("=" * 50)
|
||||
|
||||
tests = [test_imports(), test_gpu(), test_data()]
|
||||
|
||||
if all(tests):
|
||||
print("\n" + "=" * 50)
|
||||
print("+ ALL TESTS PASSED - Ready for training!")
|
||||
print("=" * 50)
|
||||
return True
|
||||
else:
|
||||
print("\n" + "=" * 50)
|
||||
print("! Some tests failed - fix issues before training")
|
||||
print("=" * 50)
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user