Initial Commit

This commit is contained in:
2026-02-03 23:06:28 -05:00
commit 46b10fb69b
25 changed files with 2770 additions and 0 deletions

View File

@@ -0,0 +1,189 @@
"""
Tests for checkpoint system.
"""
import pytest
import tempfile
import os
from pathlib import Path
from src.core.checkpoint import (
CheckpointManager, ShotCheckpoint, ProjectCheckpoint,
ShotStatus, ProjectStatus
)
class TestCheckpointManager:
"""Test checkpoint management."""
@pytest.fixture
def temp_db(self):
"""Create a temporary database for testing."""
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
db_path = f.name
yield db_path
# Cleanup
if os.path.exists(db_path):
os.unlink(db_path)
@pytest.fixture
def manager(self, temp_db):
"""Create a checkpoint manager with temp database."""
return CheckpointManager(db_path=temp_db)
def test_create_project(self, manager):
"""Test creating a project checkpoint."""
checkpoint = manager.create_project(
project_id="test_001",
storyboard_path="/path/to/storyboard.json",
output_dir="/path/to/output",
backend_name="wan_t2v_14b"
)
assert checkpoint.project_id == "test_001"
assert checkpoint.status == ProjectStatus.INITIALIZED
assert checkpoint.backend_name == "wan_t2v_14b"
def test_get_project(self, manager):
"""Test retrieving a project checkpoint."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/to/storyboard.json",
output_dir="/path/to/output",
backend_name="wan_t2v_14b"
)
retrieved = manager.get_project("test_001")
assert retrieved is not None
assert retrieved.project_id == "test_001"
missing = manager.get_project("nonexistent")
assert missing is None
def test_update_project_status(self, manager):
"""Test updating project status."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/to/storyboard.json",
output_dir="/path/to/output",
backend_name="wan_t2v_14b"
)
manager.update_project_status("test_001", ProjectStatus.RUNNING)
retrieved = manager.get_project("test_001")
assert retrieved.status == ProjectStatus.RUNNING
def test_create_shot(self, manager):
"""Test creating a shot checkpoint."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/storyboard.json",
output_dir="/path/output",
backend_name="wan"
)
shot = manager.create_shot("test_001", "S01", ShotStatus.PENDING)
assert shot.shot_id == "S01"
assert shot.status == ShotStatus.PENDING
def test_update_shot(self, manager):
"""Test updating shot checkpoint."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/storyboard.json",
output_dir="/path/output",
backend_name="wan"
)
manager.create_shot("test_001", "S01")
manager.update_shot(
"test_001",
"S01",
status=ShotStatus.IN_PROGRESS,
output_path="/path/output.mp4",
metadata={"seed": 12345}
)
shot = manager.get_shot("test_001", "S01")
assert shot.status == ShotStatus.IN_PROGRESS
assert shot.output_path == "/path/output.mp4"
assert shot.metadata["seed"] == 12345
def test_get_pending_shots(self, manager):
"""Test retrieving pending shots."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/storyboard.json",
output_dir="/path/output",
backend_name="wan"
)
manager.create_shot("test_001", "S01", ShotStatus.PENDING)
manager.create_shot("test_001", "S02", ShotStatus.COMPLETED)
manager.create_shot("test_001", "S03", ShotStatus.PENDING)
pending = manager.get_pending_shots("test_001")
assert len(pending) == 2
assert "S01" in pending
assert "S03" in pending
def test_can_resume(self, manager):
"""Test checking if project can be resumed."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/storyboard.json",
output_dir="/path/output",
backend_name="wan"
)
# No shots yet - can't resume
assert not manager.can_resume("test_001")
# Add pending shot - can resume
manager.create_shot("test_001", "S01", ShotStatus.PENDING)
assert manager.can_resume("test_001")
# Complete all shots - can't resume
manager.update_shot("test_001", "S01", status=ShotStatus.COMPLETED)
assert not manager.can_resume("test_001")
def test_delete_project(self, manager):
"""Test deleting a project."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/storyboard.json",
output_dir="/path/output",
backend_name="wan"
)
manager.create_shot("test_001", "S01")
manager.delete_project("test_001")
assert manager.get_project("test_001") is None
assert manager.get_shot("test_001", "S01") is None
def test_list_projects(self, manager):
"""Test listing all projects."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/1.json",
output_dir="/out/1",
backend_name="wan"
)
manager.create_project(
project_id="test_002",
storyboard_path="/path/2.json",
output_dir="/out/2",
backend_name="svd"
)
projects = manager.list_projects()
assert len(projects) == 2
project_ids = [p.project_id for p in projects]
assert "test_001" in project_ids
assert "test_002" in project_ids

163
tests/unit/test_config.py Normal file
View File

