""" Tests for storyboard schema validation. """ import pytest from pathlib import Path import json import tempfile import os from src.storyboard.schema import ( Storyboard, ProjectSettings, Shot, Character, Location, GlobalStyle, CameraSettings, GenerationSettings, OutputSettings, Resolution ) from src.storyboard.loader import StoryboardValidator, StoryboardLoadError class TestStoryboardSchema: """Test storyboard schema validation.""" def test_create_minimal_storyboard(self): """Test creating a minimal valid storyboard.""" storyboard = Storyboard( project=ProjectSettings( title="Test Video", resolution=Resolution(width=1920, height=1080), target_duration_s=10 ), shots=[ Shot( id="S01", duration_s=5, prompt="A test shot" ) ] ) assert storyboard.project.title == "Test Video" assert len(storyboard.shots) == 1 assert storyboard.shots[0].id == "S01" def test_shot_validation_unique_ids(self): """Test that duplicate shot IDs are rejected.""" with pytest.raises(ValueError, match="Shot IDs must be unique"): Storyboard( project=ProjectSettings( title="Test", resolution=Resolution(width=1920, height=1080), target_duration_s=10 ), shots=[ Shot(id="S01", duration_s=5, prompt="Shot 1"), Shot(id="S01", duration_s=5, prompt="Shot 2") ] ) def test_character_validation_unique_ids(self): """Test that duplicate character IDs are rejected.""" with pytest.raises(ValueError, match="Character IDs must be unique"): Storyboard( project=ProjectSettings( title="Test", resolution=Resolution(width=1920, height=1080), target_duration_s=10 ), characters=[ Character(id="C01", name="Character 1"), Character(id="C01", name="Character 2") ], shots=[Shot(id="S01", duration_s=5, prompt="Test")] ) def test_location_validation_unique_ids(self): """Test that duplicate location IDs are rejected.""" with pytest.raises(ValueError, match="Location IDs must be unique"): Storyboard( project=ProjectSettings( title="Test", resolution=Resolution(width=1920, height=1080), target_duration_s=10 ), locations=[ Location(id="L01", name="Location 1"), Location(id="L01", name="Location 2") ], shots=[Shot(id="S01", duration_s=5, prompt="Test")] ) def test_get_character_by_id(self): """Test retrieving character by ID.""" storyboard = Storyboard( project=ProjectSettings( title="Test", resolution=Resolution(width=1920, height=1080), target_duration_s=10 ), characters=[ Character(id="C01", name="Hero", description="The protagonist") ], shots=[Shot(id="S01", duration_s=5, prompt="Test")] ) char = storyboard.get_character("C01") assert char is not None assert char.name == "Hero" missing = storyboard.get_character("C99") assert missing is None def test_get_location_by_id(self): """Test retrieving location by ID.""" storyboard = Storyboard( project=ProjectSettings( title="Test", resolution=Resolution(width=1920, height=1080), target_duration_s=10 ), locations=[ Location(id="L01", name="Street", description="A city street") ], shots=[Shot(id="S01", duration_s=5, prompt="Test")] ) loc = storyboard.get_location("L01") assert loc is not None assert loc.name == "Street" missing = storyboard.get_location("L99") assert missing is None def test_total_duration_calculation(self): """Test total duration calculation.""" storyboard = Storyboard( project=ProjectSettings( title="Test", resolution=Resolution(width=1920, height=1080), target_duration_s=10 ), shots=[ Shot(id="S01", duration_s=4, prompt="Shot 1"), Shot(id="S02", duration_s=6, prompt="Shot 2") ] ) assert storyboard.get_total_duration() == 10 def test_total_frames_calculation(self): """Test total frames calculation.""" storyboard = Storyboard( project=ProjectSettings( title="Test", fps=24, resolution=Resolution(width=1920, height=1080), target_duration_s=10 ), shots=[ Shot(id="S01", duration_s=4, prompt="Shot 1"), Shot(id="S02", duration_s=6, prompt="Shot 2") ] ) assert storyboard.get_total_frames() == 240 # 10 seconds * 24 fps class TestStoryboardLoader: """Test storyboard loading functionality.""" def test_load_valid_storyboard(self): """Test loading a valid storyboard JSON file.""" data = { "schema_version": "1.0", "project": { "title": "Test Video", "fps": 24, "target_duration_s": 10, "resolution": {"width": 1920, "height": 1080}, "aspect_ratio": "16:9", "global_style": { "visual_style": "cinematic", "negative_prompt": "blurry" }, "audio": {"add_music": False} }, "characters": [], "locations": [], "shots": [ { "id": "S01", "duration_s": 5, "prompt": "A test shot", "camera": {"framing": "wide"}, "generation": {"seed": 12345, "steps": 30} } ], "output": {"container": "mp4", "codec": "h264"} } with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: json.dump(data, f) temp_path = f.name try: storyboard = StoryboardValidator.load(temp_path) assert storyboard.project.title == "Test Video" assert len(storyboard.shots) == 1 assert storyboard.shots[0].id == "S01" finally: os.unlink(temp_path) def test_load_nonexistent_file(self): """Test loading a file that doesn't exist.""" with pytest.raises(StoryboardLoadError, match="not found"): StoryboardValidator.load("/nonexistent/path/storyboard.json") def test_load_invalid_json(self): """Test loading invalid JSON.""" with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: f.write("{invalid json}") temp_path = f.name try: with pytest.raises(StoryboardLoadError, match="Invalid JSON"): StoryboardValidator.load(temp_path) finally: os.unlink(temp_path) def test_load_wrong_extension(self): """Test loading a file with wrong extension.""" with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: f.write('{}') temp_path = f.name try: with pytest.raises(StoryboardLoadError, match="must be JSON"): StoryboardValidator.load(temp_path) finally: os.unlink(temp_path) def test_validate_references(self): """Test reference validation.""" storyboard = Storyboard( project=ProjectSettings( title="Test", resolution=Resolution(width=1920, height=1080), target_duration_s=10 ), characters=[ Character(id="C01", name="Hero") ], locations=[ Location(id="L01", name="Street") ], shots=[ Shot( id="S01", duration_s=5, prompt="Test", characters=["C01", "C99"], # C99 doesn't exist location_id="L99" # Doesn't exist ) ] ) issues = StoryboardValidator.validate_references(storyboard) assert len(issues) == 2 assert any("C99" in issue for issue in issues) assert any("L99" in issue for issue in issues)