commit 46b10fb69b1799c08964b8a76958f8ff60fddaa4 Author: Santhosh Janardhanan Date: Tue Feb 3 23:06:28 2026 -0500 Initial Commit diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..19658de --- /dev/null +++ b/.env.example @@ -0,0 +1,23 @@ +# Environment configuration for storyboard-video +# Copy this file to .env and customize as needed + +# Active backend selection (must match a key in config/models.yaml) +ACTIVE_BACKEND=wan_t2v_14b + +# Model cache directory (where downloaded models are stored) +MODEL_CACHE_DIR=~/.cache/storyboard-video/models + +# HuggingFace Hub cache (if using HF models) +HF_HOME=~/.cache/huggingface + +# CUDA settings +# CUDA_VISIBLE_DEVICES=0 + +# Logging level (DEBUG, INFO, WARNING, ERROR) +LOG_LEVEL=INFO + +# Output directory base path +OUTPUT_BASE_DIR=./outputs + +# FFmpeg path (leave empty to use system PATH) +FFMPEG_PATH= diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a490955 --- /dev/null +++ b/.gitignore @@ -0,0 +1,138 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Virtual environments +venv/ +env/ +ENV/ +env.bak/ +venv.bak/ +.venv + +# Conda +conda-env/ +*.conda + +# IDEs +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store + +# Jupyter Notebook +.ipynb_checkpoints + +# Pytest +.pytest_cache/ +.coverage +htmlcov/ +.tox/ +.nox/ + +# Environment variables +.env +.env.local +.env.*.local + +# Model caches and outputs +models/ +*.ckpt +*.safetensors +*.bin +*.pth +*.onnx +outputs/ +checkpoints.db + +# Temporary files +tmp/ +temp/ +*.tmp +*.temp + +# Logs +*.log +logs/ + +# Video files (generated) +*.mp4 +*.mov +*.avi +*.mkv +*.webm + +# Image files (generated) +*.png +*.jpg +*.jpeg +*.gif +*.bmp +*.tiff + +# Audio files (generated) +*.wav +*.mp3 +*.aac +*.flac +*.ogg + +# HuggingFace cache +.cache/ +transformers_cache/ +diffusers_cache/ + +# Windows +Thumbs.db +ehthumbs.db +Desktop.ini +$RECYCLE.BIN/ + +# macOS +.AppleDouble +.LSOverride +Icon +._* +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Project specific +# Keep templates but ignore generated outputs +!/templates/*.json +!/templates/*.yaml +!/templates/*.yml + +# Allow example files +!.env.example +!config/*.yaml +!config/*.yml + +# Keep docs +docs/*.md diff --git a/AGENTS.MD b/AGENTS.MD new file mode 100644 index 0000000..e33e6d8 --- /dev/null +++ b/AGENTS.MD @@ -0,0 +1,210 @@ +# agents.md +## Project: Local AI Video Generation from Text Storyboards (Windows + RTX 5070 12GB) + +### 0) Who this is for +The owner (user) is not an ML expert. The system must: +- be reproducible (conda + requirements) +- have guardrails (configs, logs, validation) +- be test-driven (pytest) +- maintain docs (developer + user) + +--- + +## 1) High-Level Goal +Build a local pipeline that converts **text-only storyboards** into **15–30 second videos** by: +1) converting storyboard -> shot plan +2) generating shot clips (T2V or I2V when possible) +3) assembling clips into a final MP4 +4) upscaling to 2K/4K if desired + +This is a **shot-based** system, not “one prompt makes a whole movie”. + +--- + +## 2) Hard Constraints (Hardware & OS) +Target system: +- Windows 11 +- NVIDIA RTX 5070 (12GB VRAM) - Must use GPU. +- 32GB RAM +- 2TB SSD +- Anaconda available + +Design must be stable under 12GB VRAM using: +- fp16/bf16 +- attention slicing +- xFormers / SDPA where supported +- optional CPU offload + +--- + +## 3) Output Targets (Realistic) +- Native generation: 720p–1080p (preferred) +- Final delivery: 1080p required; 2K/4K via upscaling +- Duration: 15–30s per video (may be segmented) +- FPS: 24 default +- Output: MP4 (H.264/H.265) + +--- + +## 4) CUDA 13.1 Reality & PyTorch Plan (Critical) +User has CUDA Toolkit 13.1 installed. Current PyTorch builds generally ship with and target CUDA 12.x runtimes. +We must NOT assume PyTorch will build/run against local CUDA 13.1 toolkit. + +**Plan:** +- Use **PyTorch prebuilt binaries that bundle CUDA runtime** (e.g., cu121 / cu124). +- Rely on NVIDIA driver compatibility rather than local CUDA toolkit version. +- Avoid compiling custom CUDA extensions unless necessary. + +Implementation notes: +- Prefer installing PyTorch via conda or pip using official CUDA 12.x builds. +- If xFormers causes build issues, use PyTorch SDPA and disable xFormers. + +--- + +## 5) Approved Stack (Do Not Deviate) +### Core +- Python 3.10 or 3.11 (conda env) +- PyTorch (CUDA 12.x build: cu121 or cu124) +- diffusers + transformers + accelerate + safetensors +- ffmpeg for assembly +- opencv-python for frame IO (if needed) +- pydantic for config/schema validation +- rich / loguru for logs + +### Testing +- pytest +- pytest-cov +- snapshot-ish tests where feasible (metadata + shapes, not visual perfection) + +### Docs +- /docs/developer.md (developer documentation) +- /docs/user.md (user manual) +- Keep docs updated alongside code changes. + +--- + +## 6) Video Models (Pragmatic Choices) +### Primary (target) +- WAN 2.x family (T2V; optional I2V if supported in chosen pipeline) +Goal: best possible quality on consumer VRAM with chunking. + +### Secondary / fallback +- Stable Video Diffusion (SVD) if WAN is unstable +- LTX-Video (only if it fits and is stable in our stack) + +All model backends must implement the same interface: +- generate_shot(shot_spec) -> video_file + metadata + +--- + +## 7) Canonical Input: Storyboard JSON +Storyboard source is text-only (often AI-generated). We will store and validate it as JSON. + +A template exists at: `templates/storyboard.template.json` + +We will later build a utility script: +- input: plain text fields or a simple text format +- output: valid storyboard JSON + +--- + +## 8) Pipeline Modules (Required) +### A) Storyboard parsing & validation +- Load storyboard JSON +- Validate schema +- Expand defaults (fps, resolution, global style) +- Produce normalized shot list + +### B) Prompt compilation +- Merge global style + shot prompt + camera notes +- Produce positive + negative prompts +- Keep deterministic via seeds + +### C) Generation runner (per shot) +- For each shot: generate clip +- Support: + - seed control + - chunking (e.g., generate 4–6 seconds then continue) + - optional init frame handoff between shots + +### D) Assembly +- Use ffmpeg concat to build final video +- Optionally add: + - transitions + - temp audio + - burn-in shot IDs for debugging mode + +### E) Upscaling (optional) +- Upscale final to 2K/4K (post step) +- Keep this modular so user can skip. + +--- + +## 9) Determinism & Logging (Must Have) +For each shot and final render, save: +- prompts (positive/negative) +- seed(s) +- model + revision/hash info if available +- inference params (steps, cfg, sampler, resolution, fps, frames) +- timing + VRAM notes if possible + +Every run produces a folder: +- outputs/// + - shots/ + - assembled/ + - metadata/ + +--- + +## 10) Testing Rules (Hard Requirement) +- Tests must be written alongside features. +- Whenever a file/function is modified, corresponding tests MUST be updated. +- Prefer tests that verify: + - schema validation works + - prompt compiler output is stable + - shot planner expands durations -> frame counts + - assembly command lines are correct + - metadata is generated correctly + +Do not require “visual quality” assertions. Test structure and determinism. + +--- + +## 11) Documentation Rules (Hard Requirement) +Maintain these continuously: +- docs/developer.md + - architecture + - install steps + - how to run tests + - how to add a new model backend +- docs/user.md + - quickstart + - how to create storyboard JSON + - how to run generation + - where outputs go + - troubleshooting (VRAM, drivers, ffmpeg) + +Docs must be updated whenever CLI flags, file formats, or workflows change. + +--- + +## 12) Project Files to Maintain +Required: +- requirements.txt (pip deps) +- environment.yml (conda env) +- templates/storyboard.template.json +- docs/developer.md +- docs/user.md +- src/ (implementation) +- tests/ (pytest) + +--- + +## 13) Definition of Done +A feature is “done” only if: +- implemented +- tests added/updated +- docs updated +- reproducible install instructions remain valid + +End of file. diff --git a/TODO.MD b/TODO.MD new file mode 100644 index 0000000..4f80170 --- /dev/null +++ b/TODO.MD @@ -0,0 +1,35 @@ +# TODO + +## Repo bootstrap +- [x] Define project direction and constraints in `agents.md` +- [x] Add `requirements.txt` (pip dependencies) +- [x] Add `environment.yml` (conda environment; PyTorch CUDA 12.x runtime strategy) +- [x] Add storyboard JSON template at `templates/storyboard.template.json` + +## Core implementation (next) +- [ ] Create repo structure: `src/`, `tests/`, `docs/`, `templates/`, `outputs/` +- [ ] Implement storyboard schema validator (pydantic) + loader +- [ ] Implement prompt compiler (global style + shot + camera) +- [ ] Implement shot planning (duration -> frame count, chunk plan) +- [ ] Implement model backend interface (`BaseVideoBackend`) +- [ ] Implement WAN backend (primary) with VRAM-safe defaults +- [ ] Implement fallback backend (SVD) for reliability testing +- [ ] Implement ffmpeg assembler (concat + optional audio + debug burn-in) +- [ ] Implement optional upscaling module (post-process) + +## Utilities +- [ ] Write storyboard “plain text → JSON” utility script (fills `storyboard.template.json`) +- [ ] Add config file support (YAML/JSON) for global defaults + +## Testing (parallel work; required) +- [ ] Add `pytest` scaffolding +- [ ] Add tests for schema validation +- [ ] Add tests for prompt compilation determinism +- [ ] Add tests for shot planning (frames/chunks) +- [ ] Add tests for ffmpeg command generation (no actual render needed) +- [ ] Ensure every code change includes a corresponding test update + +## Documentation (maintained continuously) +- [ ] Create `docs/developer.md` (install, architecture, tests, adding backends) +- [ ] Create `docs/user.md` (quickstart, storyboard creation, running, outputs, troubleshooting) +- [ ] Keep docs updated whenever CLI/config/schema changes diff --git a/config/models.yaml b/config/models.yaml new file mode 100644 index 0000000..698d0f0 --- /dev/null +++ b/config/models.yaml @@ -0,0 +1,61 @@ +# Model configurations for video generation backends +# Backend selection is controlled via ACTIVE_BACKEND environment variable +# or --backend CLI flag + +backends: + wan_t2v_14b: + name: "Wan 2.1 T2V 14B" + class: "generation.backends.wan.WanBackend" + model_id: "Wan-AI/Wan2.1-T2V-14B" + vram_gb: 12 + dtype: "fp16" + enable_vae_slicing: true + enable_vae_tiling: true + chunking: + enabled: true + mode: "sequential" # Options: sequential, overlapping + max_chunk_seconds: 4 + overlap_seconds: 1 # Only used when mode is overlapping + + wan_t2v_1_3b: + name: "Wan 2.1 T2V 1.3B" + class: "generation.backends.wan.WanBackend" + model_id: "Wan-AI/Wan2.1-T2V-1.3B" + vram_gb: 8 + dtype: "fp16" + enable_vae_slicing: true + enable_vae_tiling: false + chunking: + enabled: true + mode: "sequential" + max_chunk_seconds: 6 + overlap_seconds: 1 + + svd: + name: "Stable Video Diffusion" + class: "generation.backends.svd.SVDBackend" + model_id: "stabilityai/stable-video-diffusion-img2vid-xt" + vram_gb: 10 + dtype: "fp16" + enable_vae_slicing: true + enable_vae_tiling: true + chunking: + enabled: false + +# Default backend selection +# Override with ACTIVE_BACKEND environment variable +active_backend: "wan_t2v_14b" + +# Global generation defaults +defaults: + fps: 24 + resolution: + width: 1920 + height: 1080 + +# Checkpoint database settings +checkpoint_db: "outputs/checkpoints.db" + +# Model cache directory +# Override with MODEL_CACHE_DIR environment variable +model_cache_dir: "~/.cache/storyboard-video/models" diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..d0a2907 --- /dev/null +++ b/environment.yml @@ -0,0 +1,24 @@ +# environment.yml +name: storyboard-video +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - python=3.11 + - pip=24.3.1 + + # FFmpeg in env is easiest on Windows if you rely on conda-forge + - ffmpeg=7.1 + + # PyTorch with CUDA runtime bundled (NOT using local CUDA 13.1 toolkit) + # Using compatible versions for PyTorch 2.5.1 with CUDA 12.4 + - pytorch=2.5.1 + - pytorch-cuda=12.4 + - torchvision=0.20.1 + - torchaudio=2.5.1 + + # pip deps (mirrors requirements.txt) + - pip: + - -r requirements.txt diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..ad2a83a --- /dev/null +++ b/pytest.ini @@ -0,0 +1,7 @@ +# pytest configuration +[tool:pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = -v --tb=short diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b97ffe0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,26 @@ +# requirements.txt +# Install PyTorch separately (see environment.yml and docs). + +diffusers==0.32.2 +transformers==4.48.3 +accelerate==1.3.0 +safetensors==0.5.2 + +pydantic==2.10.6 +pyyaml==6.0.2 +numpy==2.1.3 +pillow==11.1.0 +opencv-python==4.11.0.86 + +rich==13.9.4 +loguru==0.7.3 + +tqdm==4.67.1 +requests==2.32.3 + +# CLI & tooling +typer==0.15.1 + +# Testing +pytest==8.3.4 +pytest-cov==6.0.0 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..3b79e6e --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,5 @@ +""" +Storyboard video generation pipeline. +""" + +__version__ = "0.1.0" diff --git a/src/assembly/__init__.py b/src/assembly/__init__.py new file mode 100644 index 0000000..5fabb51 --- /dev/null +++ b/src/assembly/__init__.py @@ -0,0 +1,3 @@ +""" +Video assembly and post-processing. +""" diff --git a/src/cli/__init__.py b/src/cli/__init__.py new file mode 100644 index 0000000..d84e788 --- /dev/null +++ b/src/cli/__init__.py @@ -0,0 +1,3 @@ +""" +CLI entry points. +""" diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..d23f7e9 --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1,25 @@ +""" +Core utilities and shared components. +""" + +from .config import Config, BackendConfig, ConfigLoader, get_config, reload_config +from .checkpoint import ( + CheckpointManager, + ShotCheckpoint, + ProjectCheckpoint, + ShotStatus, + ProjectStatus +) + +__all__ = [ + 'Config', + 'BackendConfig', + 'ConfigLoader', + 'get_config', + 'reload_config', + 'CheckpointManager', + 'ShotCheckpoint', + 'ProjectCheckpoint', + 'ShotStatus', + 'ProjectStatus' +] diff --git a/src/core/checkpoint.py b/src/core/checkpoint.py new file mode 100644 index 0000000..5f858c6 --- /dev/null +++ b/src/core/checkpoint.py @@ -0,0 +1,435 @@ +""" +Checkpoint and resume system using SQLite. +Tracks generation progress and enables resuming from failures. +""" + +import sqlite3 +import json +from pathlib import Path +from datetime import datetime +from typing import Optional, List, Dict, Any +from dataclasses import dataclass +from enum import Enum + + +class ShotStatus(str, Enum): + """Status of a shot generation.""" + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + FAILED = "failed" + SKIPPED = "skipped" + + +class ProjectStatus(str, Enum): + """Status of a project.""" + INITIALIZED = "initialized" + RUNNING = "running" + PAUSED = "paused" + COMPLETED = "completed" + FAILED = "failed" + + +@dataclass +class ShotCheckpoint: + """Checkpoint data for a single shot.""" + shot_id: str + status: ShotStatus + output_path: Optional[str] = None + error_message: Optional[str] = None + started_at: Optional[str] = None + completed_at: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None + + def __post_init__(self): + if self.metadata is None: + self.metadata = {} + + +@dataclass +class ProjectCheckpoint: + """Checkpoint data for a project.""" + project_id: str + storyboard_path: str + output_dir: str + status: ProjectStatus + backend_name: str + started_at: str + completed_at: Optional[str] = None + error_message: Optional[str] = None + + +class CheckpointManager: + """Manages checkpoints using SQLite database.""" + + def __init__(self, db_path: str = "outputs/checkpoints.db"): + """ + Initialize checkpoint manager. + + Args: + db_path: Path to SQLite database file + """ + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_db() + + def _init_db(self): + """Initialize database schema.""" + conn = sqlite3.connect(self.db_path) + try: + conn.execute(""" + CREATE TABLE IF NOT EXISTS projects ( + project_id TEXT PRIMARY KEY, + storyboard_path TEXT NOT NULL, + output_dir TEXT NOT NULL, + status TEXT NOT NULL, + backend_name TEXT NOT NULL, + started_at TEXT NOT NULL, + completed_at TEXT, + error_message TEXT + ) + """) + + conn.execute(""" + CREATE TABLE IF NOT EXISTS shots ( + project_id TEXT NOT NULL, + shot_id TEXT NOT NULL, + status TEXT NOT NULL, + output_path TEXT, + error_message TEXT, + started_at TEXT, + completed_at TEXT, + metadata TEXT, + PRIMARY KEY (project_id, shot_id), + FOREIGN KEY (project_id) REFERENCES projects(project_id) + ) + """) + + conn.commit() + finally: + conn.close() + + def _get_connection(self): + """Get a new database connection.""" + return sqlite3.connect(self.db_path) + + def create_project( + self, + project_id: str, + storyboard_path: str, + output_dir: str, + backend_name: str + ) -> ProjectCheckpoint: + """ + Create a new project checkpoint. + + Args: + project_id: Unique project identifier + storyboard_path: Path to storyboard JSON file + output_dir: Output directory for generated files + backend_name: Name of the generation backend + + Returns: + ProjectCheckpoint object + """ + checkpoint = ProjectCheckpoint( + project_id=project_id, + storyboard_path=storyboard_path, + output_dir=output_dir, + status=ProjectStatus.INITIALIZED, + backend_name=backend_name, + started_at=datetime.now().isoformat() + ) + + conn = self._get_connection() + try: + conn.execute( + """ + INSERT INTO projects + (project_id, storyboard_path, output_dir, status, backend_name, started_at) + VALUES (?, ?, ?, ?, ?, ?) + """, + ( + checkpoint.project_id, + checkpoint.storyboard_path, + checkpoint.output_dir, + checkpoint.status.value, + checkpoint.backend_name, + checkpoint.started_at + ) + ) + conn.commit() + finally: + conn.close() + + return checkpoint + + def get_project(self, project_id: str) -> Optional[ProjectCheckpoint]: + """Get project checkpoint by ID.""" + conn = self._get_connection() + try: + cursor = conn.execute( + "SELECT * FROM projects WHERE project_id = ?", + (project_id,) + ) + row = cursor.fetchone() + + if row is None: + return None + + return ProjectCheckpoint( + project_id=row[0], + storyboard_path=row[1], + output_dir=row[2], + status=ProjectStatus(row[3]), + backend_name=row[4], + started_at=row[5], + completed_at=row[6], + error_message=row[7] + ) + finally: + conn.close() + + def update_project_status( + self, + project_id: str, + status: ProjectStatus, + error_message: Optional[str] = None + ): + """Update project status.""" + completed_at = None + if status in [ProjectStatus.COMPLETED, ProjectStatus.FAILED]: + completed_at = datetime.now().isoformat() + + conn = self._get_connection() + try: + conn.execute( + """ + UPDATE projects + SET status = ?, error_message = ?, completed_at = ? + WHERE project_id = ? + """, + (status.value, error_message, completed_at, project_id) + ) + conn.commit() + finally: + conn.close() + + def create_shot( + self, + project_id: str, + shot_id: str, + status: ShotStatus = ShotStatus.PENDING + ) -> ShotCheckpoint: + """Create a shot checkpoint.""" + checkpoint = ShotCheckpoint( + shot_id=shot_id, + status=status + ) + + conn = self._get_connection() + try: + conn.execute( + """ + INSERT INTO shots + (project_id, shot_id, status, metadata) + VALUES (?, ?, ?, ?) + """, + ( + project_id, + shot_id, + status.value, + json.dumps(checkpoint.metadata) + ) + ) + conn.commit() + finally: + conn.close() + + return checkpoint + + def get_shot(self, project_id: str, shot_id: str) -> Optional[ShotCheckpoint]: + """Get shot checkpoint.""" + conn = self._get_connection() + try: + cursor = conn.execute( + "SELECT * FROM shots WHERE project_id = ? AND shot_id = ?", + (project_id, shot_id) + ) + row = cursor.fetchone() + + if row is None: + return None + + return ShotCheckpoint( + shot_id=row[1], + status=ShotStatus(row[2]), + output_path=row[3], + error_message=row[4], + started_at=row[5], + completed_at=row[6], + metadata=json.loads(row[7]) if row[7] else {} + ) + finally: + conn.close() + + def get_project_shots(self, project_id: str) -> List[ShotCheckpoint]: + """Get all shots for a project.""" + conn = self._get_connection() + try: + cursor = conn.execute( + "SELECT * FROM shots WHERE project_id = ? ORDER BY shot_id", + (project_id,) + ) + rows = cursor.fetchall() + + return [ + ShotCheckpoint( + shot_id=row[1], + status=ShotStatus(row[2]), + output_path=row[3], + error_message=row[4], + started_at=row[5], + completed_at=row[6], + metadata=json.loads(row[7]) if row[7] else {} + ) + for row in rows + ] + finally: + conn.close() + + def update_shot( + self, + project_id: str, + shot_id: str, + status: Optional[ShotStatus] = None, + output_path: Optional[str] = None, + error_message: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None + ): + """Update shot checkpoint.""" + updates = [] + params = [] + + if status is not None: + updates.append("status = ?") + params.append(status.value) + + if status == ShotStatus.IN_PROGRESS: + updates.append("started_at = ?") + params.append(datetime.now().isoformat()) + elif status in [ShotStatus.COMPLETED, ShotStatus.FAILED]: + updates.append("completed_at = ?") + params.append(datetime.now().isoformat()) + + if output_path is not None: + updates.append("output_path = ?") + params.append(output_path) + + if error_message is not None: + updates.append("error_message = ?") + params.append(error_message) + + if metadata is not None: + updates.append("metadata = ?") + params.append(json.dumps(metadata)) + + if not updates: + return + + params.extend([project_id, shot_id]) + + conn = self._get_connection() + try: + conn.execute( + f""" + UPDATE shots + SET {', '.join(updates)} + WHERE project_id = ? AND shot_id = ? + """, + params + ) + conn.commit() + finally: + conn.close() + + def get_pending_shots(self, project_id: str) -> List[str]: + """Get list of pending shot IDs for a project.""" + conn = self._get_connection() + try: + cursor = conn.execute( + """ + SELECT shot_id FROM shots + WHERE project_id = ? AND status = ? + ORDER BY shot_id + """, + (project_id, ShotStatus.PENDING.value) + ) + return [row[0] for row in cursor.fetchall()] + finally: + conn.close() + + def get_failed_shots(self, project_id: str) -> List[str]: + """Get list of failed shot IDs for a project.""" + conn = self._get_connection() + try: + cursor = conn.execute( + """ + SELECT shot_id FROM shots + WHERE project_id = ? AND status = ? + ORDER BY shot_id + """, + (project_id, ShotStatus.FAILED.value) + ) + return [row[0] for row in cursor.fetchall()] + finally: + conn.close() + + def can_resume(self, project_id: str) -> bool: + """Check if a project can be resumed.""" + project = self.get_project(project_id) + if project is None: + return False + + if project.status == ProjectStatus.COMPLETED: + return False + + pending = self.get_pending_shots(project_id) + failed = self.get_failed_shots(project_id) + + return len(pending) > 0 or len(failed) > 0 + + def delete_project(self, project_id: str): + """Delete a project and all its shots.""" + conn = self._get_connection() + try: + conn.execute("DELETE FROM shots WHERE project_id = ?", (project_id,)) + conn.execute("DELETE FROM projects WHERE project_id = ?", (project_id,)) + conn.commit() + finally: + conn.close() + + def list_projects(self) -> List[ProjectCheckpoint]: + """List all projects.""" + conn = self._get_connection() + try: + cursor = conn.execute( + "SELECT * FROM projects ORDER BY started_at DESC" + ) + rows = cursor.fetchall() + + return [ + ProjectCheckpoint( + project_id=row[0], + storyboard_path=row[1], + output_dir=row[2], + status=ProjectStatus(row[3]), + backend_name=row[4], + started_at=row[5], + completed_at=row[6], + error_message=row[7] + ) + for row in rows + ] + finally: + conn.close() diff --git a/src/core/config.py b/src/core/config.py new file mode 100644 index 0000000..ee39309 --- /dev/null +++ b/src/core/config.py @@ -0,0 +1,212 @@ +""" +Configuration management system. +Handles YAML configs and environment variables. +""" + +import os +import yaml +from pathlib import Path +from typing import Any, Dict, Optional +from dataclasses import dataclass, field + + +@dataclass +class BackendConfig: + """Configuration for a generation backend.""" + name: str + model_class: str + model_id: str + vram_gb: int + dtype: str + enable_vae_slicing: bool = True + enable_vae_tiling: bool = False + chunking_enabled: bool = True + chunking_mode: str = "sequential" + max_chunk_seconds: int = 4 + overlap_seconds: int = 1 + + +@dataclass +class Config: + """Application configuration.""" + active_backend: str = "wan_t2v_14b" + backends: Dict[str, BackendConfig] = field(default_factory=dict) + defaults: Dict[str, Any] = field(default_factory=dict) + checkpoint_db: str = "outputs/checkpoints.db" + model_cache_dir: str = "~/.cache/storyboard-video/models" + log_level: str = "INFO" + output_base_dir: str = "./outputs" + + def get_backend(self, name: Optional[str] = None) -> BackendConfig: + """Get backend configuration by name.""" + name = name or self.active_backend + if name not in self.backends: + raise ValueError(f"Backend '{name}' not found in configuration") + return self.backends[name] + + +class ConfigLoader: + """Loads configuration from YAML files and environment variables.""" + + @staticmethod + def load(config_path: Optional[Path] = None, env_file: Optional[Path] = None) -> Config: + """ + Load configuration from files and environment. + + Priority (highest to lowest): + 1. Environment variables + 2. .env file + 3. YAML config file + 4. Defaults + + Args: + config_path: Path to YAML config file (default: config/models.yaml) + env_file: Path to .env file (default: .env if exists) + + Returns: + Config object + """ + # Start with defaults + config = Config() + + # Load YAML config + if config_path is None: + config_path = Path("config/models.yaml") + + if config_path.exists(): + config = ConfigLoader._load_yaml(config_path, config) + + # Load .env file + if env_file is None: + env_file = Path(".env") + + if env_file.exists(): + config = ConfigLoader._load_env_file(env_file, config) + + # Override with environment variables + config = ConfigLoader._load_env_vars(config) + + # Expand paths + config.model_cache_dir = str(Path(config.model_cache_dir).expanduser()) + config.checkpoint_db = str(Path(config.checkpoint_db).expanduser()) + config.output_base_dir = str(Path(config.output_base_dir).expanduser()) + + return config + + @staticmethod + def _load_yaml(path: Path, config: Config) -> Config: + """Load configuration from YAML file.""" + with open(path, 'r') as f: + data = yaml.safe_load(f) + + if not data: + return config + + # Load backends + if 'backends' in data: + for backend_id, backend_data in data['backends'].items(): + chunking = backend_data.get('chunking', {}) + config.backends[backend_id] = BackendConfig( + name=backend_data.get('name', backend_id), + model_class=backend_data.get('class', ''), + model_id=backend_data.get('model_id', ''), + vram_gb=backend_data.get('vram_gb', 12), + dtype=backend_data.get('dtype', 'fp16'), + enable_vae_slicing=backend_data.get('enable_vae_slicing', True), + enable_vae_tiling=backend_data.get('enable_vae_tiling', False), + chunking_enabled=chunking.get('enabled', True), + chunking_mode=chunking.get('mode', 'sequential'), + max_chunk_seconds=chunking.get('max_chunk_seconds', 4), + overlap_seconds=chunking.get('overlap_seconds', 1) + ) + + # Load active backend + if 'active_backend' in data: + config.active_backend = data['active_backend'] + + # Load defaults + if 'defaults' in data: + config.defaults = data['defaults'] + + # Load checkpoint DB path + if 'checkpoint_db' in data: + config.checkpoint_db = data['checkpoint_db'] + + # Load model cache dir + if 'model_cache_dir' in data: + config.model_cache_dir = data['model_cache_dir'] + + return config + + @staticmethod + def _load_env_file(path: Path, config: Config) -> Config: + """Load configuration from .env file.""" + env_vars = {} + with open(path, 'r') as f: + for line in f: + line = line.strip() + if not line or line.startswith('#'): + continue + + if '=' in line: + key, value = line.split('=', 1) + key = key.strip() + value = value.strip().strip('"\'') + env_vars[key] = value + + # Apply env vars directly to config (don't set in os.environ) + if 'ACTIVE_BACKEND' in env_vars: + config.active_backend = env_vars['ACTIVE_BACKEND'] + + if 'MODEL_CACHE_DIR' in env_vars: + config.model_cache_dir = env_vars['MODEL_CACHE_DIR'] + + if 'LOG_LEVEL' in env_vars: + config.log_level = env_vars['LOG_LEVEL'] + + if 'OUTPUT_BASE_DIR' in env_vars: + config.output_base_dir = env_vars['OUTPUT_BASE_DIR'] + + if 'CHECKPOINT_DB' in env_vars: + config.checkpoint_db = env_vars['CHECKPOINT_DB'] + + return config + + @staticmethod + def _load_env_vars(config: Config) -> Config: + """Override config with environment variables.""" + if 'ACTIVE_BACKEND' in os.environ: + config.active_backend = os.environ['ACTIVE_BACKEND'] + + if 'MODEL_CACHE_DIR' in os.environ: + config.model_cache_dir = os.environ['MODEL_CACHE_DIR'] + + if 'LOG_LEVEL' in os.environ: + config.log_level = os.environ['LOG_LEVEL'] + + if 'OUTPUT_BASE_DIR' in os.environ: + config.output_base_dir = os.environ['OUTPUT_BASE_DIR'] + + if 'CHECKPOINT_DB' in os.environ: + config.checkpoint_db = os.environ['CHECKPOINT_DB'] + + return config + + +# Global config instance (lazy-loaded) +_config: Optional[Config] = None + + +def get_config() -> Config: + """Get the global configuration instance.""" + global _config + if _config is None: + _config = ConfigLoader.load() + return _config + + +def reload_config() -> Config: + """Reload configuration from files.""" + global _config + _config = ConfigLoader.load() + return _config diff --git a/src/generation/__init__.py b/src/generation/__init__.py new file mode 100644 index 0000000..4c29321 --- /dev/null +++ b/src/generation/__init__.py @@ -0,0 +1,17 @@ +""" +Video generation backends. +""" + +from .base import ( + BaseVideoBackend, + GenerationResult, + GenerationSpec, + BackendFactory +) + +__all__ = [ + 'BaseVideoBackend', + 'GenerationResult', + 'GenerationSpec', + 'BackendFactory' +] diff --git a/src/generation/base.py b/src/generation/base.py new file mode 100644 index 0000000..ee70e96 --- /dev/null +++ b/src/generation/base.py @@ -0,0 +1,235 @@ +""" +Base interface for video generation backends. +All backends (WAN, SVD, etc.) must implement this interface. +""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional, Dict, Any, List +from dataclasses import dataclass + + +@dataclass +class GenerationResult: + """Result of a video generation.""" + success: bool + output_path: Optional[Path] = None + metadata: Optional[Dict[str, Any]] = None + error_message: Optional[str] = None + vram_usage_gb: Optional[float] = None + generation_time_s: Optional[float] = None + + def __post_init__(self): + if self.metadata is None: + self.metadata = {} + + +@dataclass +class GenerationSpec: + """Specification for video generation.""" + prompt: str + negative_prompt: str + width: int + height: int + num_frames: int + fps: int + seed: int + steps: int + cfg_scale: float + output_path: Path + init_frame_path: Optional[Path] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for metadata storage.""" + return { + "prompt": self.prompt, + "negative_prompt": self.negative_prompt, + "width": self.width, + "height": self.height, + "num_frames": self.num_frames, + "fps": self.fps, + "seed": self.seed, + "steps": self.steps, + "cfg_scale": self.cfg_scale, + "output_path": str(self.output_path), + "init_frame_path": str(self.init_frame_path) if self.init_frame_path else None + } + + +class BaseVideoBackend(ABC): + """ + Abstract base class for video generation backends. + + All backends must implement this interface to be used in the pipeline. + """ + + def __init__(self, config: Dict[str, Any]): + """ + Initialize the backend with configuration. + + Args: + config: Backend-specific configuration dictionary + """ + self.config = config + self.model = None + self._is_loaded = False + + @property + @abstractmethod + def name(self) -> str: + """Return the backend name.""" + pass + + @property + @abstractmethod + def supports_chunking(self) -> bool: + """Whether this backend supports chunked generation.""" + pass + + @property + @abstractmethod + def supports_init_frame(self) -> bool: + """Whether this backend supports init frame conditioning.""" + pass + + @abstractmethod + def load(self) -> None: + """ + Load the model into memory. + + This should be called before generate(). + Implementations should set self._is_loaded = True when complete. + """ + pass + + @abstractmethod + def unload(self) -> None: + """ + Unload the model from memory. + + Implementations should set self._is_loaded = False when complete. + """ + pass + + @abstractmethod + def generate(self, spec: GenerationSpec) -> GenerationResult: + """ + Generate a video clip. + + Args: + spec: Generation specification + + Returns: + GenerationResult with output path and metadata + """ + pass + + def generate_chunked( + self, + spec: GenerationSpec, + chunk_duration_s: int, + overlap_s: int = 0, + mode: str = "sequential" + ) -> GenerationResult: + """ + Generate a video in chunks. + + Default implementation for backends that don't natively support chunking. + Backends with native support should override this. + + Args: + spec: Generation specification + chunk_duration_s: Duration of each chunk in seconds + overlap_s: Overlap between chunks in seconds (for blending) + mode: "sequential" or "overlapping" + + Returns: + GenerationResult with concatenated video + """ + if not self.supports_chunking: + # Fall back to single generation + return self.generate(spec) + + raise NotImplementedError( + "Chunked generation not implemented for this backend" + ) + + def estimate_vram_usage(self, width: int, height: int, num_frames: int) -> float: + """ + Estimate VRAM usage for a generation. + + Args: + width: Video width + height: Video height + num_frames: Number of frames + + Returns: + Estimated VRAM usage in GB + """ + # Default implementation - backends should override with better estimates + return self.config.get("vram_gb", 12.0) + + def check_vram_available(self, width: int, height: int, num_frames: int) -> bool: + """ + Check if sufficient VRAM is available. + + Args: + width: Video width + height: Video height + num_frames: Number of frames + + Returns: + True if generation should fit in VRAM + """ + try: + import torch + if torch.cuda.is_available(): + total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3) + estimated = self.estimate_vram_usage(width, height, num_frames) + # Leave 10% headroom + return estimated < total_vram * 0.9 + except: + pass + + # If we can't check, assume it's OK + return True + + @property + def is_loaded(self) -> bool: + """Check if the model is loaded.""" + return self._is_loaded + + def __enter__(self): + """Context manager entry.""" + self.load() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.unload() + + +class BackendFactory: + """Factory for creating backend instances.""" + + _backends: Dict[str, type] = {} + + @classmethod + def register(cls, name: str, backend_class: type): + """Register a backend class.""" + if not issubclass(backend_class, BaseVideoBackend): + raise ValueError(f"Backend must inherit from BaseVideoBackend") + cls._backends[name] = backend_class + + @classmethod + def create(cls, name: str, config: Dict[str, Any]) -> BaseVideoBackend: + """Create a backend instance.""" + if name not in cls._backends: + raise ValueError(f"Unknown backend: {name}. " + f"Available: {list(cls._backends.keys())}") + return cls._backends[name](config) + + @classmethod + def list_backends(cls) -> List[str]: + """List available backend names.""" + return list(cls._backends.keys()) diff --git a/src/storyboard/__init__.py b/src/storyboard/__init__.py new file mode 100644 index 0000000..a64b7a2 --- /dev/null +++ b/src/storyboard/__init__.py @@ -0,0 +1,39 @@ +""" +Storyboard module for validation, loading, and prompt compilation. +""" + +from .schema import ( + Storyboard, + ProjectSettings, + Shot, + Character, + Location, + GlobalStyle, + CameraSettings, + GenerationSettings, + OutputSettings, + Resolution +) +from .loader import StoryboardValidator, StoryboardLoadError +from .prompt_compiler import PromptCompiler, CompiledPrompt +from .shot_planner import ShotPlanner, ShotPlan, Chunk + +__all__ = [ + 'Storyboard', + 'ProjectSettings', + 'Shot', + 'Character', + 'Location', + 'GlobalStyle', + 'CameraSettings', + 'GenerationSettings', + 'OutputSettings', + 'Resolution', + 'StoryboardValidator', + 'StoryboardLoadError', + 'PromptCompiler', + 'CompiledPrompt', + 'ShotPlanner', + 'ShotPlan', + 'Chunk' +] diff --git a/src/storyboard/loader.py b/src/storyboard/loader.py new file mode 100644 index 0000000..d98314f --- /dev/null +++ b/src/storyboard/loader.py @@ -0,0 +1,84 @@ +""" +Storyboard loader and validator. +Handles loading JSON files and validating against schema. +""" + +import json +from pathlib import Path +from typing import Union + +from .schema import Storyboard + + +class StoryboardLoadError(Exception): + """Raised when storyboard loading fails.""" + pass + + +class StoryboardValidator: + """Validates and loads storyboard JSON files.""" + + @staticmethod + def load(path: Union[str, Path]) -> Storyboard: + """ + Load and validate a storyboard JSON file. + + Args: + path: Path to the JSON file + + Returns: + Validated Storyboard object + + Raises: + StoryboardLoadError: If file doesn't exist or is invalid + """ + path = Path(path) + + if not path.exists(): + raise StoryboardLoadError(f"Storyboard file not found: {path}") + + if not path.suffix.lower() == '.json': + raise StoryboardLoadError(f"Storyboard file must be JSON: {path}") + + try: + with open(path, 'r', encoding='utf-8') as f: + data = json.load(f) + except json.JSONDecodeError as e: + raise StoryboardLoadError(f"Invalid JSON in {path}: {e}") + except Exception as e: + raise StoryboardLoadError(f"Failed to read {path}: {e}") + + try: + storyboard = Storyboard(**data) + except Exception as e: + raise StoryboardLoadError(f"Storyboard validation failed: {e}") + + return storyboard + + @staticmethod + def validate_references(storyboard: Storyboard) -> list[str]: + """ + Validate that all references (location_id, characters) exist. + + Args: + storyboard: Loaded storyboard + + Returns: + List of validation warnings/errors + """ + issues = [] + + # Validate character references + char_ids = {char.id for char in storyboard.characters} + for shot in storyboard.shots: + for char_id in shot.characters: + if char_id not in char_ids: + issues.append(f"Shot {shot.id}: Character '{char_id}' not defined") + + # Validate location references + loc_ids = {loc.id for loc in storyboard.locations} + for shot in storyboard.shots: + if shot.location_id and shot.location_id not in loc_ids: + issues.append(f"Shot {shot.id}: Location '{shot.location_id}' not defined") + + return issues diff --git a/src/storyboard/prompt_compiler.py b/src/storyboard/prompt_compiler.py new file mode 100644 index 0000000..4d64f07 --- /dev/null +++ b/src/storyboard/prompt_compiler.py @@ -0,0 +1,154 @@ +""" +Prompt compiler for merging global style with shot-specific prompts. +""" + +from typing import List, Optional +from dataclasses import dataclass + +from ..storyboard.schema import Storyboard, Shot, GlobalStyle + + +@dataclass +class CompiledPrompt: + """Compiled prompt ready for generation.""" + positive: str + negative: str + camera_notes: str + full_prompt: str # Combined positive + style + camera + + def to_dict(self) -> dict: + """Convert to dictionary for metadata storage.""" + return { + "positive": self.positive, + "negative": self.negative, + "camera_notes": self.camera_notes, + "full_prompt": self.full_prompt + } + + +class PromptCompiler: + """ + Compiles prompts by merging global style with shot-specific content. + """ + + def __init__(self, storyboard: Storyboard): + """ + Initialize with storyboard. + + Args: + storyboard: Loaded and validated storyboard + """ + self.storyboard = storyboard + self.global_style = storyboard.project.global_style + + def compile_shot(self, shot: Shot) -> CompiledPrompt: + """ + Compile prompt for a specific shot. + + Args: + shot: Shot to compile prompt for + + Returns: + CompiledPrompt with merged content + """ + # Build character descriptions + char_descriptions = [] + for char_id in shot.characters: + char = self.storyboard.get_character(char_id) + if char: + char_descriptions.append(f"{char.name}: {char.description}") + + # Build location description + location_desc = "" + if shot.location_id: + loc = self.storyboard.get_location(shot.location_id) + if loc: + location_desc = loc.description + + # Compile positive prompt + positive_parts = [] + + # Add global visual style first (sets the aesthetic) + if self.global_style.visual_style: + positive_parts.append(self.global_style.visual_style) + + # Add location context + if location_desc: + positive_parts.append(f"Setting: {location_desc}") + + # Add character descriptions + if char_descriptions: + positive_parts.append("Characters: " + "; ".join(char_descriptions)) + + # Add the main shot prompt + positive_parts.append(shot.prompt) + + # Add camera/lens info + camera_parts = [] + if shot.camera.framing: + camera_parts.append(shot.camera.framing) + if self.global_style.lens: + camera_parts.append(self.global_style.lens) + if shot.camera.movement: + camera_parts.append(shot.camera.movement) + if self.global_style.motion_style: + camera_parts.append(self.global_style.motion_style) + if shot.camera.notes: + camera_parts.append(shot.camera.notes) + + if camera_parts: + positive_parts.append(f"Camera: {', '.join(camera_parts)}") + + # Add lighting + if self.global_style.lighting: + positive_parts.append(f"Lighting: {self.global_style.lighting}") + + # Add color grade + if self.global_style.color_grade: + positive_parts.append(f"Color: {self.global_style.color_grade}") + + positive = ". ".join(positive_parts) + + # Compile negative prompt + negative_parts = [] + if self.global_style.negative_prompt: + negative_parts.append(self.global_style.negative_prompt) + + negative = ", ".join(negative_parts) if negative_parts else "" + + # Camera notes for metadata + camera_notes = "; ".join(camera_parts) if camera_parts else "" + + # Full combined prompt (for reference) + full = positive + + return CompiledPrompt( + positive=positive, + negative=negative, + camera_notes=camera_notes, + full_prompt=full + ) + + def compile_all(self) -> List[CompiledPrompt]: + """ + Compile prompts for all shots. + + Returns: + List of CompiledPrompt objects in shot order + """ + return [self.compile_shot(shot) for shot in self.storyboard.shots] + + def get_shot_prompt(self, shot_id: str) -> Optional[CompiledPrompt]: + """ + Get compiled prompt for a specific shot by ID. + + Args: + shot_id: Shot identifier + + Returns: + CompiledPrompt or None if shot not found + """ + for shot in self.storyboard.shots: + if shot.id == shot_id: + return self.compile_shot(shot) + return None diff --git a/src/storyboard/schema.py b/src/storyboard/schema.py new file mode 100644 index 0000000..4413ded --- /dev/null +++ b/src/storyboard/schema.py @@ -0,0 +1,159 @@ +""" +Storyboard schema validation using Pydantic. +Defines the data models for storyboard JSON files. +""" + +from typing import List, Optional, Literal +from pydantic import BaseModel, Field, validator + + +class Resolution(BaseModel): + """Video resolution dimensions.""" + width: int = Field(..., ge=256, le=4096, description="Width in pixels") + height: int = Field(..., ge=256, le=4096, description="Height in pixels") + + +class GlobalStyle(BaseModel): + """Global visual style settings applied to all shots.""" + visual_style: str = Field(default="", description="Overall visual aesthetic") + color_grade: str = Field(default="", description="Color grading description") + lens: str = Field(default="", description="Lens type/characteristics") + lighting: str = Field(default="", description="Lighting setup description") + motion_style: str = Field(default="", description="Camera motion characteristics") + negative_prompt: str = Field(default="", description="What to avoid in generation") + + +class AudioSettings(BaseModel): + """Audio configuration for the final video.""" + add_music: bool = Field(default=False) + music_path: str = Field(default="") + add_voiceover: bool = Field(default=False) + voiceover_path: str = Field(default="") + + +class ProjectSettings(BaseModel): + """Top-level project configuration.""" + title: str = Field(..., min_length=1, description="Project title") + fps: int = Field(default=24, ge=1, le=120, description="Frames per second") + target_duration_s: int = Field(default=20, ge=1, le=300, description="Target duration in seconds") + resolution: Resolution + aspect_ratio: str = Field(default="16:9", description="Aspect ratio (e.g., 16:9, 4:3)") + global_style: GlobalStyle = Field(default_factory=GlobalStyle) + audio: AudioSettings = Field(default_factory=AudioSettings) + + +class Character(BaseModel): + """Character definition for consistency across shots.""" + id: str = Field(..., min_length=1, description="Unique character identifier") + name: str = Field(..., min_length=1, description="Character name") + description: str = Field(default="", description="Visual description") + consistency_notes: str = Field(default="", description="Notes for maintaining consistency") + + +class Location(BaseModel): + """Location/setting definition.""" + id: str = Field(..., min_length=1, description="Unique location identifier") + name: str = Field(..., min_length=1, description="Location name") + description: str = Field(default="", description="Visual description of the setting") + + +class CameraSettings(BaseModel): + """Camera configuration for a shot.""" + framing: str = Field(default="", description="Shot framing (wide, medium, close-up, etc.)") + movement: str = Field(default="", description="Camera movement description") + notes: str = Field(default="", description="Additional camera notes") + + +class GenerationSettings(BaseModel): + """AI generation parameters for a shot.""" + seed: int = Field(default=-1, description="Random seed (-1 for random)") + steps: int = Field(default=30, ge=1, le=100, description="Inference steps") + cfg_scale: float = Field(default=6.0, ge=1.0, le=20.0, description="Classifier-free guidance scale") + sampler: str = Field(default="default", description="Sampler type") + chunk_seconds: int = Field(default=4, ge=1, le=10, description="Duration per chunk in seconds") + use_init_frame_from_prev: bool = Field(default=False, description="Use last frame of previous shot as init") + + +class Shot(BaseModel): + """Individual shot definition.""" + id: str = Field(..., min_length=1, description="Unique shot identifier") + duration_s: int = Field(..., ge=1, le=60, description="Shot duration in seconds") + location_id: Optional[str] = Field(default=None, description="Reference to location") + characters: List[str] = Field(default_factory=list, description="List of character IDs") + prompt: str = Field(..., min_length=1, description="Generation prompt for this shot") + camera: CameraSettings = Field(default_factory=CameraSettings) + generation: GenerationSettings = Field(default_factory=GenerationSettings) + + +class UpscaleSettings(BaseModel): + """Post-processing upscaling configuration.""" + enabled: bool = Field(default=False) + target_height: int = Field(default=2160, ge=720, le=4320, description="Target height in pixels") + method: str = Field(default="default", description="Upscaling method") + + +class OutputSettings(BaseModel): + """Final video output configuration.""" + container: Literal["mp4", "mov", "mkv"] = Field(default="mp4") + codec: Literal["h264", "h265", "vp9"] = Field(default="h264") + crf: int = Field(default=18, ge=0, le=51, description="Constant rate factor (lower = higher quality)") + preset: str = Field(default="medium", description="Encoding preset") + upscale: UpscaleSettings = Field(default_factory=UpscaleSettings) + + +class Storyboard(BaseModel): + """Complete storyboard document.""" + schema_version: str = Field(default="1.0", description="Schema version for compatibility") + project: ProjectSettings + characters: List[Character] = Field(default_factory=list) + locations: List[Location] = Field(default_factory=list) + shots: List[Shot] = Field(..., description="List of shots to generate") + output: OutputSettings = Field(default_factory=OutputSettings) + + @validator('shots') + def validate_shot_ids_unique(cls, v): + """Ensure all shot IDs are unique and at least one shot exists.""" + if len(v) < 1: + raise ValueError("At least one shot is required") + ids = [shot.id for shot in v] + if len(ids) != len(set(ids)): + raise ValueError("Shot IDs must be unique") + return v + + @validator('characters') + def validate_character_ids_unique(cls, v): + """Ensure all character IDs are unique.""" + ids = [char.id for char in v] + if len(ids) != len(set(ids)): + raise ValueError("Character IDs must be unique") + return v + + @validator('locations') + def validate_location_ids_unique(cls, v): + """Ensure all location IDs are unique.""" + ids = [loc.id for loc in v] + if len(ids) != len(set(ids)): + raise ValueError("Location IDs must be unique") + return v + + def get_character(self, char_id: str) -> Optional[Character]: + """Get character by ID.""" + for char in self.characters: + if char.id == char_id: + return char + return None + + def get_location(self, loc_id: str) -> Optional[Location]: + """Get location by ID.""" + for loc in self.locations: + if loc.id == loc_id: + return loc + return None + + def get_total_duration(self) -> int: + """Calculate total duration in seconds.""" + return sum(shot.duration_s for shot in self.shots) + + def get_total_frames(self) -> int: + """Calculate total frame count.""" + return self.get_total_duration() * self.project.fps diff --git a/src/storyboard/shot_planner.py b/src/storyboard/shot_planner.py new file mode 100644 index 0000000..51c9276 --- /dev/null +++ b/src/storyboard/shot_planner.py @@ -0,0 +1,252 @@ +""" +Shot planner for converting duration to frames and chunking strategy. +""" + +from typing import List, Dict, Any, Optional +from dataclasses import dataclass +from pathlib import Path + +from ..storyboard.schema import Storyboard, Shot + + +@dataclass +class Chunk: + """A single chunk of a shot.""" + chunk_index: int + start_frame: int + end_frame: int + num_frames: int + duration_s: float + use_init_frame: bool + init_frame_source: Optional[Path] # Path to frame file or previous chunk + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for metadata.""" + return { + "chunk_index": self.chunk_index, + "start_frame": self.start_frame, + "end_frame": self.end_frame, + "num_frames": self.num_frames, + "duration_s": self.duration_s, + "use_init_frame": self.use_init_frame, + "init_frame_source": str(self.init_frame_source) if self.init_frame_source else None + } + + +@dataclass +class ShotPlan: + """Complete plan for generating a shot.""" + shot_id: str + total_frames: int + total_duration_s: int + fps: int + chunks: List[Chunk] + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for metadata.""" + return { + "shot_id": self.shot_id, + "total_frames": self.total_frames, + "total_duration_s": self.total_duration_s, + "fps": self.fps, + "chunks": [c.to_dict() for c in self.chunks] + } + + +class ShotPlanner: + """ + Plans shot generation including frame counts and chunking strategy. + """ + + def __init__( + self, + fps: int = 24, + max_chunk_seconds: int = 4, + overlap_seconds: int = 0, + chunking_mode: str = "sequential" + ): + """ + Initialize shot planner. + + Args: + fps: Frames per second + max_chunk_seconds: Maximum duration per chunk + overlap_seconds: Overlap between chunks (for blending mode) + chunking_mode: "sequential" or "overlapping" + """ + self.fps = fps + self.max_chunk_seconds = max_chunk_seconds + self.overlap_seconds = overlap_seconds + self.chunking_mode = chunking_mode + + def duration_to_frames(self, duration_s: int) -> int: + """ + Convert duration in seconds to frame count. + + Args: + duration_s: Duration in seconds + + Returns: + Number of frames + """ + return duration_s * self.fps + + def frames_to_duration(self, num_frames: int) -> float: + """ + Convert frame count to duration in seconds. + + Args: + num_frames: Number of frames + + Returns: + Duration in seconds + """ + return num_frames / self.fps + + def plan_shot( + self, + shot: Shot, + fps: Optional[int] = None, + use_init_from_prev: bool = False + ) -> ShotPlan: + """ + Create a generation plan for a shot. + + Args: + shot: Shot to plan + fps: Override FPS (uses storyboard default if None) + use_init_from_prev: Whether to use previous shot's last frame + + Returns: + ShotPlan with chunk breakdown + """ + fps = fps or self.fps + total_frames = shot.duration_s * fps + + # Determine chunk size + chunk_frames = self.max_chunk_seconds * fps + overlap_frames = self.overlap_seconds * fps if self.chunking_mode == "overlapping" else 0 + + chunks = [] + + if total_frames <= chunk_frames: + # Single chunk - no need to split + chunk = Chunk( + chunk_index=0, + start_frame=0, + end_frame=total_frames - 1, + num_frames=total_frames, + duration_s=shot.duration_s, + use_init_frame=use_init_from_prev and shot.generation.use_init_frame_from_prev, + init_frame_source=None # Will be set by pipeline + ) + chunks.append(chunk) + else: + # Multiple chunks needed + remaining_frames = total_frames + current_frame = 0 + chunk_index = 0 + + while remaining_frames > 0: + # Calculate frames for this chunk + if self.chunking_mode == "sequential": + # Sequential: Each chunk starts where previous ended + this_chunk_frames = min(chunk_frames, remaining_frames) + + chunk = Chunk( + chunk_index=chunk_index, + start_frame=current_frame, + end_frame=current_frame + this_chunk_frames - 1, + num_frames=this_chunk_frames, + duration_s=this_chunk_frames / fps, + use_init_frame=(chunk_index == 0 and use_init_from_prev and shot.generation.use_init_frame_from_prev) or chunk_index > 0, + init_frame_source=None + ) + chunks.append(chunk) + + current_frame += this_chunk_frames + remaining_frames -= this_chunk_frames + + else: # overlapping + # Overlapping: Chunks overlap for blending + if remaining_frames <= chunk_frames: + # Last chunk - take all remaining frames + this_chunk_frames = remaining_frames + else: + # Not last chunk - include overlap + this_chunk_frames = min(chunk_frames + overlap_frames, remaining_frames) + + chunk = Chunk( + chunk_index=chunk_index, + start_frame=current_frame, + end_frame=current_frame + this_chunk_frames - 1, + num_frames=this_chunk_frames, + duration_s=this_chunk_frames / fps, + use_init_frame=(chunk_index == 0 and use_init_from_prev and shot.generation.use_init_frame_from_prev) or chunk_index > 0, + init_frame_source=None + ) + chunks.append(chunk) + + # Move forward by chunk size (minus overlap) + current_frame += chunk_frames - overlap_frames + remaining_frames -= chunk_frames - overlap_frames + + chunk_index += 1 + + return ShotPlan( + shot_id=shot.id, + total_frames=total_frames, + total_duration_s=shot.duration_s, + fps=fps, + chunks=chunks + ) + + def plan_storyboard( + self, + storyboard: Storyboard, + override_chunk_config: Optional[Dict[str, Any]] = None + ) -> List[ShotPlan]: + """ + Create generation plans for all shots in a storyboard. + + Args: + storyboard: Storyboard to plan + override_chunk_config: Optional config to override defaults + + Returns: + List of ShotPlan objects + """ + # Apply overrides if provided + fps = override_chunk_config.get("fps", storyboard.project.fps) if override_chunk_config else storyboard.project.fps + max_chunk = override_chunk_config.get("max_chunk_seconds", self.max_chunk_seconds) if override_chunk_config else self.max_chunk_seconds + overlap = override_chunk_config.get("overlap_seconds", self.overlap_seconds) if override_chunk_config else self.overlap_seconds + mode = override_chunk_config.get("mode", self.chunking_mode) if override_chunk_config else self.chunking_mode + + planner = ShotPlanner( + fps=fps, + max_chunk_seconds=max_chunk, + overlap_seconds=overlap, + chunking_mode=mode + ) + + plans = [] + prev_shot_id = None + + for i, shot in enumerate(storyboard.shots): + # Determine if we should use init frame from previous shot + use_init = i > 0 and shot.generation.use_init_frame_from_prev + + plan = planner.plan_shot(shot, fps=fps, use_init_from_prev=use_init) + plans.append(plan) + + prev_shot_id = shot.id + + return plans + + def get_total_frames(self, plans: List[ShotPlan]) -> int: + """Calculate total frames across all plans.""" + return sum(plan.total_frames for plan in plans) + + def get_total_duration(self, plans: List[ShotPlan]) -> float: + """Calculate total duration in seconds across all plans.""" + return sum(plan.total_duration_s for plan in plans) diff --git a/src/upscaling/__init__.py b/src/upscaling/__init__.py new file mode 100644 index 0000000..19b04e7 --- /dev/null +++ b/src/upscaling/__init__.py @@ -0,0 +1,3 @@ +""" +Upscaling module. +""" diff --git a/tests/unit/test_checkpoint.py b/tests/unit/test_checkpoint.py new file mode 100644 index 0000000..659f01b --- /dev/null +++ b/tests/unit/test_checkpoint.py @@ -0,0 +1,189 @@ +""" +Tests for checkpoint system. +""" + +import pytest +import tempfile +import os +from pathlib import Path + +from src.core.checkpoint import ( + CheckpointManager, ShotCheckpoint, ProjectCheckpoint, + ShotStatus, ProjectStatus +) + + +class TestCheckpointManager: + """Test checkpoint management.""" + + @pytest.fixture + def temp_db(self): + """Create a temporary database for testing.""" + with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f: + db_path = f.name + + yield db_path + + # Cleanup + if os.path.exists(db_path): + os.unlink(db_path) + + @pytest.fixture + def manager(self, temp_db): + """Create a checkpoint manager with temp database.""" + return CheckpointManager(db_path=temp_db) + + def test_create_project(self, manager): + """Test creating a project checkpoint.""" + checkpoint = manager.create_project( + project_id="test_001", + storyboard_path="/path/to/storyboard.json", + output_dir="/path/to/output", + backend_name="wan_t2v_14b" + ) + + assert checkpoint.project_id == "test_001" + assert checkpoint.status == ProjectStatus.INITIALIZED + assert checkpoint.backend_name == "wan_t2v_14b" + + def test_get_project(self, manager): + """Test retrieving a project checkpoint.""" + manager.create_project( + project_id="test_001", + storyboard_path="/path/to/storyboard.json", + output_dir="/path/to/output", + backend_name="wan_t2v_14b" + ) + + retrieved = manager.get_project("test_001") + assert retrieved is not None + assert retrieved.project_id == "test_001" + + missing = manager.get_project("nonexistent") + assert missing is None + + def test_update_project_status(self, manager): + """Test updating project status.""" + manager.create_project( + project_id="test_001", + storyboard_path="/path/to/storyboard.json", + output_dir="/path/to/output", + backend_name="wan_t2v_14b" + ) + + manager.update_project_status("test_001", ProjectStatus.RUNNING) + + retrieved = manager.get_project("test_001") + assert retrieved.status == ProjectStatus.RUNNING + + def test_create_shot(self, manager): + """Test creating a shot checkpoint.""" + manager.create_project( + project_id="test_001", + storyboard_path="/path/storyboard.json", + output_dir="/path/output", + backend_name="wan" + ) + + shot = manager.create_shot("test_001", "S01", ShotStatus.PENDING) + + assert shot.shot_id == "S01" + assert shot.status == ShotStatus.PENDING + + def test_update_shot(self, manager): + """Test updating shot checkpoint.""" + manager.create_project( + project_id="test_001", + storyboard_path="/path/storyboard.json", + output_dir="/path/output", + backend_name="wan" + ) + manager.create_shot("test_001", "S01") + + manager.update_shot( + "test_001", + "S01", + status=ShotStatus.IN_PROGRESS, + output_path="/path/output.mp4", + metadata={"seed": 12345} + ) + + shot = manager.get_shot("test_001", "S01") + assert shot.status == ShotStatus.IN_PROGRESS + assert shot.output_path == "/path/output.mp4" + assert shot.metadata["seed"] == 12345 + + def test_get_pending_shots(self, manager): + """Test retrieving pending shots.""" + manager.create_project( + project_id="test_001", + storyboard_path="/path/storyboard.json", + output_dir="/path/output", + backend_name="wan" + ) + + manager.create_shot("test_001", "S01", ShotStatus.PENDING) + manager.create_shot("test_001", "S02", ShotStatus.COMPLETED) + manager.create_shot("test_001", "S03", ShotStatus.PENDING) + + pending = manager.get_pending_shots("test_001") + assert len(pending) == 2 + assert "S01" in pending + assert "S03" in pending + + def test_can_resume(self, manager): + """Test checking if project can be resumed.""" + manager.create_project( + project_id="test_001", + storyboard_path="/path/storyboard.json", + output_dir="/path/output", + backend_name="wan" + ) + + # No shots yet - can't resume + assert not manager.can_resume("test_001") + + # Add pending shot - can resume + manager.create_shot("test_001", "S01", ShotStatus.PENDING) + assert manager.can_resume("test_001") + + # Complete all shots - can't resume + manager.update_shot("test_001", "S01", status=ShotStatus.COMPLETED) + assert not manager.can_resume("test_001") + + def test_delete_project(self, manager): + """Test deleting a project.""" + manager.create_project( + project_id="test_001", + storyboard_path="/path/storyboard.json", + output_dir="/path/output", + backend_name="wan" + ) + manager.create_shot("test_001", "S01") + + manager.delete_project("test_001") + + assert manager.get_project("test_001") is None + assert manager.get_shot("test_001", "S01") is None + + def test_list_projects(self, manager): + """Test listing all projects.""" + manager.create_project( + project_id="test_001", + storyboard_path="/path/1.json", + output_dir="/out/1", + backend_name="wan" + ) + manager.create_project( + project_id="test_002", + storyboard_path="/path/2.json", + output_dir="/out/2", + backend_name="svd" + ) + + projects = manager.list_projects() + assert len(projects) == 2 + + project_ids = [p.project_id for p in projects] + assert "test_001" in project_ids + assert "test_002" in project_ids diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..7f4aea3 --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,163 @@ +""" +Tests for configuration system. +""" + +import pytest +import tempfile +import os +from pathlib import Path + +from src.core.config import Config, BackendConfig, ConfigLoader, get_config, reload_config + + +class TestConfigLoader: + """Test configuration loading.""" + + def test_load_default_config(self): + """Test loading with no config files (uses defaults).""" + config = ConfigLoader.load( + config_path=Path("nonexistent.yaml"), + env_file=Path("nonexistent.env") + ) + + assert isinstance(config, Config) + assert config.active_backend == "wan_t2v_14b" + assert len(config.backends) == 0 # No YAML loaded + + def test_load_yaml_config(self): + """Test loading from YAML file.""" + yaml_content = """ +backends: + test_backend: + name: "Test Backend" + class: "test.TestBackend" + model_id: "test/model" + vram_gb: 8 + dtype: "fp16" + enable_vae_slicing: true + enable_vae_tiling: false + chunking: + enabled: true + mode: "sequential" + max_chunk_seconds: 4 + overlap_seconds: 1 + +active_backend: "test_backend" +defaults: + fps: 30 +checkpoint_db: "test.db" +model_cache_dir: "~/test_cache" +""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write(yaml_content) + yaml_path = f.name + + try: + config = ConfigLoader.load(config_path=Path(yaml_path)) + + assert config.active_backend == "test_backend" + assert "test_backend" in config.backends + + backend = config.get_backend("test_backend") + assert backend.name == "Test Backend" + assert backend.vram_gb == 8 + assert backend.chunking_mode == "sequential" + assert config.defaults["fps"] == 30 + finally: + os.unlink(yaml_path) + + def test_load_env_file(self, monkeypatch): + """Test loading from .env file.""" + # Clear any existing env vars first + for key in ['ACTIVE_BACKEND', 'MODEL_CACHE_DIR', 'LOG_LEVEL']: + monkeypatch.delenv(key, raising=False) + + env_content = """ +ACTIVE_BACKEND=env_backend +MODEL_CACHE_DIR=/env/cache +LOG_LEVEL=DEBUG +""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.env', delete=False) as f: + f.write(env_content) + env_path = f.name + + try: + config = ConfigLoader.load( + config_path=Path("nonexistent.yaml"), + env_file=Path(env_path) + ) + + assert config.active_backend == "env_backend" + assert "/env/cache" in config.model_cache_dir or "\\env\\cache" in config.model_cache_dir + assert config.log_level == "DEBUG" + finally: + os.unlink(env_path) + # Clean up env vars + for key in ['ACTIVE_BACKEND', 'MODEL_CACHE_DIR', 'LOG_LEVEL']: + monkeypatch.delenv(key, raising=False) + + def test_env_variable_override(self, monkeypatch): + """Test that environment variables override config files.""" + # Clear env var first + monkeypatch.delenv("ACTIVE_BACKEND", raising=False) + + yaml_content = """ +active_backend: yaml_backend +model_cache_dir: ~/yaml_cache +""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write(yaml_content) + yaml_path = f.name + + try: + monkeypatch.setenv("ACTIVE_BACKEND", "env_override") + + config = ConfigLoader.load(config_path=Path(yaml_path)) + + # Environment variable should override + assert config.active_backend == "env_override" + # But other settings from YAML should remain + assert "yaml_cache" in config.model_cache_dir + finally: + os.unlink(yaml_path) + monkeypatch.delenv("ACTIVE_BACKEND", raising=False) + + def test_path_expansion(self): + """Test that paths are properly expanded.""" + config = ConfigLoader.load( + config_path=Path("nonexistent.yaml"), + env_file=Path("nonexistent.env") + ) + + # Default paths should be expanded + assert not config.model_cache_dir.startswith("~") + assert not config.checkpoint_db.startswith("~") + + def test_get_backend_not_found(self): + """Test getting a non-existent backend.""" + config = Config() + + with pytest.raises(ValueError, match="not found"): + config.get_backend("nonexistent") + + +class TestBackendConfig: + """Test backend configuration.""" + + def test_backend_config_defaults(self): + """Test backend config with defaults.""" + config = BackendConfig( + name="Test", + model_class="test.Test", + model_id="test/model", + vram_gb=12, + dtype="fp16" + ) + + assert config.enable_vae_slicing is True + assert config.enable_vae_tiling is False + assert config.chunking_enabled is True + assert config.chunking_mode == "sequential" diff --git a/tests/unit/test_storyboard.py b/tests/unit/test_storyboard.py new file mode 100644 index 0000000..76e05e6 --- /dev/null +++ b/tests/unit/test_storyboard.py @@ -0,0 +1,268 @@ +""" +Tests for storyboard schema validation. +""" + +import pytest +from pathlib import Path +import json +import tempfile +import os + +from src.storyboard.schema import ( + Storyboard, ProjectSettings, Shot, Character, Location, + GlobalStyle, CameraSettings, GenerationSettings, OutputSettings, + Resolution +) +from src.storyboard.loader import StoryboardValidator, StoryboardLoadError + + +class TestStoryboardSchema: + """Test storyboard schema validation.""" + + def test_create_minimal_storyboard(self): + """Test creating a minimal valid storyboard.""" + storyboard = Storyboard( + project=ProjectSettings( + title="Test Video", + resolution=Resolution(width=1920, height=1080), + target_duration_s=10 + ), + shots=[ + Shot( + id="S01", + duration_s=5, + prompt="A test shot" + ) + ] + ) + + assert storyboard.project.title == "Test Video" + assert len(storyboard.shots) == 1 + assert storyboard.shots[0].id == "S01" + + def test_shot_validation_unique_ids(self): + """Test that duplicate shot IDs are rejected.""" + with pytest.raises(ValueError, match="Shot IDs must be unique"): + Storyboard( + project=ProjectSettings( + title="Test", + resolution=Resolution(width=1920, height=1080), + target_duration_s=10 + ), + shots=[ + Shot(id="S01", duration_s=5, prompt="Shot 1"), + Shot(id="S01", duration_s=5, prompt="Shot 2") + ] + ) + + def test_character_validation_unique_ids(self): + """Test that duplicate character IDs are rejected.""" + with pytest.raises(ValueError, match="Character IDs must be unique"): + Storyboard( + project=ProjectSettings( + title="Test", + resolution=Resolution(width=1920, height=1080), + target_duration_s=10 + ), + characters=[ + Character(id="C01", name="Character 1"), + Character(id="C01", name="Character 2") + ], + shots=[Shot(id="S01", duration_s=5, prompt="Test")] + ) + + def test_location_validation_unique_ids(self): + """Test that duplicate location IDs are rejected.""" + with pytest.raises(ValueError, match="Location IDs must be unique"): + Storyboard( + project=ProjectSettings( + title="Test", + resolution=Resolution(width=1920, height=1080), + target_duration_s=10 + ), + locations=[ + Location(id="L01", name="Location 1"), + Location(id="L01", name="Location 2") + ], + shots=[Shot(id="S01", duration_s=5, prompt="Test")] + ) + + def test_get_character_by_id(self): + """Test retrieving character by ID.""" + storyboard = Storyboard( + project=ProjectSettings( + title="Test", + resolution=Resolution(width=1920, height=1080), + target_duration_s=10 + ), + characters=[ + Character(id="C01", name="Hero", description="The protagonist") + ], + shots=[Shot(id="S01", duration_s=5, prompt="Test")] + ) + + char = storyboard.get_character("C01") + assert char is not None + assert char.name == "Hero" + + missing = storyboard.get_character("C99") + assert missing is None + + def test_get_location_by_id(self): + """Test retrieving location by ID.""" + storyboard = Storyboard( + project=ProjectSettings( + title="Test", + resolution=Resolution(width=1920, height=1080), + target_duration_s=10 + ), + locations=[ + Location(id="L01", name="Street", description="A city street") + ], + shots=[Shot(id="S01", duration_s=5, prompt="Test")] + ) + + loc = storyboard.get_location("L01") + assert loc is not None + assert loc.name == "Street" + + missing = storyboard.get_location("L99") + assert missing is None + + def test_total_duration_calculation(self): + """Test total duration calculation.""" + storyboard = Storyboard( + project=ProjectSettings( + title="Test", + resolution=Resolution(width=1920, height=1080), + target_duration_s=10 + ), + shots=[ + Shot(id="S01", duration_s=4, prompt="Shot 1"), + Shot(id="S02", duration_s=6, prompt="Shot 2") + ] + ) + + assert storyboard.get_total_duration() == 10 + + def test_total_frames_calculation(self): + """Test total frames calculation.""" + storyboard = Storyboard( + project=ProjectSettings( + title="Test", + fps=24, + resolution=Resolution(width=1920, height=1080), + target_duration_s=10 + ), + shots=[ + Shot(id="S01", duration_s=4, prompt="Shot 1"), + Shot(id="S02", duration_s=6, prompt="Shot 2") + ] + ) + + assert storyboard.get_total_frames() == 240 # 10 seconds * 24 fps + + +class TestStoryboardLoader: + """Test storyboard loading functionality.""" + + def test_load_valid_storyboard(self): + """Test loading a valid storyboard JSON file.""" + data = { + "schema_version": "1.0", + "project": { + "title": "Test Video", + "fps": 24, + "target_duration_s": 10, + "resolution": {"width": 1920, "height": 1080}, + "aspect_ratio": "16:9", + "global_style": { + "visual_style": "cinematic", + "negative_prompt": "blurry" + }, + "audio": {"add_music": False} + }, + "characters": [], + "locations": [], + "shots": [ + { + "id": "S01", + "duration_s": 5, + "prompt": "A test shot", + "camera": {"framing": "wide"}, + "generation": {"seed": 12345, "steps": 30} + } + ], + "output": {"container": "mp4", "codec": "h264"} + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(data, f) + temp_path = f.name + + try: + storyboard = StoryboardValidator.load(temp_path) + assert storyboard.project.title == "Test Video" + assert len(storyboard.shots) == 1 + assert storyboard.shots[0].id == "S01" + finally: + os.unlink(temp_path) + + def test_load_nonexistent_file(self): + """Test loading a file that doesn't exist.""" + with pytest.raises(StoryboardLoadError, match="not found"): + StoryboardValidator.load("/nonexistent/path/storyboard.json") + + def test_load_invalid_json(self): + """Test loading invalid JSON.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + f.write("{invalid json}") + temp_path = f.name + + try: + with pytest.raises(StoryboardLoadError, match="Invalid JSON"): + StoryboardValidator.load(temp_path) + finally: + os.unlink(temp_path) + + def test_load_wrong_extension(self): + """Test loading a file with wrong extension.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write('{}') + temp_path = f.name + + try: + with pytest.raises(StoryboardLoadError, match="must be JSON"): + StoryboardValidator.load(temp_path) + finally: + os.unlink(temp_path) + + def test_validate_references(self): + """Test reference validation.""" + storyboard = Storyboard( + project=ProjectSettings( + title="Test", + resolution=Resolution(width=1920, height=1080), + target_duration_s=10 + ), + characters=[ + Character(id="C01", name="Hero") + ], + locations=[ + Location(id="L01", name="Street") + ], + shots=[ + Shot( + id="S01", + duration_s=5, + prompt="Test", + characters=["C01", "C99"], # C99 doesn't exist + location_id="L99" # Doesn't exist + ) + ] + ) + + issues = StoryboardValidator.validate_references(storyboard) + assert len(issues) == 2 + assert any("C99" in issue for issue in issues) + assert any("L99" in issue for issue in issues)