#!/usr/bin/env python3 """ Test script to verify our training setup """ import os import sys def test_imports(): """Test if required packages are available""" print("Testing package imports...") try: import torch print(f"+ PyTorch {torch.__version__}") except ImportError: print("X PyTorch not installed") return False try: import transformers print(f"+ Transformers {transformers.__version__}") except ImportError: print("X Transformers not installed") return False try: import peft print(f"+ PEFT {peft.__version__}") except ImportError: print("X PEFT not installed") return False try: import datasets print(f"+ Datasets {datasets.__version__}") except ImportError: print("X Datasets not installed") return False return True def test_gpu(): """Test GPU availability""" print("\nTesting GPU...") try: import torch if torch.cuda.is_available(): print(f"+ GPU detected: {torch.cuda.get_device_name()}") print( f"+ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB" ) return True else: print("! No GPU detected - training will be very slow") return False except: print("X Cannot check GPU") return False def test_data(): """Test if training data exists""" print("\nTesting training data...") data_file = "data/training/monte_cristo_combined.json" if os.path.exists(data_file): import json with open(data_file, "r", encoding="utf-8") as f: data = json.load(f) print(f"+ Training data found: {len(data)} examples") # Show categories categories = {} for item in data: cat = item.get("category", "unknown") categories[cat] = categories.get(cat, 0) + 1 print(f"+ Categories: {categories}") return True else: print("X Training data not found") return False def main(): """Main test function""" print("The Trial SLM - Environment Test") print("=" * 50) tests = [test_imports(), test_gpu(), test_data()] if all(tests): print("\n" + "=" * 50) print("+ ALL TESTS PASSED - Ready for training!") print("=" * 50) return True else: print("\n" + "=" * 50) print("! Some tests failed - fix issues before training") print("=" * 50) return False if __name__ == "__main__": success = main() sys.exit(0 if success else 1)