@@ -0,0 +1,163 @@
"""
Tests for configuration system.
"""
import pytest
import tempfile
import os
from pathlib import Path
from src.core.config import Config, BackendConfig, ConfigLoader, get_config, reload_config
class TestConfigLoader:
"""Test configuration loading."""
def test_load_default_config(self):
"""Test loading with no config files (uses defaults)."""
config = ConfigLoader.load(
config_path=Path("nonexistent.yaml"),
env_file=Path("nonexistent.env")
)
assert isinstance(config, Config)
assert config.active_backend == "wan_t2v_14b"
assert len(config.backends) == 0 # No YAML loaded
def test_load_yaml_config(self):
"""Test loading from YAML file."""
yaml_content = """
backends:
test_backend:
name: "Test Backend"
class: "test.TestBackend"
model_id: "test/model"
vram_gb: 8
dtype: "fp16"
enable_vae_slicing: true
enable_vae_tiling: false
chunking:
enabled: true
mode: "sequential"
max_chunk_seconds: 4
overlap_seconds: 1
active_backend: "test_backend"
defaults:
fps: 30
checkpoint_db: "test.db"
model_cache_dir: "~/test_cache"
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
f.write(yaml_content)
yaml_path = f.name
try:
config = ConfigLoader.load(config_path=Path(yaml_path))
assert config.active_backend == "test_backend"
assert "test_backend" in config.backends
backend = config.get_backend("test_backend")
assert backend.name == "Test Backend"
assert backend.vram_gb == 8
assert backend.chunking_mode == "sequential"
assert config.defaults["fps"] == 30
finally:
os.unlink(yaml_path)
def test_load_env_file(self, monkeypatch):
"""Test loading from .env file."""
# Clear any existing env vars first
for key in ['ACTIVE_BACKEND', 'MODEL_CACHE_DIR', 'LOG_LEVEL']:
monkeypatch.delenv(key, raising=False)
env_content = """
ACTIVE_BACKEND=env_backend
MODEL_CACHE_DIR=/env/cache
LOG_LEVEL=DEBUG
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.env', delete=False) as f:
f.write(env_content)
env_path = f.name
try:
config = ConfigLoader.load(
config_path=Path("nonexistent.yaml"),
env_file=Path(env_path)
)
assert config.active_backend == "env_backend"
assert "/env/cache" in config.model_cache_dir or "\\env\\cache" in config.model_cache_dir
assert config.log_level == "DEBUG"
finally:
os.unlink(env_path)
# Clean up env vars
for key in ['ACTIVE_BACKEND', 'MODEL_CACHE_DIR', 'LOG_LEVEL']:
monkeypatch.delenv(key, raising=False)
def test_env_variable_override(self, monkeypatch):
"""Test that environment variables override config files."""
# Clear env var first
monkeypatch.delenv("ACTIVE_BACKEND", raising=False)
yaml_content = """
active_backend: yaml_backend
model_cache_dir: ~/yaml_cache
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
f.write(yaml_content)
yaml_path = f.name
try:
monkeypatch.setenv("ACTIVE_BACKEND", "env_override")
config = ConfigLoader.load(config_path=Path(yaml_path))
# Environment variable should override
assert config.active_backend == "env_override"
# But other settings from YAML should remain
assert "yaml_cache" in config.model_cache_dir
finally:
os.unlink(yaml_path)
monkeypatch.delenv("ACTIVE_BACKEND", raising=False)
def test_path_expansion(self):
"""Test that paths are properly expanded."""
config = ConfigLoader.load(
config_path=Path("nonexistent.yaml"),
env_file=Path("nonexistent.env")
)
# Default paths should be expanded
assert not config.model_cache_dir.startswith("~")
assert not config.checkpoint_db.startswith("~")
def test_get_backend_not_found(self):
"""Test getting a non-existent backend."""
config = Config()
with pytest.raises(ValueError, match="not found"):
config.get_backend("nonexistent")
class TestBackendConfig:
"""Test backend configuration."""
def test_backend_config_defaults(self):
"""Test backend config with defaults."""
config = BackendConfig(
name="Test",
model_class="test.Test",
model_id="test/model",
vram_gb=12,
dtype="fp16"
)
assert config.enable_vae_slicing is True
assert config.enable_vae_tiling is False
assert config.chunking_enabled is True
assert config.chunking_mode == "sequential"

View File

@@ -0,0 +1,268 @@
"""
Tests for storyboard schema validation.
"""
import pytest
from pathlib import Path
import json
import tempfile
import os
from src.storyboard.schema import (
Storyboard, ProjectSettings, Shot, Character, Location,
GlobalStyle, CameraSettings, GenerationSettings, OutputSettings,
Resolution
)
from src.storyboard.loader import StoryboardValidator, StoryboardLoadError
class TestStoryboardSchema:
"""Test storyboard schema validation."""
def test_create_minimal_storyboard(self):
"""Test creating a minimal valid storyboard."""
storyboard = Storyboard(
project=ProjectSettings(
title="Test Video",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
shots=[
Shot(
id="S01",
duration_s=5,
prompt="A test shot"
)
]
)
assert storyboard.project.title == "Test Video"
assert len(storyboard.shots) == 1
assert storyboard.shots[0].id == "S01"
def test_shot_validation_unique_ids(self):
"""Test that duplicate shot IDs are rejected."""
with pytest.raises(ValueError, match="Shot IDs must be unique"):
Storyboard(
project=ProjectSettings(
title="Test",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
shots=[
Shot(id="S01", duration_s=5, prompt="Shot 1"),
Shot(id="S01", duration_s=5, prompt="Shot 2")
]
)
def test_character_validation_unique_ids(self):
"""Test that duplicate character IDs are rejected."""
with pytest.raises(ValueError, match="Character IDs must be unique"):
Storyboard(
project=ProjectSettings(
title="Test",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
characters=[
Character(id="C01", name="Character 1"),
Character(id="C01", name="Character 2")
],
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
)
def test_location_validation_unique_ids(self):
"""Test that duplicate location IDs are rejected."""
with pytest.raises(ValueError, match="Location IDs must be unique"):
Storyboard(
project=ProjectSettings(
title="Test",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
locations=[
Location(id="L01", name="Location 1"),
Location(id="L01", name="Location 2")
],
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
)
def test_get_character_by_id(self):
"""Test retrieving character by ID."""
storyboard = Storyboard(
project=ProjectSettings(
title="Test",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
characters=[
Character(id="C01", name="Hero", description="The protagonist")
],
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
)
char = storyboard.get_character("C01")
assert char is not None
assert char.name == "Hero"
missing = storyboard.get_character("C99")
assert missing is None
def test_get_location_by_id(self):
"""Test retrieving location by ID."""
storyboard = Storyboard(
project=ProjectSettings(
title="Test",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
locations=[
Location(id="L01", name="Street", description="A city street")
],
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
)
loc = storyboard.get_location("L01")
assert loc is not None
assert loc.name == "Street"
missing = storyboard.get_location("L99")
assert missing is None
def test_total_duration_calculation(self):
"""Test total duration calculation."""
storyboard = Storyboard(
project=ProjectSettings(
title="Test",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
shots=[
Shot(id="S01", duration_s=4, prompt="Shot 1"),
Shot(id="S02", duration_s=6, prompt="Shot 2")
]
)
assert storyboard.get_total_duration() == 10
def test_total_frames_calculation(self):
"""Test total frames calculation."""
storyboard = Storyboard(
project=ProjectSettings(
title="Test",
fps=24,
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
shots=[
Shot(id="S01", duration_s=4, prompt="Shot 1"),
Shot(id="S02", duration_s=6, prompt="Shot 2")
]
)
assert storyboard.get_total_frames() == 240 # 10 seconds * 24 fps
class TestStoryboardLoader:
"""Test storyboard loading functionality."""
def test_load_valid_storyboard(self):
"""Test loading a valid storyboard JSON file."""
data = {
"schema_version": "1.0",
"project": {
"title": "Test Video",
"fps": 24,
"target_duration_s": 10,
"resolution": {"width": 1920, "height": 1080},
"aspect_ratio": "16:9",
"global_style": {
"visual_style": "cinematic",
"negative_prompt": "blurry"
},
"audio": {"add_music": False}
},
"characters": [],
"locations": [],
"shots": [
{
"id": "S01",
"duration_s": 5,
"prompt": "A test shot",
"camera": {"framing": "wide"},
"generation": {"seed": 12345, "steps": 30}
}
],
"output": {"container": "mp4", "codec": "h264"}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(data, f)
temp_path = f.name
try:
storyboard = StoryboardValidator.load(temp_path)
assert storyboard.project.title == "Test Video"
assert len(storyboard.shots) == 1
assert storyboard.shots[0].id == "S01"
finally:
os.unlink(temp_path)
def test_load_nonexistent_file(self):
"""Test loading a file that doesn't exist."""
with pytest.raises(StoryboardLoadError, match="not found"):
StoryboardValidator.load("/nonexistent/path/storyboard.json")
def test_load_invalid_json(self):
"""Test loading invalid JSON."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
f.write("{invalid json}")
temp_path = f.name
try:
with pytest.raises(StoryboardLoadError, match="Invalid JSON"):
StoryboardValidator.load(temp_path)
finally:
os.unlink(temp_path)
def test_load_wrong_extension(self):
"""Test loading a file with wrong extension."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write('{}')
temp_path = f.name
try:
with pytest.raises(StoryboardLoadError, match="must be JSON"):
StoryboardValidator.load(temp_path)
finally:
os.unlink(temp_path)
def test_validate_references(self):
"""Test reference validation."""
storyboard = Storyboard(
project=ProjectSettings(
title="Test",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
characters=[
Character(id="C01", name="Hero")
],
locations=[
Location(id="L01", name="Street")
],
shots=[
Shot(
id="S01",
duration_s=5,
prompt="Test",
characters=["C01", "C99"], # C99 doesn't exist
location_id="L99" # Doesn't exist
)
]
)
issues = StoryboardValidator.validate_references(storyboard)
assert len(issues) == 2
assert any("C99" in issue for issue in issues)
assert any("L99" in issue for issue in issues)