""" Tests for checkpoint system. """ import pytest import tempfile import os from pathlib import Path from src.core.checkpoint import ( CheckpointManager, ShotCheckpoint, ProjectCheckpoint, ShotStatus, ProjectStatus ) class TestCheckpointManager: """Test checkpoint management.""" @pytest.fixture def temp_db(self): """Create a temporary database for testing.""" with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f: db_path = f.name yield db_path # Cleanup if os.path.exists(db_path): os.unlink(db_path) @pytest.fixture def manager(self, temp_db): """Create a checkpoint manager with temp database.""" return CheckpointManager(db_path=temp_db) def test_create_project(self, manager): """Test creating a project checkpoint.""" checkpoint = manager.create_project( project_id="test_001", storyboard_path="/path/to/storyboard.json", output_dir="/path/to/output", backend_name="wan_t2v_14b" ) assert checkpoint.project_id == "test_001" assert checkpoint.status == ProjectStatus.INITIALIZED assert checkpoint.backend_name == "wan_t2v_14b" def test_get_project(self, manager): """Test retrieving a project checkpoint.""" manager.create_project( project_id="test_001", storyboard_path="/path/to/storyboard.json", output_dir="/path/to/output", backend_name="wan_t2v_14b" ) retrieved = manager.get_project("test_001") assert retrieved is not None assert retrieved.project_id == "test_001" missing = manager.get_project("nonexistent") assert missing is None def test_update_project_status(self, manager): """Test updating project status.""" manager.create_project( project_id="test_001", storyboard_path="/path/to/storyboard.json", output_dir="/path/to/output", backend_name="wan_t2v_14b" ) manager.update_project_status("test_001", ProjectStatus.RUNNING) retrieved = manager.get_project("test_001") assert retrieved.status == ProjectStatus.RUNNING def test_create_shot(self, manager): """Test creating a shot checkpoint.""" manager.create_project( project_id="test_001", storyboard_path="/path/storyboard.json", output_dir="/path/output", backend_name="wan" ) shot = manager.create_shot("test_001", "S01", ShotStatus.PENDING) assert shot.shot_id == "S01" assert shot.status == ShotStatus.PENDING def test_update_shot(self, manager): """Test updating shot checkpoint.""" manager.create_project( project_id="test_001", storyboard_path="/path/storyboard.json", output_dir="/path/output", backend_name="wan" ) manager.create_shot("test_001", "S01") manager.update_shot( "test_001", "S01", status=ShotStatus.IN_PROGRESS, output_path="/path/output.mp4", metadata={"seed": 12345} ) shot = manager.get_shot("test_001", "S01") assert shot.status == ShotStatus.IN_PROGRESS assert shot.output_path == "/path/output.mp4" assert shot.metadata["seed"] == 12345 def test_get_pending_shots(self, manager): """Test retrieving pending shots.""" manager.create_project( project_id="test_001", storyboard_path="/path/storyboard.json", output_dir="/path/output", backend_name="wan" ) manager.create_shot("test_001", "S01", ShotStatus.PENDING) manager.create_shot("test_001", "S02", ShotStatus.COMPLETED) manager.create_shot("test_001", "S03", ShotStatus.PENDING) pending = manager.get_pending_shots("test_001") assert len(pending) == 2 assert "S01" in pending assert "S03" in pending def test_can_resume(self, manager): """Test checking if project can be resumed.""" manager.create_project( project_id="test_001", storyboard_path="/path/storyboard.json", output_dir="/path/output", backend_name="wan" ) # No shots yet - can't resume assert not manager.can_resume("test_001") # Add pending shot - can resume manager.create_shot("test_001", "S01", ShotStatus.PENDING) assert manager.can_resume("test_001") # Complete all shots - can't resume manager.update_shot("test_001", "S01", status=ShotStatus.COMPLETED) assert not manager.can_resume("test_001") def test_delete_project(self, manager): """Test deleting a project.""" manager.create_project( project_id="test_001", storyboard_path="/path/storyboard.json", output_dir="/path/output", backend_name="wan" ) manager.create_shot("test_001", "S01") manager.delete_project("test_001") assert manager.get_project("test_001") is None assert manager.get_shot("test_001", "S01") is None def test_list_projects(self, manager): """Test listing all projects.""" manager.create_project( project_id="test_001", storyboard_path="/path/1.json", output_dir="/out/1", backend_name="wan" ) manager.create_project( project_id="test_002", storyboard_path="/path/2.json", output_dir="/out/2", backend_name="svd" ) projects = manager.list_projects() assert len(projects) == 2 project_ids = [p.project_id for p in projects] assert "test_001" in project_ids assert "test_002" in project_ids