269 lines
9.1 KiB
Python
269 lines
9.1 KiB
Python
"""
|
|
Tests for storyboard schema validation.
|
|
"""
|
|
|
|
import pytest
|
|
from pathlib import Path
|
|
import json
|
|
import tempfile
|
|
import os
|
|
|
|
from src.storyboard.schema import (
|
|
Storyboard, ProjectSettings, Shot, Character, Location,
|
|
GlobalStyle, CameraSettings, GenerationSettings, OutputSettings,
|
|
Resolution
|
|
)
|
|
from src.storyboard.loader import StoryboardValidator, StoryboardLoadError
|
|
|
|
|
|
class TestStoryboardSchema:
|
|
"""Test storyboard schema validation."""
|
|
|
|
def test_create_minimal_storyboard(self):
|
|
"""Test creating a minimal valid storyboard."""
|
|
storyboard = Storyboard(
|
|
project=ProjectSettings(
|
|
title="Test Video",
|
|
resolution=Resolution(width=1920, height=1080),
|
|
target_duration_s=10
|
|
),
|
|
shots=[
|
|
Shot(
|
|
id="S01",
|
|
duration_s=5,
|
|
prompt="A test shot"
|
|
)
|
|
]
|
|
)
|
|
|
|
assert storyboard.project.title == "Test Video"
|
|
assert len(storyboard.shots) == 1
|
|
assert storyboard.shots[0].id == "S01"
|
|
|
|
def test_shot_validation_unique_ids(self):
|
|
"""Test that duplicate shot IDs are rejected."""
|
|
with pytest.raises(ValueError, match="Shot IDs must be unique"):
|
|
Storyboard(
|
|
project=ProjectSettings(
|
|
title="Test",
|
|
resolution=Resolution(width=1920, height=1080),
|
|
target_duration_s=10
|
|
),
|
|
shots=[
|
|
Shot(id="S01", duration_s=5, prompt="Shot 1"),
|
|
Shot(id="S01", duration_s=5, prompt="Shot 2")
|
|
]
|
|
)
|
|
|
|
def test_character_validation_unique_ids(self):
|
|
"""Test that duplicate character IDs are rejected."""
|
|
with pytest.raises(ValueError, match="Character IDs must be unique"):
|
|
Storyboard(
|
|
project=ProjectSettings(
|
|
title="Test",
|
|
resolution=Resolution(width=1920, height=1080),
|
|
target_duration_s=10
|
|
),
|
|
characters=[
|
|
Character(id="C01", name="Character 1"),
|
|
Character(id="C01", name="Character 2")
|
|
],
|
|
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
|
|
)
|
|
|
|
def test_location_validation_unique_ids(self):
|
|
"""Test that duplicate location IDs are rejected."""
|
|
with pytest.raises(ValueError, match="Location IDs must be unique"):
|
|
Storyboard(
|
|
project=ProjectSettings(
|
|
title="Test",
|
|
resolution=Resolution(width=1920, height=1080),
|
|
target_duration_s=10
|
|
),
|
|
locations=[
|
|
Location(id="L01", name="Location 1"),
|
|
Location(id="L01", name="Location 2")
|
|
],
|
|
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
|
|
)
|
|
|
|
def test_get_character_by_id(self):
|
|
"""Test retrieving character by ID."""
|
|
storyboard = Storyboard(
|
|
project=ProjectSettings(
|
|
title="Test",
|
|
resolution=Resolution(width=1920, height=1080),
|
|
target_duration_s=10
|
|
),
|
|
characters=[
|
|
Character(id="C01", name="Hero", description="The protagonist")
|
|
],
|
|
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
|
|
)
|
|
|
|
char = storyboard.get_character("C01")
|
|
assert char is not None
|
|
assert char.name == "Hero"
|
|
|
|
missing = storyboard.get_character("C99")
|
|
assert missing is None
|
|
|
|
def test_get_location_by_id(self):
|
|
"""Test retrieving location by ID."""
|
|
storyboard = Storyboard(
|
|
project=ProjectSettings(
|
|
title="Test",
|
|
resolution=Resolution(width=1920, height=1080),
|
|
target_duration_s=10
|
|
),
|
|
locations=[
|
|
Location(id="L01", name="Street", description="A city street")
|
|
],
|
|
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
|
|
)
|
|
|
|
loc = storyboard.get_location("L01")
|
|
assert loc is not None
|
|
assert loc.name == "Street"
|
|
|
|
missing = storyboard.get_location("L99")
|
|
assert missing is None
|
|
|
|
def test_total_duration_calculation(self):
|
|
"""Test total duration calculation."""
|
|
storyboard = Storyboard(
|
|
project=ProjectSettings(
|
|
title="Test",
|
|
resolution=Resolution(width=1920, height=1080),
|
|
target_duration_s=10
|
|
),
|
|
shots=[
|
|
Shot(id="S01", duration_s=4, prompt="Shot 1"),
|
|
Shot(id="S02", duration_s=6, prompt="Shot 2")
|
|
]
|
|
)
|
|
|
|
assert storyboard.get_total_duration() == 10
|
|
|
|
def test_total_frames_calculation(self):
|
|
"""Test total frames calculation."""
|
|
storyboard = Storyboard(
|
|
project=ProjectSettings(
|
|
title="Test",
|
|
fps=24,
|
|
resolution=Resolution(width=1920, height=1080),
|
|
target_duration_s=10
|
|
),
|
|
shots=[
|
|
Shot(id="S01", duration_s=4, prompt="Shot 1"),
|
|
Shot(id="S02", duration_s=6, prompt="Shot 2")
|
|
]
|
|
)
|
|
|
|
assert storyboard.get_total_frames() == 240 # 10 seconds * 24 fps
|
|
|
|
|
|
class TestStoryboardLoader:
|
|
"""Test storyboard loading functionality."""
|
|
|
|
def test_load_valid_storyboard(self):
|
|
"""Test loading a valid storyboard JSON file."""
|
|
data = {
|
|
"schema_version": "1.0",
|
|
"project": {
|
|
"title": "Test Video",
|
|
"fps": 24,
|
|
"target_duration_s": 10,
|
|
"resolution": {"width": 1920, "height": 1080},
|
|
"aspect_ratio": "16:9",
|
|
"global_style": {
|
|
"visual_style": "cinematic",
|
|
"negative_prompt": "blurry"
|
|
},
|
|
"audio": {"add_music": False}
|
|
},
|
|
"characters": [],
|
|
"locations": [],
|
|
"shots": [
|
|
{
|
|
"id": "S01",
|
|
"duration_s": 5,
|
|
"prompt": "A test shot",
|
|
"camera": {"framing": "wide"},
|
|
"generation": {"seed": 12345, "steps": 30}
|
|
}
|
|
],
|
|
"output": {"container": "mp4", "codec": "h264"}
|
|
}
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
json.dump(data, f)
|
|
temp_path = f.name
|
|
|
|
try:
|
|
storyboard = StoryboardValidator.load(temp_path)
|
|
assert storyboard.project.title == "Test Video"
|
|
assert len(storyboard.shots) == 1
|
|
assert storyboard.shots[0].id == "S01"
|
|
finally:
|
|
os.unlink(temp_path)
|
|
|
|
def test_load_nonexistent_file(self):
|
|
"""Test loading a file that doesn't exist."""
|
|
with pytest.raises(StoryboardLoadError, match="not found"):
|
|
StoryboardValidator.load("/nonexistent/path/storyboard.json")
|
|
|
|
def test_load_invalid_json(self):
|
|
"""Test loading invalid JSON."""
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
f.write("{invalid json}")
|
|
temp_path = f.name
|
|
|
|
try:
|
|
with pytest.raises(StoryboardLoadError, match="Invalid JSON"):
|
|
StoryboardValidator.load(temp_path)
|
|
finally:
|
|
os.unlink(temp_path)
|
|
|
|
def test_load_wrong_extension(self):
|
|
"""Test loading a file with wrong extension."""
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
|
|
f.write('{}')
|
|
temp_path = f.name
|
|
|
|
try:
|
|
with pytest.raises(StoryboardLoadError, match="must be JSON"):
|
|
StoryboardValidator.load(temp_path)
|
|
finally:
|
|
os.unlink(temp_path)
|
|
|
|
def test_validate_references(self):
|
|
"""Test reference validation."""
|
|
storyboard = Storyboard(
|
|
project=ProjectSettings(
|
|
title="Test",
|
|
resolution=Resolution(width=1920, height=1080),
|
|
target_duration_s=10
|
|
),
|
|
characters=[
|
|
Character(id="C01", name="Hero")
|
|
],
|
|
locations=[
|
|
Location(id="L01", name="Street")
|
|
],
|
|
shots=[
|
|
Shot(
|
|
id="S01",
|
|
duration_s=5,
|
|
prompt="Test",
|
|
characters=["C01", "C99"], # C99 doesn't exist
|
|
location_id="L99" # Doesn't exist
|
|
)
|
|
]
|
|
)
|
|
|
|
issues = StoryboardValidator.validate_references(storyboard)
|
|
assert len(issues) == 2
|
|
assert any("C99" in issue for issue in issues)
|
|
assert any("L99" in issue for issue in issues)
|