Initial Commit
This commit is contained in:
189
tests/unit/test_checkpoint.py
Normal file
189
tests/unit/test_checkpoint.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Tests for checkpoint system.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from src.core.checkpoint import (
|
||||
CheckpointManager, ShotCheckpoint, ProjectCheckpoint,
|
||||
ShotStatus, ProjectStatus
|
||||
)
|
||||
|
||||
|
||||
class TestCheckpointManager:
|
||||
"""Test checkpoint management."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(self):
|
||||
"""Create a temporary database for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
yield db_path
|
||||
|
||||
# Cleanup
|
||||
if os.path.exists(db_path):
|
||||
os.unlink(db_path)
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self, temp_db):
|
||||
"""Create a checkpoint manager with temp database."""
|
||||
return CheckpointManager(db_path=temp_db)
|
||||
|
||||
def test_create_project(self, manager):
|
||||
"""Test creating a project checkpoint."""
|
||||
checkpoint = manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/to/storyboard.json",
|
||||
output_dir="/path/to/output",
|
||||
backend_name="wan_t2v_14b"
|
||||
)
|
||||
|
||||
assert checkpoint.project_id == "test_001"
|
||||
assert checkpoint.status == ProjectStatus.INITIALIZED
|
||||
assert checkpoint.backend_name == "wan_t2v_14b"
|
||||
|
||||
def test_get_project(self, manager):
|
||||
"""Test retrieving a project checkpoint."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/to/storyboard.json",
|
||||
output_dir="/path/to/output",
|
||||
backend_name="wan_t2v_14b"
|
||||
)
|
||||
|
||||
retrieved = manager.get_project("test_001")
|
||||
assert retrieved is not None
|
||||
assert retrieved.project_id == "test_001"
|
||||
|
||||
missing = manager.get_project("nonexistent")
|
||||
assert missing is None
|
||||
|
||||
def test_update_project_status(self, manager):
|
||||
"""Test updating project status."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/to/storyboard.json",
|
||||
output_dir="/path/to/output",
|
||||
backend_name="wan_t2v_14b"
|
||||
)
|
||||
|
||||
manager.update_project_status("test_001", ProjectStatus.RUNNING)
|
||||
|
||||
retrieved = manager.get_project("test_001")
|
||||
assert retrieved.status == ProjectStatus.RUNNING
|
||||
|
||||
def test_create_shot(self, manager):
|
||||
"""Test creating a shot checkpoint."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/storyboard.json",
|
||||
output_dir="/path/output",
|
||||
backend_name="wan"
|
||||
)
|
||||
|
||||
shot = manager.create_shot("test_001", "S01", ShotStatus.PENDING)
|
||||
|
||||
assert shot.shot_id == "S01"
|
||||
assert shot.status == ShotStatus.PENDING
|
||||
|
||||
def test_update_shot(self, manager):
|
||||
"""Test updating shot checkpoint."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/storyboard.json",
|
||||
output_dir="/path/output",
|
||||
backend_name="wan"
|
||||
)
|
||||
manager.create_shot("test_001", "S01")
|
||||
|
||||
manager.update_shot(
|
||||
"test_001",
|
||||
"S01",
|
||||
status=ShotStatus.IN_PROGRESS,
|
||||
output_path="/path/output.mp4",
|
||||
metadata={"seed": 12345}
|
||||
)
|
||||
|
||||
shot = manager.get_shot("test_001", "S01")
|
||||
assert shot.status == ShotStatus.IN_PROGRESS
|
||||
assert shot.output_path == "/path/output.mp4"
|
||||
assert shot.metadata["seed"] == 12345
|
||||
|
||||
def test_get_pending_shots(self, manager):
|
||||
"""Test retrieving pending shots."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/storyboard.json",
|
||||
output_dir="/path/output",
|
||||
backend_name="wan"
|
||||
)
|
||||
|
||||
manager.create_shot("test_001", "S01", ShotStatus.PENDING)
|
||||
manager.create_shot("test_001", "S02", ShotStatus.COMPLETED)
|
||||
manager.create_shot("test_001", "S03", ShotStatus.PENDING)
|
||||
|
||||
pending = manager.get_pending_shots("test_001")
|
||||
assert len(pending) == 2
|
||||
assert "S01" in pending
|
||||
assert "S03" in pending
|
||||
|
||||
def test_can_resume(self, manager):
|
||||
"""Test checking if project can be resumed."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/storyboard.json",
|
||||
output_dir="/path/output",
|
||||
backend_name="wan"
|
||||
)
|
||||
|
||||
# No shots yet - can't resume
|
||||
assert not manager.can_resume("test_001")
|
||||
|
||||
# Add pending shot - can resume
|
||||
manager.create_shot("test_001", "S01", ShotStatus.PENDING)
|
||||
assert manager.can_resume("test_001")
|
||||
|
||||
# Complete all shots - can't resume
|
||||
manager.update_shot("test_001", "S01", status=ShotStatus.COMPLETED)
|
||||
assert not manager.can_resume("test_001")
|
||||
|
||||
def test_delete_project(self, manager):
|
||||
"""Test deleting a project."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/storyboard.json",
|
||||
output_dir="/path/output",
|
||||
backend_name="wan"
|
||||
)
|
||||
manager.create_shot("test_001", "S01")
|
||||
|
||||
manager.delete_project("test_001")
|
||||
|
||||
assert manager.get_project("test_001") is None
|
||||
assert manager.get_shot("test_001", "S01") is None
|
||||
|
||||
def test_list_projects(self, manager):
|
||||
"""Test listing all projects."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/1.json",
|
||||
output_dir="/out/1",
|
||||
backend_name="wan"
|
||||
)
|
||||
manager.create_project(
|
||||
project_id="test_002",
|
||||
storyboard_path="/path/2.json",
|
||||
output_dir="/out/2",
|
||||
backend_name="svd"
|
||||
)
|
||||
|
||||
projects = manager.list_projects()
|
||||
assert len(projects) == 2
|
||||
|
||||
project_ids = [p.project_id for p in projects]
|
||||
assert "test_001" in project_ids
|
||||
assert "test_002" in project_ids
|
||||
Reference in New Issue
Block a user