Initial Commit

This commit is contained in:
2026-02-03 23:06:28 -05:00
commit 46b10fb69b
25 changed files with 2770 additions and 0 deletions

23
.env.example Normal file
View File

@@ -0,0 +1,23 @@
# Environment configuration for storyboard-video
# Copy this file to .env and customize as needed
# Active backend selection (must match a key in config/models.yaml)
ACTIVE_BACKEND=wan_t2v_14b
# Model cache directory (where downloaded models are stored)
MODEL_CACHE_DIR=~/.cache/storyboard-video/models
# HuggingFace Hub cache (if using HF models)
HF_HOME=~/.cache/huggingface
# CUDA settings
# CUDA_VISIBLE_DEVICES=0
# Logging level (DEBUG, INFO, WARNING, ERROR)
LOG_LEVEL=INFO
# Output directory base path
OUTPUT_BASE_DIR=./outputs
# FFmpeg path (leave empty to use system PATH)
FFMPEG_PATH=

138
.gitignore vendored Normal file
View File

@@ -0,0 +1,138 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# Virtual environments
venv/
env/
ENV/
env.bak/
venv.bak/
.venv
# Conda
conda-env/
*.conda
# IDEs
.vscode/
.idea/
*.swp
*.swo
*~
.DS_Store
# Jupyter Notebook
.ipynb_checkpoints
# Pytest
.pytest_cache/
.coverage
htmlcov/
.tox/
.nox/
# Environment variables
.env
.env.local
.env.*.local
# Model caches and outputs
models/
*.ckpt
*.safetensors
*.bin
*.pth
*.onnx
outputs/
checkpoints.db
# Temporary files
tmp/
temp/
*.tmp
*.temp
# Logs
*.log
logs/
# Video files (generated)
*.mp4
*.mov
*.avi
*.mkv
*.webm
# Image files (generated)
*.png
*.jpg
*.jpeg
*.gif
*.bmp
*.tiff
# Audio files (generated)
*.wav
*.mp3
*.aac
*.flac
*.ogg
# HuggingFace cache
.cache/
transformers_cache/
diffusers_cache/
# Windows
Thumbs.db
ehthumbs.db
Desktop.ini
$RECYCLE.BIN/
# macOS
.AppleDouble
.LSOverride
Icon
._*
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Project specific
# Keep templates but ignore generated outputs
!/templates/*.json
!/templates/*.yaml
!/templates/*.yml
# Allow example files
!.env.example
!config/*.yaml
!config/*.yml
# Keep docs
docs/*.md

210
AGENTS.MD Normal file
View File

@@ -0,0 +1,210 @@
# agents.md
## Project: Local AI Video Generation from Text Storyboards (Windows + RTX 5070 12GB)
### 0) Who this is for
The owner (user) is not an ML expert. The system must:
- be reproducible (conda + requirements)
- have guardrails (configs, logs, validation)
- be test-driven (pytest)
- maintain docs (developer + user)
---
## 1) High-Level Goal
Build a local pipeline that converts **text-only storyboards** into **1530 second videos** by:
1) converting storyboard -> shot plan
2) generating shot clips (T2V or I2V when possible)
3) assembling clips into a final MP4
4) upscaling to 2K/4K if desired
This is a **shot-based** system, not “one prompt makes a whole movie”.
---
## 2) Hard Constraints (Hardware & OS)
Target system:
- Windows 11
- NVIDIA RTX 5070 (12GB VRAM) - Must use GPU.
- 32GB RAM
- 2TB SSD
- Anaconda available
Design must be stable under 12GB VRAM using:
- fp16/bf16
- attention slicing
- xFormers / SDPA where supported
- optional CPU offload
---
## 3) Output Targets (Realistic)
- Native generation: 720p1080p (preferred)
- Final delivery: 1080p required; 2K/4K via upscaling
- Duration: 1530s per video (may be segmented)
- FPS: 24 default
- Output: MP4 (H.264/H.265)
---
## 4) CUDA 13.1 Reality & PyTorch Plan (Critical)
User has CUDA Toolkit 13.1 installed. Current PyTorch builds generally ship with and target CUDA 12.x runtimes.
We must NOT assume PyTorch will build/run against local CUDA 13.1 toolkit.
**Plan:**
- Use **PyTorch prebuilt binaries that bundle CUDA runtime** (e.g., cu121 / cu124).
- Rely on NVIDIA driver compatibility rather than local CUDA toolkit version.
- Avoid compiling custom CUDA extensions unless necessary.
Implementation notes:
- Prefer installing PyTorch via conda or pip using official CUDA 12.x builds.
- If xFormers causes build issues, use PyTorch SDPA and disable xFormers.
---
## 5) Approved Stack (Do Not Deviate)
### Core
- Python 3.10 or 3.11 (conda env)
- PyTorch (CUDA 12.x build: cu121 or cu124)
- diffusers + transformers + accelerate + safetensors
- ffmpeg for assembly
- opencv-python for frame IO (if needed)
- pydantic for config/schema validation
- rich / loguru for logs
### Testing
- pytest
- pytest-cov
- snapshot-ish tests where feasible (metadata + shapes, not visual perfection)
### Docs
- /docs/developer.md (developer documentation)
- /docs/user.md (user manual)
- Keep docs updated alongside code changes.
---
## 6) Video Models (Pragmatic Choices)
### Primary (target)
- WAN 2.x family (T2V; optional I2V if supported in chosen pipeline)
Goal: best possible quality on consumer VRAM with chunking.
### Secondary / fallback
- Stable Video Diffusion (SVD) if WAN is unstable
- LTX-Video (only if it fits and is stable in our stack)
All model backends must implement the same interface:
- generate_shot(shot_spec) -> video_file + metadata
---
## 7) Canonical Input: Storyboard JSON
Storyboard source is text-only (often AI-generated). We will store and validate it as JSON.
A template exists at: `templates/storyboard.template.json`
We will later build a utility script:
- input: plain text fields or a simple text format
- output: valid storyboard JSON
---
## 8) Pipeline Modules (Required)
### A) Storyboard parsing & validation
- Load storyboard JSON
- Validate schema
- Expand defaults (fps, resolution, global style)
- Produce normalized shot list
### B) Prompt compilation
- Merge global style + shot prompt + camera notes
- Produce positive + negative prompts
- Keep deterministic via seeds
### C) Generation runner (per shot)
- For each shot: generate clip
- Support:
- seed control
- chunking (e.g., generate 46 seconds then continue)
- optional init frame handoff between shots
### D) Assembly
- Use ffmpeg concat to build final video
- Optionally add:
- transitions
- temp audio
- burn-in shot IDs for debugging mode
### E) Upscaling (optional)
- Upscale final to 2K/4K (post step)
- Keep this modular so user can skip.
---
## 9) Determinism & Logging (Must Have)
For each shot and final render, save:
- prompts (positive/negative)
- seed(s)
- model + revision/hash info if available
- inference params (steps, cfg, sampler, resolution, fps, frames)
- timing + VRAM notes if possible
Every run produces a folder:
- outputs/<project>/<timestamp>/
- shots/
- assembled/
- metadata/
---
## 10) Testing Rules (Hard Requirement)
- Tests must be written alongside features.
- Whenever a file/function is modified, corresponding tests MUST be updated.
- Prefer tests that verify:
- schema validation works
- prompt compiler output is stable
- shot planner expands durations -> frame counts
- assembly command lines are correct
- metadata is generated correctly
Do not require “visual quality” assertions. Test structure and determinism.
---
## 11) Documentation Rules (Hard Requirement)
Maintain these continuously:
- docs/developer.md
- architecture
- install steps
- how to run tests
- how to add a new model backend
- docs/user.md
- quickstart
- how to create storyboard JSON
- how to run generation
- where outputs go
- troubleshooting (VRAM, drivers, ffmpeg)
Docs must be updated whenever CLI flags, file formats, or workflows change.
---
## 12) Project Files to Maintain
Required:
- requirements.txt (pip deps)
- environment.yml (conda env)
- templates/storyboard.template.json
- docs/developer.md
- docs/user.md
- src/ (implementation)
- tests/ (pytest)
---
## 13) Definition of Done
A feature is “done” only if:
- implemented
- tests added/updated
- docs updated
- reproducible install instructions remain valid
End of file.

35
TODO.MD Normal file
View File

@@ -0,0 +1,35 @@
# TODO
## Repo bootstrap
- [x] Define project direction and constraints in `agents.md`
- [x] Add `requirements.txt` (pip dependencies)
- [x] Add `environment.yml` (conda environment; PyTorch CUDA 12.x runtime strategy)
- [x] Add storyboard JSON template at `templates/storyboard.template.json`
## Core implementation (next)
- [ ] Create repo structure: `src/`, `tests/`, `docs/`, `templates/`, `outputs/`
- [ ] Implement storyboard schema validator (pydantic) + loader
- [ ] Implement prompt compiler (global style + shot + camera)
- [ ] Implement shot planning (duration -> frame count, chunk plan)
- [ ] Implement model backend interface (`BaseVideoBackend`)
- [ ] Implement WAN backend (primary) with VRAM-safe defaults
- [ ] Implement fallback backend (SVD) for reliability testing
- [ ] Implement ffmpeg assembler (concat + optional audio + debug burn-in)
- [ ] Implement optional upscaling module (post-process)
## Utilities
- [ ] Write storyboard “plain text → JSON” utility script (fills `storyboard.template.json`)
- [ ] Add config file support (YAML/JSON) for global defaults
## Testing (parallel work; required)
- [ ] Add `pytest` scaffolding
- [ ] Add tests for schema validation
- [ ] Add tests for prompt compilation determinism
- [ ] Add tests for shot planning (frames/chunks)
- [ ] Add tests for ffmpeg command generation (no actual render needed)
- [ ] Ensure every code change includes a corresponding test update
## Documentation (maintained continuously)
- [ ] Create `docs/developer.md` (install, architecture, tests, adding backends)
- [ ] Create `docs/user.md` (quickstart, storyboard creation, running, outputs, troubleshooting)
- [ ] Keep docs updated whenever CLI/config/schema changes

61
config/models.yaml Normal file
View File

@@ -0,0 +1,61 @@
# Model configurations for video generation backends
# Backend selection is controlled via ACTIVE_BACKEND environment variable
# or --backend CLI flag
backends:
wan_t2v_14b:
name: "Wan 2.1 T2V 14B"
class: "generation.backends.wan.WanBackend"
model_id: "Wan-AI/Wan2.1-T2V-14B"
vram_gb: 12
dtype: "fp16"
enable_vae_slicing: true
enable_vae_tiling: true
chunking:
enabled: true
mode: "sequential" # Options: sequential, overlapping
max_chunk_seconds: 4
overlap_seconds: 1 # Only used when mode is overlapping
wan_t2v_1_3b:
name: "Wan 2.1 T2V 1.3B"
class: "generation.backends.wan.WanBackend"
model_id: "Wan-AI/Wan2.1-T2V-1.3B"
vram_gb: 8
dtype: "fp16"
enable_vae_slicing: true
enable_vae_tiling: false
chunking:
enabled: true
mode: "sequential"
max_chunk_seconds: 6
overlap_seconds: 1
svd:
name: "Stable Video Diffusion"
class: "generation.backends.svd.SVDBackend"
model_id: "stabilityai/stable-video-diffusion-img2vid-xt"
vram_gb: 10
dtype: "fp16"
enable_vae_slicing: true
enable_vae_tiling: true
chunking:
enabled: false
# Default backend selection
# Override with ACTIVE_BACKEND environment variable
active_backend: "wan_t2v_14b"
# Global generation defaults
defaults:
fps: 24
resolution:
width: 1920
height: 1080
# Checkpoint database settings
checkpoint_db: "outputs/checkpoints.db"
# Model cache directory
# Override with MODEL_CACHE_DIR environment variable
model_cache_dir: "~/.cache/storyboard-video/models"

24
environment.yml Normal file
View File

@@ -0,0 +1,24 @@
# environment.yml
name: storyboard-video
channels:
- pytorch
- nvidia
- conda-forge
- defaults
dependencies:
- python=3.11
- pip=24.3.1
# FFmpeg in env is easiest on Windows if you rely on conda-forge
- ffmpeg=7.1
# PyTorch with CUDA runtime bundled (NOT using local CUDA 13.1 toolkit)
# Using compatible versions for PyTorch 2.5.1 with CUDA 12.4
- pytorch=2.5.1
- pytorch-cuda=12.4
- torchvision=0.20.1
- torchaudio=2.5.1
# pip deps (mirrors requirements.txt)
- pip:
- -r requirements.txt

7
pytest.ini Normal file
View File

@@ -0,0 +1,7 @@
# pytest configuration
[tool:pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short

26
requirements.txt Normal file
View File

@@ -0,0 +1,26 @@
# requirements.txt
# Install PyTorch separately (see environment.yml and docs).
diffusers==0.32.2
transformers==4.48.3
accelerate==1.3.0
safetensors==0.5.2
pydantic==2.10.6
pyyaml==6.0.2
numpy==2.1.3
pillow==11.1.0
opencv-python==4.11.0.86
rich==13.9.4
loguru==0.7.3
tqdm==4.67.1
requests==2.32.3
# CLI & tooling
typer==0.15.1
# Testing
pytest==8.3.4
pytest-cov==6.0.0

5
src/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""
Storyboard video generation pipeline.
"""
__version__ = "0.1.0"

3
src/assembly/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
Video assembly and post-processing.
"""

3
src/cli/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
CLI entry points.
"""

25
src/core/__init__.py Normal file
View File

@@ -0,0 +1,25 @@
"""
Core utilities and shared components.
"""
from .config import Config, BackendConfig, ConfigLoader, get_config, reload_config
from .checkpoint import (
CheckpointManager,
ShotCheckpoint,
ProjectCheckpoint,
ShotStatus,
ProjectStatus
)
__all__ = [
'Config',
'BackendConfig',
'ConfigLoader',
'get_config',
'reload_config',
'CheckpointManager',
'ShotCheckpoint',
'ProjectCheckpoint',
'ShotStatus',
'ProjectStatus'
]

435
src/core/checkpoint.py Normal file
View File

@@ -0,0 +1,435 @@
"""
Checkpoint and resume system using SQLite.
Tracks generation progress and enables resuming from failures.
"""
import sqlite3
import json
from pathlib import Path
from datetime import datetime
from typing import Optional, List, Dict, Any
from dataclasses import dataclass
from enum import Enum
class ShotStatus(str, Enum):
"""Status of a shot generation."""
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
SKIPPED = "skipped"
class ProjectStatus(str, Enum):
"""Status of a project."""
INITIALIZED = "initialized"
RUNNING = "running"
PAUSED = "paused"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class ShotCheckpoint:
"""Checkpoint data for a single shot."""
shot_id: str
status: ShotStatus
output_path: Optional[str] = None
error_message: Optional[str] = None
started_at: Optional[str] = None
completed_at: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
@dataclass
class ProjectCheckpoint:
"""Checkpoint data for a project."""
project_id: str
storyboard_path: str
output_dir: str
status: ProjectStatus
backend_name: str
started_at: str
completed_at: Optional[str] = None
error_message: Optional[str] = None
class CheckpointManager:
"""Manages checkpoints using SQLite database."""
def __init__(self, db_path: str = "outputs/checkpoints.db"):
"""
Initialize checkpoint manager.
Args:
db_path: Path to SQLite database file
"""
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_db()
def _init_db(self):
"""Initialize database schema."""
conn = sqlite3.connect(self.db_path)
try:
conn.execute("""
CREATE TABLE IF NOT EXISTS projects (
project_id TEXT PRIMARY KEY,
storyboard_path TEXT NOT NULL,
output_dir TEXT NOT NULL,
status TEXT NOT NULL,
backend_name TEXT NOT NULL,
started_at TEXT NOT NULL,
completed_at TEXT,
error_message TEXT
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS shots (
project_id TEXT NOT NULL,
shot_id TEXT NOT NULL,
status TEXT NOT NULL,
output_path TEXT,
error_message TEXT,
started_at TEXT,
completed_at TEXT,
metadata TEXT,
PRIMARY KEY (project_id, shot_id),
FOREIGN KEY (project_id) REFERENCES projects(project_id)
)
""")
conn.commit()
finally:
conn.close()
def _get_connection(self):
"""Get a new database connection."""
return sqlite3.connect(self.db_path)
def create_project(
self,
project_id: str,
storyboard_path: str,
output_dir: str,
backend_name: str
) -> ProjectCheckpoint:
"""
Create a new project checkpoint.
Args:
project_id: Unique project identifier
storyboard_path: Path to storyboard JSON file
output_dir: Output directory for generated files
backend_name: Name of the generation backend
Returns:
ProjectCheckpoint object
"""
checkpoint = ProjectCheckpoint(
project_id=project_id,
storyboard_path=storyboard_path,
output_dir=output_dir,
status=ProjectStatus.INITIALIZED,
backend_name=backend_name,
started_at=datetime.now().isoformat()
)
conn = self._get_connection()
try:
conn.execute(
"""
INSERT INTO projects
(project_id, storyboard_path, output_dir, status, backend_name, started_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
checkpoint.project_id,
checkpoint.storyboard_path,
checkpoint.output_dir,
checkpoint.status.value,
checkpoint.backend_name,
checkpoint.started_at
)
)
conn.commit()
finally:
conn.close()
return checkpoint
def get_project(self, project_id: str) -> Optional[ProjectCheckpoint]:
"""Get project checkpoint by ID."""
conn = self._get_connection()
try:
cursor = conn.execute(
"SELECT * FROM projects WHERE project_id = ?",
(project_id,)
)
row = cursor.fetchone()
if row is None:
return None
return ProjectCheckpoint(
project_id=row[0],
storyboard_path=row[1],
output_dir=row[2],
status=ProjectStatus(row[3]),
backend_name=row[4],
started_at=row[5],
completed_at=row[6],
error_message=row[7]
)
finally:
conn.close()
def update_project_status(
self,
project_id: str,
status: ProjectStatus,
error_message: Optional[str] = None
):
"""Update project status."""
completed_at = None
if status in [ProjectStatus.COMPLETED, ProjectStatus.FAILED]:
completed_at = datetime.now().isoformat()
conn = self._get_connection()
try:
conn.execute(
"""
UPDATE projects
SET status = ?, error_message = ?, completed_at = ?
WHERE project_id = ?
""",
(status.value, error_message, completed_at, project_id)
)
conn.commit()
finally:
conn.close()
def create_shot(
self,
project_id: str,
shot_id: str,
status: ShotStatus = ShotStatus.PENDING
) -> ShotCheckpoint:
"""Create a shot checkpoint."""
checkpoint = ShotCheckpoint(
shot_id=shot_id,
status=status
)
conn = self._get_connection()
try:
conn.execute(
"""
INSERT INTO shots
(project_id, shot_id, status, metadata)
VALUES (?, ?, ?, ?)
""",
(
project_id,
shot_id,
status.value,
json.dumps(checkpoint.metadata)
)
)
conn.commit()
finally:
conn.close()
return checkpoint
def get_shot(self, project_id: str, shot_id: str) -> Optional[ShotCheckpoint]:
"""Get shot checkpoint."""
conn = self._get_connection()
try:
cursor = conn.execute(
"SELECT * FROM shots WHERE project_id = ? AND shot_id = ?",
(project_id, shot_id)
)
row = cursor.fetchone()
if row is None:
return None
return ShotCheckpoint(
shot_id=row[1],
status=ShotStatus(row[2]),
output_path=row[3],
error_message=row[4],
started_at=row[5],
completed_at=row[6],
metadata=json.loads(row[7]) if row[7] else {}
)
finally:
conn.close()
def get_project_shots(self, project_id: str) -> List[ShotCheckpoint]:
"""Get all shots for a project."""
conn = self._get_connection()
try:
cursor = conn.execute(
"SELECT * FROM shots WHERE project_id = ? ORDER BY shot_id",
(project_id,)
)
rows = cursor.fetchall()
return [
ShotCheckpoint(
shot_id=row[1],
status=ShotStatus(row[2]),
output_path=row[3],
error_message=row[4],
started_at=row[5],
completed_at=row[6],
metadata=json.loads(row[7]) if row[7] else {}
)
for row in rows
]
finally:
conn.close()
def update_shot(
self,
project_id: str,
shot_id: str,
status: Optional[ShotStatus] = None,
output_path: Optional[str] = None,
error_message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
):
"""Update shot checkpoint."""
updates = []
params = []
if status is not None:
updates.append("status = ?")
params.append(status.value)
if status == ShotStatus.IN_PROGRESS:
updates.append("started_at = ?")
params.append(datetime.now().isoformat())
elif status in [ShotStatus.COMPLETED, ShotStatus.FAILED]:
updates.append("completed_at = ?")
params.append(datetime.now().isoformat())
if output_path is not None:
updates.append("output_path = ?")
params.append(output_path)
if error_message is not None:
updates.append("error_message = ?")
params.append(error_message)
if metadata is not None:
updates.append("metadata = ?")
params.append(json.dumps(metadata))
if not updates:
return
params.extend([project_id, shot_id])
conn = self._get_connection()
try:
conn.execute(
f"""
UPDATE shots
SET {', '.join(updates)}
WHERE project_id = ? AND shot_id = ?
""",
params
)
conn.commit()
finally:
conn.close()
def get_pending_shots(self, project_id: str) -> List[str]:
"""Get list of pending shot IDs for a project."""
conn = self._get_connection()
try:
cursor = conn.execute(
"""
SELECT shot_id FROM shots
WHERE project_id = ? AND status = ?
ORDER BY shot_id
""",
(project_id, ShotStatus.PENDING.value)
)
return [row[0] for row in cursor.fetchall()]
finally:
conn.close()
def get_failed_shots(self, project_id: str) -> List[str]:
"""Get list of failed shot IDs for a project."""
conn = self._get_connection()
try:
cursor = conn.execute(
"""
SELECT shot_id FROM shots
WHERE project_id = ? AND status = ?
ORDER BY shot_id
""",
(project_id, ShotStatus.FAILED.value)
)
return [row[0] for row in cursor.fetchall()]
finally:
conn.close()
def can_resume(self, project_id: str) -> bool:
"""Check if a project can be resumed."""
project = self.get_project(project_id)
if project is None:
return False
if project.status == ProjectStatus.COMPLETED:
return False
pending = self.get_pending_shots(project_id)
failed = self.get_failed_shots(project_id)
return len(pending) > 0 or len(failed) > 0
def delete_project(self, project_id: str):
"""Delete a project and all its shots."""
conn = self._get_connection()
try:
conn.execute("DELETE FROM shots WHERE project_id = ?", (project_id,))
conn.execute("DELETE FROM projects WHERE project_id = ?", (project_id,))
conn.commit()
finally:
conn.close()
def list_projects(self) -> List[ProjectCheckpoint]:
"""List all projects."""
conn = self._get_connection()
try:
cursor = conn.execute(
"SELECT * FROM projects ORDER BY started_at DESC"
)
rows = cursor.fetchall()
return [
ProjectCheckpoint(
project_id=row[0],
storyboard_path=row[1],
output_dir=row[2],
status=ProjectStatus(row[3]),
backend_name=row[4],
started_at=row[5],
completed_at=row[6],
error_message=row[7]
)
for row in rows
]
finally:
conn.close()

212
src/core/config.py Normal file
View File

@@ -0,0 +1,212 @@
"""
Configuration management system.
Handles YAML configs and environment variables.
"""
import os
import yaml
from pathlib import Path
from typing import Any, Dict, Optional
from dataclasses import dataclass, field
@dataclass
class BackendConfig:
"""Configuration for a generation backend."""
name: str
model_class: str
model_id: str
vram_gb: int
dtype: str
enable_vae_slicing: bool = True
enable_vae_tiling: bool = False
chunking_enabled: bool = True
chunking_mode: str = "sequential"
max_chunk_seconds: int = 4
overlap_seconds: int = 1
@dataclass
class Config:
"""Application configuration."""
active_backend: str = "wan_t2v_14b"
backends: Dict[str, BackendConfig] = field(default_factory=dict)
defaults: Dict[str, Any] = field(default_factory=dict)
checkpoint_db: str = "outputs/checkpoints.db"
model_cache_dir: str = "~/.cache/storyboard-video/models"
log_level: str = "INFO"
output_base_dir: str = "./outputs"
def get_backend(self, name: Optional[str] = None) -> BackendConfig:
"""Get backend configuration by name."""
name = name or self.active_backend
if name not in self.backends:
raise ValueError(f"Backend '{name}' not found in configuration")
return self.backends[name]
class ConfigLoader:
"""Loads configuration from YAML files and environment variables."""
@staticmethod
def load(config_path: Optional[Path] = None, env_file: Optional[Path] = None) -> Config:
"""
Load configuration from files and environment.
Priority (highest to lowest):
1. Environment variables
2. .env file
3. YAML config file
4. Defaults
Args:
config_path: Path to YAML config file (default: config/models.yaml)
env_file: Path to .env file (default: .env if exists)
Returns:
Config object
"""
# Start with defaults
config = Config()
# Load YAML config
if config_path is None:
config_path = Path("config/models.yaml")
if config_path.exists():
config = ConfigLoader._load_yaml(config_path, config)
# Load .env file
if env_file is None:
env_file = Path(".env")
if env_file.exists():
config = ConfigLoader._load_env_file(env_file, config)
# Override with environment variables
config = ConfigLoader._load_env_vars(config)
# Expand paths
config.model_cache_dir = str(Path(config.model_cache_dir).expanduser())
config.checkpoint_db = str(Path(config.checkpoint_db).expanduser())
config.output_base_dir = str(Path(config.output_base_dir).expanduser())
return config
@staticmethod
def _load_yaml(path: Path, config: Config) -> Config:
"""Load configuration from YAML file."""
with open(path, 'r') as f:
data = yaml.safe_load(f)
if not data:
return config
# Load backends
if 'backends' in data:
for backend_id, backend_data in data['backends'].items():
chunking = backend_data.get('chunking', {})
config.backends[backend_id] = BackendConfig(
name=backend_data.get('name', backend_id),
model_class=backend_data.get('class', ''),
model_id=backend_data.get('model_id', ''),
vram_gb=backend_data.get('vram_gb', 12),
dtype=backend_data.get('dtype', 'fp16'),
enable_vae_slicing=backend_data.get('enable_vae_slicing', True),
enable_vae_tiling=backend_data.get('enable_vae_tiling', False),
chunking_enabled=chunking.get('enabled', True),
chunking_mode=chunking.get('mode', 'sequential'),
max_chunk_seconds=chunking.get('max_chunk_seconds', 4),
overlap_seconds=chunking.get('overlap_seconds', 1)
)
# Load active backend
if 'active_backend' in data:
config.active_backend = data['active_backend']
# Load defaults
if 'defaults' in data:
config.defaults = data['defaults']
# Load checkpoint DB path
if 'checkpoint_db' in data:
config.checkpoint_db = data['checkpoint_db']
# Load model cache dir
if 'model_cache_dir' in data:
config.model_cache_dir = data['model_cache_dir']
return config
@staticmethod
def _load_env_file(path: Path, config: Config) -> Config:
"""Load configuration from .env file."""
env_vars = {}
with open(path, 'r') as f:
for line in f:
line = line.strip()
if not line or line.startswith('#'):
continue
if '=' in line:
key, value = line.split('=', 1)
key = key.strip()
value = value.strip().strip('"\'')
env_vars[key] = value
# Apply env vars directly to config (don't set in os.environ)
if 'ACTIVE_BACKEND' in env_vars:
config.active_backend = env_vars['ACTIVE_BACKEND']
if 'MODEL_CACHE_DIR' in env_vars:
config.model_cache_dir = env_vars['MODEL_CACHE_DIR']
if 'LOG_LEVEL' in env_vars:
config.log_level = env_vars['LOG_LEVEL']
if 'OUTPUT_BASE_DIR' in env_vars:
config.output_base_dir = env_vars['OUTPUT_BASE_DIR']
if 'CHECKPOINT_DB' in env_vars:
config.checkpoint_db = env_vars['CHECKPOINT_DB']
return config
@staticmethod
def _load_env_vars(config: Config) -> Config:
"""Override config with environment variables."""
if 'ACTIVE_BACKEND' in os.environ:
config.active_backend = os.environ['ACTIVE_BACKEND']
if 'MODEL_CACHE_DIR' in os.environ:
config.model_cache_dir = os.environ['MODEL_CACHE_DIR']
if 'LOG_LEVEL' in os.environ:
config.log_level = os.environ['LOG_LEVEL']
if 'OUTPUT_BASE_DIR' in os.environ:
config.output_base_dir = os.environ['OUTPUT_BASE_DIR']
if 'CHECKPOINT_DB' in os.environ:
config.checkpoint_db = os.environ['CHECKPOINT_DB']
return config
# Global config instance (lazy-loaded)
_config: Optional[Config] = None
def get_config() -> Config:
"""Get the global configuration instance."""
global _config
if _config is None:
_config = ConfigLoader.load()
return _config
def reload_config() -> Config:
"""Reload configuration from files."""
global _config
_config = ConfigLoader.load()
return _config

View File

@@ -0,0 +1,17 @@
"""
Video generation backends.
"""
from .base import (
BaseVideoBackend,
GenerationResult,
GenerationSpec,
BackendFactory
)
__all__ = [
'BaseVideoBackend',
'GenerationResult',
'GenerationSpec',
'BackendFactory'
]

235
src/generation/base.py Normal file
View File

@@ -0,0 +1,235 @@
"""
Base interface for video generation backends.
All backends (WAN, SVD, etc.) must implement this interface.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Dict, Any, List
from dataclasses import dataclass
@dataclass
class GenerationResult:
"""Result of a video generation."""
success: bool
output_path: Optional[Path] = None
metadata: Optional[Dict[str, Any]] = None
error_message: Optional[str] = None
vram_usage_gb: Optional[float] = None
generation_time_s: Optional[float] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
@dataclass
class GenerationSpec:
"""Specification for video generation."""
prompt: str
negative_prompt: str
width: int
height: int
num_frames: int
fps: int
seed: int
steps: int
cfg_scale: float
output_path: Path
init_frame_path: Optional[Path] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for metadata storage."""
return {
"prompt": self.prompt,
"negative_prompt": self.negative_prompt,
"width": self.width,
"height": self.height,
"num_frames": self.num_frames,
"fps": self.fps,
"seed": self.seed,
"steps": self.steps,
"cfg_scale": self.cfg_scale,
"output_path": str(self.output_path),
"init_frame_path": str(self.init_frame_path) if self.init_frame_path else None
}
class BaseVideoBackend(ABC):
"""
Abstract base class for video generation backends.
All backends must implement this interface to be used in the pipeline.
"""
def __init__(self, config: Dict[str, Any]):
"""
Initialize the backend with configuration.
Args:
config: Backend-specific configuration dictionary
"""
self.config = config
self.model = None
self._is_loaded = False
@property
@abstractmethod
def name(self) -> str:
"""Return the backend name."""
pass
@property
@abstractmethod
def supports_chunking(self) -> bool:
"""Whether this backend supports chunked generation."""
pass
@property
@abstractmethod
def supports_init_frame(self) -> bool:
"""Whether this backend supports init frame conditioning."""
pass
@abstractmethod
def load(self) -> None:
"""
Load the model into memory.
This should be called before generate().
Implementations should set self._is_loaded = True when complete.
"""
pass
@abstractmethod
def unload(self) -> None:
"""
Unload the model from memory.
Implementations should set self._is_loaded = False when complete.
"""
pass
@abstractmethod
def generate(self, spec: GenerationSpec) -> GenerationResult:
"""
Generate a video clip.
Args:
spec: Generation specification
Returns:
GenerationResult with output path and metadata
"""
pass
def generate_chunked(
self,
spec: GenerationSpec,
chunk_duration_s: int,
overlap_s: int = 0,
mode: str = "sequential"
) -> GenerationResult:
"""
Generate a video in chunks.
Default implementation for backends that don't natively support chunking.
Backends with native support should override this.
Args:
spec: Generation specification
chunk_duration_s: Duration of each chunk in seconds
overlap_s: Overlap between chunks in seconds (for blending)
mode: "sequential" or "overlapping"
Returns:
GenerationResult with concatenated video
"""
if not self.supports_chunking:
# Fall back to single generation
return self.generate(spec)
raise NotImplementedError(
"Chunked generation not implemented for this backend"
)
def estimate_vram_usage(self, width: int, height: int, num_frames: int) -> float:
"""
Estimate VRAM usage for a generation.
Args:
width: Video width
height: Video height
num_frames: Number of frames
Returns:
Estimated VRAM usage in GB
"""
# Default implementation - backends should override with better estimates
return self.config.get("vram_gb", 12.0)
def check_vram_available(self, width: int, height: int, num_frames: int) -> bool:
"""
Check if sufficient VRAM is available.
Args:
width: Video width
height: Video height
num_frames: Number of frames
Returns:
True if generation should fit in VRAM
"""
try:
import torch
if torch.cuda.is_available():
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
estimated = self.estimate_vram_usage(width, height, num_frames)
# Leave 10% headroom
return estimated < total_vram * 0.9
except:
pass
# If we can't check, assume it's OK
return True
@property
def is_loaded(self) -> bool:
"""Check if the model is loaded."""
return self._is_loaded
def __enter__(self):
"""Context manager entry."""
self.load()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.unload()
class BackendFactory:
"""Factory for creating backend instances."""
_backends: Dict[str, type] = {}
@classmethod
def register(cls, name: str, backend_class: type):
"""Register a backend class."""
if not issubclass(backend_class, BaseVideoBackend):
raise ValueError(f"Backend must inherit from BaseVideoBackend")
cls._backends[name] = backend_class
@classmethod
def create(cls, name: str, config: Dict[str, Any]) -> BaseVideoBackend:
"""Create a backend instance."""
if name not in cls._backends:
raise ValueError(f"Unknown backend: {name}. "
f"Available: {list(cls._backends.keys())}")
return cls._backends[name](config)
@classmethod
def list_backends(cls) -> List[str]:
"""List available backend names."""
return list(cls._backends.keys())

View File

@@ -0,0 +1,39 @@
"""
Storyboard module for validation, loading, and prompt compilation.
"""
from .schema import (
Storyboard,
ProjectSettings,
Shot,
Character,
Location,
GlobalStyle,
CameraSettings,
GenerationSettings,
OutputSettings,
Resolution
)
from .loader import StoryboardValidator, StoryboardLoadError
from .prompt_compiler import PromptCompiler, CompiledPrompt
from .shot_planner import ShotPlanner, ShotPlan, Chunk
__all__ = [
'Storyboard',
'ProjectSettings',
'Shot',
'Character',
'Location',
'GlobalStyle',
'CameraSettings',
'GenerationSettings',
'OutputSettings',
'Resolution',
'StoryboardValidator',
'StoryboardLoadError',
'PromptCompiler',
'CompiledPrompt',
'ShotPlanner',
'ShotPlan',
'Chunk'
]

84
src/storyboard/loader.py Normal file
View File

@@ -0,0 +1,84 @@
"""
Storyboard loader and validator.
Handles loading JSON files and validating against schema.
"""
import json
from pathlib import Path
from typing import Union
from .schema import Storyboard
class StoryboardLoadError(Exception):
"""Raised when storyboard loading fails."""
pass
class StoryboardValidator:
"""Validates and loads storyboard JSON files."""
@staticmethod
def load(path: Union[str, Path]) -> Storyboard:
"""
Load and validate a storyboard JSON file.
Args:
path: Path to the JSON file
Returns:
Validated Storyboard object
Raises:
StoryboardLoadError: If file doesn't exist or is invalid
"""
path = Path(path)
if not path.exists():
raise StoryboardLoadError(f"Storyboard file not found: {path}")
if not path.suffix.lower() == '.json':
raise StoryboardLoadError(f"Storyboard file must be JSON: {path}")
try:
with open(path, 'r', encoding='utf-8') as f:
data = json.load(f)
except json.JSONDecodeError as e:
raise StoryboardLoadError(f"Invalid JSON in {path}: {e}")
except Exception as e:
raise StoryboardLoadError(f"Failed to read {path}: {e}")
try:
storyboard = Storyboard(**data)
except Exception as e:
raise StoryboardLoadError(f"Storyboard validation failed: {e}")
return storyboard
@staticmethod
def validate_references(storyboard: Storyboard) -> list[str]:
"""
Validate that all references (location_id, characters) exist.
Args:
storyboard: Loaded storyboard
Returns:
List of validation warnings/errors
"""
issues = []
# Validate character references
char_ids = {char.id for char in storyboard.characters}
for shot in storyboard.shots:
for char_id in shot.characters:
if char_id not in char_ids:
issues.append(f"Shot {shot.id}: Character '{char_id}' not defined")
# Validate location references
loc_ids = {loc.id for loc in storyboard.locations}
for shot in storyboard.shots:
if shot.location_id and shot.location_id not in loc_ids:
issues.append(f"Shot {shot.id}: Location '{shot.location_id}' not defined")
return issues

View File

@@ -0,0 +1,154 @@
"""
Prompt compiler for merging global style with shot-specific prompts.
"""
from typing import List, Optional
from dataclasses import dataclass
from ..storyboard.schema import Storyboard, Shot, GlobalStyle
@dataclass
class CompiledPrompt:
"""Compiled prompt ready for generation."""
positive: str
negative: str
camera_notes: str
full_prompt: str # Combined positive + style + camera
def to_dict(self) -> dict:
"""Convert to dictionary for metadata storage."""
return {
"positive": self.positive,
"negative": self.negative,
"camera_notes": self.camera_notes,
"full_prompt": self.full_prompt
}
class PromptCompiler:
"""
Compiles prompts by merging global style with shot-specific content.
"""
def __init__(self, storyboard: Storyboard):
"""
Initialize with storyboard.
Args:
storyboard: Loaded and validated storyboard
"""
self.storyboard = storyboard
self.global_style = storyboard.project.global_style
def compile_shot(self, shot: Shot) -> CompiledPrompt:
"""
Compile prompt for a specific shot.
Args:
shot: Shot to compile prompt for
Returns:
CompiledPrompt with merged content
"""
# Build character descriptions
char_descriptions = []
for char_id in shot.characters:
char = self.storyboard.get_character(char_id)
if char:
char_descriptions.append(f"{char.name}: {char.description}")
# Build location description
location_desc = ""
if shot.location_id:
loc = self.storyboard.get_location(shot.location_id)
if loc:
location_desc = loc.description
# Compile positive prompt
positive_parts = []
# Add global visual style first (sets the aesthetic)
if self.global_style.visual_style:
positive_parts.append(self.global_style.visual_style)
# Add location context
if location_desc:
positive_parts.append(f"Setting: {location_desc}")
# Add character descriptions
if char_descriptions:
positive_parts.append("Characters: " + "; ".join(char_descriptions))
# Add the main shot prompt
positive_parts.append(shot.prompt)
# Add camera/lens info
camera_parts = []
if shot.camera.framing:
camera_parts.append(shot.camera.framing)
if self.global_style.lens:
camera_parts.append(self.global_style.lens)
if shot.camera.movement:
camera_parts.append(shot.camera.movement)
if self.global_style.motion_style:
camera_parts.append(self.global_style.motion_style)
if shot.camera.notes:
camera_parts.append(shot.camera.notes)
if camera_parts:
positive_parts.append(f"Camera: {', '.join(camera_parts)}")
# Add lighting
if self.global_style.lighting:
positive_parts.append(f"Lighting: {self.global_style.lighting}")
# Add color grade
if self.global_style.color_grade:
positive_parts.append(f"Color: {self.global_style.color_grade}")
positive = ". ".join(positive_parts)
# Compile negative prompt
negative_parts = []
if self.global_style.negative_prompt:
negative_parts.append(self.global_style.negative_prompt)
negative = ", ".join(negative_parts) if negative_parts else ""
# Camera notes for metadata
camera_notes = "; ".join(camera_parts) if camera_parts else ""
# Full combined prompt (for reference)
full = positive
return CompiledPrompt(
positive=positive,
negative=negative,
camera_notes=camera_notes,
full_prompt=full
)
def compile_all(self) -> List[CompiledPrompt]:
"""
Compile prompts for all shots.
Returns:
List of CompiledPrompt objects in shot order
"""
return [self.compile_shot(shot) for shot in self.storyboard.shots]
def get_shot_prompt(self, shot_id: str) -> Optional[CompiledPrompt]:
"""
Get compiled prompt for a specific shot by ID.
Args:
shot_id: Shot identifier
Returns:
CompiledPrompt or None if shot not found
"""
for shot in self.storyboard.shots:
if shot.id == shot_id:
return self.compile_shot(shot)
return None

159
src/storyboard/schema.py Normal file
View File

@@ -0,0 +1,159 @@
"""
Storyboard schema validation using Pydantic.
Defines the data models for storyboard JSON files.
"""
from typing import List, Optional, Literal
from pydantic import BaseModel, Field, validator
class Resolution(BaseModel):
"""Video resolution dimensions."""
width: int = Field(..., ge=256, le=4096, description="Width in pixels")
height: int = Field(..., ge=256, le=4096, description="Height in pixels")
class GlobalStyle(BaseModel):
"""Global visual style settings applied to all shots."""
visual_style: str = Field(default="", description="Overall visual aesthetic")
color_grade: str = Field(default="", description="Color grading description")
lens: str = Field(default="", description="Lens type/characteristics")
lighting: str = Field(default="", description="Lighting setup description")
motion_style: str = Field(default="", description="Camera motion characteristics")
negative_prompt: str = Field(default="", description="What to avoid in generation")
class AudioSettings(BaseModel):
"""Audio configuration for the final video."""
add_music: bool = Field(default=False)
music_path: str = Field(default="")
add_voiceover: bool = Field(default=False)
voiceover_path: str = Field(default="")
class ProjectSettings(BaseModel):
"""Top-level project configuration."""
title: str = Field(..., min_length=1, description="Project title")
fps: int = Field(default=24, ge=1, le=120, description="Frames per second")
target_duration_s: int = Field(default=20, ge=1, le=300, description="Target duration in seconds")
resolution: Resolution
aspect_ratio: str = Field(default="16:9", description="Aspect ratio (e.g., 16:9, 4:3)")
global_style: GlobalStyle = Field(default_factory=GlobalStyle)
audio: AudioSettings = Field(default_factory=AudioSettings)
class Character(BaseModel):
"""Character definition for consistency across shots."""
id: str = Field(..., min_length=1, description="Unique character identifier")
name: str = Field(..., min_length=1, description="Character name")
description: str = Field(default="", description="Visual description")
consistency_notes: str = Field(default="", description="Notes for maintaining consistency")
class Location(BaseModel):
"""Location/setting definition."""
id: str = Field(..., min_length=1, description="Unique location identifier")
name: str = Field(..., min_length=1, description="Location name")
description: str = Field(default="", description="Visual description of the setting")
class CameraSettings(BaseModel):
"""Camera configuration for a shot."""
framing: str = Field(default="", description="Shot framing (wide, medium, close-up, etc.)")
movement: str = Field(default="", description="Camera movement description")
notes: str = Field(default="", description="Additional camera notes")
class GenerationSettings(BaseModel):
"""AI generation parameters for a shot."""
seed: int = Field(default=-1, description="Random seed (-1 for random)")
steps: int = Field(default=30, ge=1, le=100, description="Inference steps")
cfg_scale: float = Field(default=6.0, ge=1.0, le=20.0, description="Classifier-free guidance scale")
sampler: str = Field(default="default", description="Sampler type")
chunk_seconds: int = Field(default=4, ge=1, le=10, description="Duration per chunk in seconds")
use_init_frame_from_prev: bool = Field(default=False, description="Use last frame of previous shot as init")
class Shot(BaseModel):
"""Individual shot definition."""
id: str = Field(..., min_length=1, description="Unique shot identifier")
duration_s: int = Field(..., ge=1, le=60, description="Shot duration in seconds")
location_id: Optional[str] = Field(default=None, description="Reference to location")
characters: List[str] = Field(default_factory=list, description="List of character IDs")
prompt: str = Field(..., min_length=1, description="Generation prompt for this shot")
camera: CameraSettings = Field(default_factory=CameraSettings)
generation: GenerationSettings = Field(default_factory=GenerationSettings)
class UpscaleSettings(BaseModel):
"""Post-processing upscaling configuration."""
enabled: bool = Field(default=False)
target_height: int = Field(default=2160, ge=720, le=4320, description="Target height in pixels")
method: str = Field(default="default", description="Upscaling method")
class OutputSettings(BaseModel):
"""Final video output configuration."""
container: Literal["mp4", "mov", "mkv"] = Field(default="mp4")
codec: Literal["h264", "h265", "vp9"] = Field(default="h264")
crf: int = Field(default=18, ge=0, le=51, description="Constant rate factor (lower = higher quality)")
preset: str = Field(default="medium", description="Encoding preset")
upscale: UpscaleSettings = Field(default_factory=UpscaleSettings)
class Storyboard(BaseModel):
"""Complete storyboard document."""
schema_version: str = Field(default="1.0", description="Schema version for compatibility")
project: ProjectSettings
characters: List[Character] = Field(default_factory=list)
locations: List[Location] = Field(default_factory=list)
shots: List[Shot] = Field(..., description="List of shots to generate")
output: OutputSettings = Field(default_factory=OutputSettings)
@validator('shots')
def validate_shot_ids_unique(cls, v):
"""Ensure all shot IDs are unique and at least one shot exists."""
if len(v) < 1:
raise ValueError("At least one shot is required")
ids = [shot.id for shot in v]
if len(ids) != len(set(ids)):
raise ValueError("Shot IDs must be unique")
return v
@validator('characters')
def validate_character_ids_unique(cls, v):
"""Ensure all character IDs are unique."""
ids = [char.id for char in v]
if len(ids) != len(set(ids)):
raise ValueError("Character IDs must be unique")
return v
@validator('locations')
def validate_location_ids_unique(cls, v):
"""Ensure all location IDs are unique."""
ids = [loc.id for loc in v]
if len(ids) != len(set(ids)):
raise ValueError("Location IDs must be unique")
return v
def get_character(self, char_id: str) -> Optional[Character]:
"""Get character by ID."""
for char in self.characters:
if char.id == char_id:
return char
return None
def get_location(self, loc_id: str) -> Optional[Location]:
"""Get location by ID."""
for loc in self.locations:
if loc.id == loc_id:
return loc
return None
def get_total_duration(self) -> int:
"""Calculate total duration in seconds."""
return sum(shot.duration_s for shot in self.shots)
def get_total_frames(self) -> int:
"""Calculate total frame count."""
return self.get_total_duration() * self.project.fps

View File

@@ -0,0 +1,252 @@
"""
Shot planner for converting duration to frames and chunking strategy.
"""
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from pathlib import Path
from ..storyboard.schema import Storyboard, Shot
@dataclass
class Chunk:
"""A single chunk of a shot."""
chunk_index: int
start_frame: int
end_frame: int
num_frames: int
duration_s: float
use_init_frame: bool
init_frame_source: Optional[Path] # Path to frame file or previous chunk
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for metadata."""
return {
"chunk_index": self.chunk_index,
"start_frame": self.start_frame,
"end_frame": self.end_frame,
"num_frames": self.num_frames,
"duration_s": self.duration_s,
"use_init_frame": self.use_init_frame,
"init_frame_source": str(self.init_frame_source) if self.init_frame_source else None
}
@dataclass
class ShotPlan:
"""Complete plan for generating a shot."""
shot_id: str
total_frames: int
total_duration_s: int
fps: int
chunks: List[Chunk]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for metadata."""
return {
"shot_id": self.shot_id,
"total_frames": self.total_frames,
"total_duration_s": self.total_duration_s,
"fps": self.fps,
"chunks": [c.to_dict() for c in self.chunks]
}
class ShotPlanner:
"""
Plans shot generation including frame counts and chunking strategy.
"""
def __init__(
self,
fps: int = 24,
max_chunk_seconds: int = 4,
overlap_seconds: int = 0,
chunking_mode: str = "sequential"
):
"""
Initialize shot planner.
Args:
fps: Frames per second
max_chunk_seconds: Maximum duration per chunk
overlap_seconds: Overlap between chunks (for blending mode)
chunking_mode: "sequential" or "overlapping"
"""
self.fps = fps
self.max_chunk_seconds = max_chunk_seconds
self.overlap_seconds = overlap_seconds
self.chunking_mode = chunking_mode
def duration_to_frames(self, duration_s: int) -> int:
"""
Convert duration in seconds to frame count.
Args:
duration_s: Duration in seconds
Returns:
Number of frames
"""
return duration_s * self.fps
def frames_to_duration(self, num_frames: int) -> float:
"""
Convert frame count to duration in seconds.
Args:
num_frames: Number of frames
Returns:
Duration in seconds
"""
return num_frames / self.fps
def plan_shot(
self,
shot: Shot,
fps: Optional[int] = None,
use_init_from_prev: bool = False
) -> ShotPlan:
"""
Create a generation plan for a shot.
Args:
shot: Shot to plan
fps: Override FPS (uses storyboard default if None)
use_init_from_prev: Whether to use previous shot's last frame
Returns:
ShotPlan with chunk breakdown
"""
fps = fps or self.fps
total_frames = shot.duration_s * fps
# Determine chunk size
chunk_frames = self.max_chunk_seconds * fps
overlap_frames = self.overlap_seconds * fps if self.chunking_mode == "overlapping" else 0
chunks = []
if total_frames <= chunk_frames:
# Single chunk - no need to split
chunk = Chunk(
chunk_index=0,
start_frame=0,
end_frame=total_frames - 1,
num_frames=total_frames,
duration_s=shot.duration_s,
use_init_frame=use_init_from_prev and shot.generation.use_init_frame_from_prev,
init_frame_source=None # Will be set by pipeline
)
chunks.append(chunk)
else:
# Multiple chunks needed
remaining_frames = total_frames
current_frame = 0
chunk_index = 0
while remaining_frames > 0:
# Calculate frames for this chunk
if self.chunking_mode == "sequential":
# Sequential: Each chunk starts where previous ended
this_chunk_frames = min(chunk_frames, remaining_frames)
chunk = Chunk(
chunk_index=chunk_index,
start_frame=current_frame,
end_frame=current_frame + this_chunk_frames - 1,
num_frames=this_chunk_frames,
duration_s=this_chunk_frames / fps,
use_init_frame=(chunk_index == 0 and use_init_from_prev and shot.generation.use_init_frame_from_prev) or chunk_index > 0,
init_frame_source=None
)
chunks.append(chunk)
current_frame += this_chunk_frames
remaining_frames -= this_chunk_frames
else: # overlapping
# Overlapping: Chunks overlap for blending
if remaining_frames <= chunk_frames:
# Last chunk - take all remaining frames
this_chunk_frames = remaining_frames
else:
# Not last chunk - include overlap
this_chunk_frames = min(chunk_frames + overlap_frames, remaining_frames)
chunk = Chunk(
chunk_index=chunk_index,
start_frame=current_frame,
end_frame=current_frame + this_chunk_frames - 1,
num_frames=this_chunk_frames,
duration_s=this_chunk_frames / fps,
use_init_frame=(chunk_index == 0 and use_init_from_prev and shot.generation.use_init_frame_from_prev) or chunk_index > 0,
init_frame_source=None
)
chunks.append(chunk)
# Move forward by chunk size (minus overlap)
current_frame += chunk_frames - overlap_frames
remaining_frames -= chunk_frames - overlap_frames
chunk_index += 1
return ShotPlan(
shot_id=shot.id,
total_frames=total_frames,
total_duration_s=shot.duration_s,
fps=fps,
chunks=chunks
)
def plan_storyboard(
self,
storyboard: Storyboard,
override_chunk_config: Optional[Dict[str, Any]] = None
) -> List[ShotPlan]:
"""
Create generation plans for all shots in a storyboard.
Args:
storyboard: Storyboard to plan
override_chunk_config: Optional config to override defaults
Returns:
List of ShotPlan objects
"""
# Apply overrides if provided
fps = override_chunk_config.get("fps", storyboard.project.fps) if override_chunk_config else storyboard.project.fps
max_chunk = override_chunk_config.get("max_chunk_seconds", self.max_chunk_seconds) if override_chunk_config else self.max_chunk_seconds
overlap = override_chunk_config.get("overlap_seconds", self.overlap_seconds) if override_chunk_config else self.overlap_seconds
mode = override_chunk_config.get("mode", self.chunking_mode) if override_chunk_config else self.chunking_mode
planner = ShotPlanner(
fps=fps,
max_chunk_seconds=max_chunk,
overlap_seconds=overlap,
chunking_mode=mode
)
plans = []
prev_shot_id = None
for i, shot in enumerate(storyboard.shots):
# Determine if we should use init frame from previous shot
use_init = i > 0 and shot.generation.use_init_frame_from_prev
plan = planner.plan_shot(shot, fps=fps, use_init_from_prev=use_init)
plans.append(plan)
prev_shot_id = shot.id
return plans
def get_total_frames(self, plans: List[ShotPlan]) -> int:
"""Calculate total frames across all plans."""
return sum(plan.total_frames for plan in plans)
def get_total_duration(self, plans: List[ShotPlan]) -> float:
"""Calculate total duration in seconds across all plans."""
return sum(plan.total_duration_s for plan in plans)

View File

@@ -0,0 +1,3 @@
"""
Upscaling module.
"""

View File

@@ -0,0 +1,189 @@
"""
Tests for checkpoint system.
"""
import pytest
import tempfile
import os
from pathlib import Path
from src.core.checkpoint import (
CheckpointManager, ShotCheckpoint, ProjectCheckpoint,
ShotStatus, ProjectStatus
)
class TestCheckpointManager:
"""Test checkpoint management."""
@pytest.fixture
def temp_db(self):
"""Create a temporary database for testing."""
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
db_path = f.name
yield db_path
# Cleanup
if os.path.exists(db_path):
os.unlink(db_path)
@pytest.fixture
def manager(self, temp_db):
"""Create a checkpoint manager with temp database."""
return CheckpointManager(db_path=temp_db)
def test_create_project(self, manager):
"""Test creating a project checkpoint."""
checkpoint = manager.create_project(
project_id="test_001",
storyboard_path="/path/to/storyboard.json",
output_dir="/path/to/output",
backend_name="wan_t2v_14b"
)
assert checkpoint.project_id == "test_001"
assert checkpoint.status == ProjectStatus.INITIALIZED
assert checkpoint.backend_name == "wan_t2v_14b"
def test_get_project(self, manager):
"""Test retrieving a project checkpoint."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/to/storyboard.json",
output_dir="/path/to/output",
backend_name="wan_t2v_14b"
)
retrieved = manager.get_project("test_001")
assert retrieved is not None
assert retrieved.project_id == "test_001"
missing = manager.get_project("nonexistent")
assert missing is None
def test_update_project_status(self, manager):
"""Test updating project status."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/to/storyboard.json",
output_dir="/path/to/output",
backend_name="wan_t2v_14b"
)
manager.update_project_status("test_001", ProjectStatus.RUNNING)
retrieved = manager.get_project("test_001")
assert retrieved.status == ProjectStatus.RUNNING
def test_create_shot(self, manager):
"""Test creating a shot checkpoint."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/storyboard.json",
output_dir="/path/output",
backend_name="wan"
)
shot = manager.create_shot("test_001", "S01", ShotStatus.PENDING)
assert shot.shot_id == "S01"
assert shot.status == ShotStatus.PENDING
def test_update_shot(self, manager):
"""Test updating shot checkpoint."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/storyboard.json",
output_dir="/path/output",
backend_name="wan"
)
manager.create_shot("test_001", "S01")
manager.update_shot(
"test_001",
"S01",
status=ShotStatus.IN_PROGRESS,
output_path="/path/output.mp4",
metadata={"seed": 12345}
)
shot = manager.get_shot("test_001", "S01")
assert shot.status == ShotStatus.IN_PROGRESS
assert shot.output_path == "/path/output.mp4"
assert shot.metadata["seed"] == 12345
def test_get_pending_shots(self, manager):
"""Test retrieving pending shots."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/storyboard.json",
output_dir="/path/output",
backend_name="wan"
)
manager.create_shot("test_001", "S01", ShotStatus.PENDING)
manager.create_shot("test_001", "S02", ShotStatus.COMPLETED)
manager.create_shot("test_001", "S03", ShotStatus.PENDING)
pending = manager.get_pending_shots("test_001")
assert len(pending) == 2
assert "S01" in pending
assert "S03" in pending
def test_can_resume(self, manager):
"""Test checking if project can be resumed."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/storyboard.json",
output_dir="/path/output",
backend_name="wan"
)
# No shots yet - can't resume
assert not manager.can_resume("test_001")
# Add pending shot - can resume
manager.create_shot("test_001", "S01", ShotStatus.PENDING)
assert manager.can_resume("test_001")
# Complete all shots - can't resume
manager.update_shot("test_001", "S01", status=ShotStatus.COMPLETED)
assert not manager.can_resume("test_001")
def test_delete_project(self, manager):
"""Test deleting a project."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/storyboard.json",
output_dir="/path/output",
backend_name="wan"
)
manager.create_shot("test_001", "S01")
manager.delete_project("test_001")
assert manager.get_project("test_001") is None
assert manager.get_shot("test_001", "S01") is None
def test_list_projects(self, manager):
"""Test listing all projects."""
manager.create_project(
project_id="test_001",
storyboard_path="/path/1.json",
output_dir="/out/1",
backend_name="wan"
)
manager.create_project(
project_id="test_002",
storyboard_path="/path/2.json",
output_dir="/out/2",
backend_name="svd"
)
projects = manager.list_projects()
assert len(projects) == 2
project_ids = [p.project_id for p in projects]
assert "test_001" in project_ids
assert "test_002" in project_ids

163
tests/unit/test_config.py Normal file
View File

@@ -0,0 +1,163 @@
"""
Tests for configuration system.
"""
import pytest
import tempfile
import os
from pathlib import Path
from src.core.config import Config, BackendConfig, ConfigLoader, get_config, reload_config
class TestConfigLoader:
"""Test configuration loading."""
def test_load_default_config(self):
"""Test loading with no config files (uses defaults)."""
config = ConfigLoader.load(
config_path=Path("nonexistent.yaml"),
env_file=Path("nonexistent.env")
)
assert isinstance(config, Config)
assert config.active_backend == "wan_t2v_14b"
assert len(config.backends) == 0 # No YAML loaded
def test_load_yaml_config(self):
"""Test loading from YAML file."""
yaml_content = """
backends:
test_backend:
name: "Test Backend"
class: "test.TestBackend"
model_id: "test/model"
vram_gb: 8
dtype: "fp16"
enable_vae_slicing: true
enable_vae_tiling: false
chunking:
enabled: true
mode: "sequential"
max_chunk_seconds: 4
overlap_seconds: 1
active_backend: "test_backend"
defaults:
fps: 30
checkpoint_db: "test.db"
model_cache_dir: "~/test_cache"
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
f.write(yaml_content)
yaml_path = f.name
try:
config = ConfigLoader.load(config_path=Path(yaml_path))
assert config.active_backend == "test_backend"
assert "test_backend" in config.backends
backend = config.get_backend("test_backend")
assert backend.name == "Test Backend"
assert backend.vram_gb == 8
assert backend.chunking_mode == "sequential"
assert config.defaults["fps"] == 30
finally:
os.unlink(yaml_path)
def test_load_env_file(self, monkeypatch):
"""Test loading from .env file."""
# Clear any existing env vars first
for key in ['ACTIVE_BACKEND', 'MODEL_CACHE_DIR', 'LOG_LEVEL']:
monkeypatch.delenv(key, raising=False)
env_content = """
ACTIVE_BACKEND=env_backend
MODEL_CACHE_DIR=/env/cache
LOG_LEVEL=DEBUG
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.env', delete=False) as f:
f.write(env_content)
env_path = f.name
try:
config = ConfigLoader.load(
config_path=Path("nonexistent.yaml"),
env_file=Path(env_path)
)
assert config.active_backend == "env_backend"
assert "/env/cache" in config.model_cache_dir or "\\env\\cache" in config.model_cache_dir
assert config.log_level == "DEBUG"
finally:
os.unlink(env_path)
# Clean up env vars
for key in ['ACTIVE_BACKEND', 'MODEL_CACHE_DIR', 'LOG_LEVEL']:
monkeypatch.delenv(key, raising=False)
def test_env_variable_override(self, monkeypatch):
"""Test that environment variables override config files."""
# Clear env var first
monkeypatch.delenv("ACTIVE_BACKEND", raising=False)
yaml_content = """
active_backend: yaml_backend
model_cache_dir: ~/yaml_cache
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
f.write(yaml_content)
yaml_path = f.name
try:
monkeypatch.setenv("ACTIVE_BACKEND", "env_override")
config = ConfigLoader.load(config_path=Path(yaml_path))
# Environment variable should override
assert config.active_backend == "env_override"
# But other settings from YAML should remain
assert "yaml_cache" in config.model_cache_dir
finally:
os.unlink(yaml_path)
monkeypatch.delenv("ACTIVE_BACKEND", raising=False)
def test_path_expansion(self):
"""Test that paths are properly expanded."""
config = ConfigLoader.load(
config_path=Path("nonexistent.yaml"),
env_file=Path("nonexistent.env")
)
# Default paths should be expanded
assert not config.model_cache_dir.startswith("~")
assert not config.checkpoint_db.startswith("~")
def test_get_backend_not_found(self):
"""Test getting a non-existent backend."""
config = Config()
with pytest.raises(ValueError, match="not found"):
config.get_backend("nonexistent")
class TestBackendConfig:
"""Test backend configuration."""
def test_backend_config_defaults(self):
"""Test backend config with defaults."""
config = BackendConfig(
name="Test",
model_class="test.Test",
model_id="test/model",
vram_gb=12,
dtype="fp16"
)
assert config.enable_vae_slicing is True
assert config.enable_vae_tiling is False
assert config.chunking_enabled is True
assert config.chunking_mode == "sequential"

View File

@@ -0,0 +1,268 @@
"""
Tests for storyboard schema validation.
"""
import pytest
from pathlib import Path
import json
import tempfile
import os
from src.storyboard.schema import (
Storyboard, ProjectSettings, Shot, Character, Location,
GlobalStyle, CameraSettings, GenerationSettings, OutputSettings,
Resolution
)
from src.storyboard.loader import StoryboardValidator, StoryboardLoadError
class TestStoryboardSchema:
"""Test storyboard schema validation."""
def test_create_minimal_storyboard(self):
"""Test creating a minimal valid storyboard."""
storyboard = Storyboard(
project=ProjectSettings(
title="Test Video",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
shots=[
Shot(
id="S01",
duration_s=5,
prompt="A test shot"
)
]
)
assert storyboard.project.title == "Test Video"
assert len(storyboard.shots) == 1
assert storyboard.shots[0].id == "S01"
def test_shot_validation_unique_ids(self):
"""Test that duplicate shot IDs are rejected."""
with pytest.raises(ValueError, match="Shot IDs must be unique"):
Storyboard(
project=ProjectSettings(
title="Test",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
shots=[
Shot(id="S01", duration_s=5, prompt="Shot 1"),
Shot(id="S01", duration_s=5, prompt="Shot 2")
]
)
def test_character_validation_unique_ids(self):
"""Test that duplicate character IDs are rejected."""
with pytest.raises(ValueError, match="Character IDs must be unique"):
Storyboard(
project=ProjectSettings(
title="Test",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
characters=[
Character(id="C01", name="Character 1"),
Character(id="C01", name="Character 2")
],
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
)
def test_location_validation_unique_ids(self):
"""Test that duplicate location IDs are rejected."""
with pytest.raises(ValueError, match="Location IDs must be unique"):
Storyboard(
project=ProjectSettings(
title="Test",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
locations=[
Location(id="L01", name="Location 1"),
Location(id="L01", name="Location 2")
],
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
)
def test_get_character_by_id(self):
"""Test retrieving character by ID."""
storyboard = Storyboard(
project=ProjectSettings(
title="Test",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
characters=[
Character(id="C01", name="Hero", description="The protagonist")
],
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
)
char = storyboard.get_character("C01")
assert char is not None
assert char.name == "Hero"
missing = storyboard.get_character("C99")
assert missing is None
def test_get_location_by_id(self):
"""Test retrieving location by ID."""
storyboard = Storyboard(
project=ProjectSettings(
title="Test",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
locations=[
Location(id="L01", name="Street", description="A city street")
],
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
)
loc = storyboard.get_location("L01")
assert loc is not None
assert loc.name == "Street"
missing = storyboard.get_location("L99")
assert missing is None
def test_total_duration_calculation(self):
"""Test total duration calculation."""
storyboard = Storyboard(
project=ProjectSettings(
title="Test",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
shots=[
Shot(id="S01", duration_s=4, prompt="Shot 1"),
Shot(id="S02", duration_s=6, prompt="Shot 2")
]
)
assert storyboard.get_total_duration() == 10
def test_total_frames_calculation(self):
"""Test total frames calculation."""
storyboard = Storyboard(
project=ProjectSettings(
title="Test",
fps=24,
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
shots=[
Shot(id="S01", duration_s=4, prompt="Shot 1"),
Shot(id="S02", duration_s=6, prompt="Shot 2")
]
)
assert storyboard.get_total_frames() == 240 # 10 seconds * 24 fps
class TestStoryboardLoader:
"""Test storyboard loading functionality."""
def test_load_valid_storyboard(self):
"""Test loading a valid storyboard JSON file."""
data = {
"schema_version": "1.0",
"project": {
"title": "Test Video",
"fps": 24,
"target_duration_s": 10,
"resolution": {"width": 1920, "height": 1080},
"aspect_ratio": "16:9",
"global_style": {
"visual_style": "cinematic",
"negative_prompt": "blurry"
},
"audio": {"add_music": False}
},
"characters": [],
"locations": [],
"shots": [
{
"id": "S01",
"duration_s": 5,
"prompt": "A test shot",
"camera": {"framing": "wide"},
"generation": {"seed": 12345, "steps": 30}
}
],
"output": {"container": "mp4", "codec": "h264"}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(data, f)
temp_path = f.name
try:
storyboard = StoryboardValidator.load(temp_path)
assert storyboard.project.title == "Test Video"
assert len(storyboard.shots) == 1
assert storyboard.shots[0].id == "S01"
finally:
os.unlink(temp_path)
def test_load_nonexistent_file(self):
"""Test loading a file that doesn't exist."""
with pytest.raises(StoryboardLoadError, match="not found"):
StoryboardValidator.load("/nonexistent/path/storyboard.json")
def test_load_invalid_json(self):
"""Test loading invalid JSON."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
f.write("{invalid json}")
temp_path = f.name
try:
with pytest.raises(StoryboardLoadError, match="Invalid JSON"):
StoryboardValidator.load(temp_path)
finally:
os.unlink(temp_path)
def test_load_wrong_extension(self):
"""Test loading a file with wrong extension."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
f.write('{}')
temp_path = f.name
try:
with pytest.raises(StoryboardLoadError, match="must be JSON"):
StoryboardValidator.load(temp_path)
finally:
os.unlink(temp_path)
def test_validate_references(self):
"""Test reference validation."""
storyboard = Storyboard(
project=ProjectSettings(
title="Test",
resolution=Resolution(width=1920, height=1080),
target_duration_s=10
),
characters=[
Character(id="C01", name="Hero")
],
locations=[
Location(id="L01", name="Street")
],
shots=[
Shot(
id="S01",
duration_s=5,
prompt="Test",
characters=["C01", "C99"], # C99 doesn't exist
location_id="L99" # Doesn't exist
)
]
)
issues = StoryboardValidator.validate_references(storyboard)
assert len(issues) == 2
assert any("C99" in issue for issue in issues)
assert any("L99" in issue for issue in issues)