Initial Commit
This commit is contained in:
23
.env.example
Normal file
23
.env.example
Normal file
@@ -0,0 +1,23 @@
|
||||
# Environment configuration for storyboard-video
|
||||
# Copy this file to .env and customize as needed
|
||||
|
||||
# Active backend selection (must match a key in config/models.yaml)
|
||||
ACTIVE_BACKEND=wan_t2v_14b
|
||||
|
||||
# Model cache directory (where downloaded models are stored)
|
||||
MODEL_CACHE_DIR=~/.cache/storyboard-video/models
|
||||
|
||||
# HuggingFace Hub cache (if using HF models)
|
||||
HF_HOME=~/.cache/huggingface
|
||||
|
||||
# CUDA settings
|
||||
# CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
# Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
LOG_LEVEL=INFO
|
||||
|
||||
# Output directory base path
|
||||
OUTPUT_BASE_DIR=./outputs
|
||||
|
||||
# FFmpeg path (leave empty to use system PATH)
|
||||
FFMPEG_PATH=
|
||||
138
.gitignore
vendored
Normal file
138
.gitignore
vendored
Normal file
@@ -0,0 +1,138 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# Virtual environments
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
.venv
|
||||
|
||||
# Conda
|
||||
conda-env/
|
||||
*.conda
|
||||
|
||||
# IDEs
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.DS_Store
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# Pytest
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
# Model caches and outputs
|
||||
models/
|
||||
*.ckpt
|
||||
*.safetensors
|
||||
*.bin
|
||||
*.pth
|
||||
*.onnx
|
||||
outputs/
|
||||
checkpoints.db
|
||||
|
||||
# Temporary files
|
||||
tmp/
|
||||
temp/
|
||||
*.tmp
|
||||
*.temp
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Video files (generated)
|
||||
*.mp4
|
||||
*.mov
|
||||
*.avi
|
||||
*.mkv
|
||||
*.webm
|
||||
|
||||
# Image files (generated)
|
||||
*.png
|
||||
*.jpg
|
||||
*.jpeg
|
||||
*.gif
|
||||
*.bmp
|
||||
*.tiff
|
||||
|
||||
# Audio files (generated)
|
||||
*.wav
|
||||
*.mp3
|
||||
*.aac
|
||||
*.flac
|
||||
*.ogg
|
||||
|
||||
# HuggingFace cache
|
||||
.cache/
|
||||
transformers_cache/
|
||||
diffusers_cache/
|
||||
|
||||
# Windows
|
||||
Thumbs.db
|
||||
ehthumbs.db
|
||||
Desktop.ini
|
||||
$RECYCLE.BIN/
|
||||
|
||||
# macOS
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
Icon
|
||||
._*
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Project specific
|
||||
# Keep templates but ignore generated outputs
|
||||
!/templates/*.json
|
||||
!/templates/*.yaml
|
||||
!/templates/*.yml
|
||||
|
||||
# Allow example files
|
||||
!.env.example
|
||||
!config/*.yaml
|
||||
!config/*.yml
|
||||
|
||||
# Keep docs
|
||||
docs/*.md
|
||||
210
AGENTS.MD
Normal file
210
AGENTS.MD
Normal file
@@ -0,0 +1,210 @@
|
||||
# agents.md
|
||||
## Project: Local AI Video Generation from Text Storyboards (Windows + RTX 5070 12GB)
|
||||
|
||||
### 0) Who this is for
|
||||
The owner (user) is not an ML expert. The system must:
|
||||
- be reproducible (conda + requirements)
|
||||
- have guardrails (configs, logs, validation)
|
||||
- be test-driven (pytest)
|
||||
- maintain docs (developer + user)
|
||||
|
||||
---
|
||||
|
||||
## 1) High-Level Goal
|
||||
Build a local pipeline that converts **text-only storyboards** into **15–30 second videos** by:
|
||||
1) converting storyboard -> shot plan
|
||||
2) generating shot clips (T2V or I2V when possible)
|
||||
3) assembling clips into a final MP4
|
||||
4) upscaling to 2K/4K if desired
|
||||
|
||||
This is a **shot-based** system, not “one prompt makes a whole movie”.
|
||||
|
||||
---
|
||||
|
||||
## 2) Hard Constraints (Hardware & OS)
|
||||
Target system:
|
||||
- Windows 11
|
||||
- NVIDIA RTX 5070 (12GB VRAM) - Must use GPU.
|
||||
- 32GB RAM
|
||||
- 2TB SSD
|
||||
- Anaconda available
|
||||
|
||||
Design must be stable under 12GB VRAM using:
|
||||
- fp16/bf16
|
||||
- attention slicing
|
||||
- xFormers / SDPA where supported
|
||||
- optional CPU offload
|
||||
|
||||
---
|
||||
|
||||
## 3) Output Targets (Realistic)
|
||||
- Native generation: 720p–1080p (preferred)
|
||||
- Final delivery: 1080p required; 2K/4K via upscaling
|
||||
- Duration: 15–30s per video (may be segmented)
|
||||
- FPS: 24 default
|
||||
- Output: MP4 (H.264/H.265)
|
||||
|
||||
---
|
||||
|
||||
## 4) CUDA 13.1 Reality & PyTorch Plan (Critical)
|
||||
User has CUDA Toolkit 13.1 installed. Current PyTorch builds generally ship with and target CUDA 12.x runtimes.
|
||||
We must NOT assume PyTorch will build/run against local CUDA 13.1 toolkit.
|
||||
|
||||
**Plan:**
|
||||
- Use **PyTorch prebuilt binaries that bundle CUDA runtime** (e.g., cu121 / cu124).
|
||||
- Rely on NVIDIA driver compatibility rather than local CUDA toolkit version.
|
||||
- Avoid compiling custom CUDA extensions unless necessary.
|
||||
|
||||
Implementation notes:
|
||||
- Prefer installing PyTorch via conda or pip using official CUDA 12.x builds.
|
||||
- If xFormers causes build issues, use PyTorch SDPA and disable xFormers.
|
||||
|
||||
---
|
||||
|
||||
## 5) Approved Stack (Do Not Deviate)
|
||||
### Core
|
||||
- Python 3.10 or 3.11 (conda env)
|
||||
- PyTorch (CUDA 12.x build: cu121 or cu124)
|
||||
- diffusers + transformers + accelerate + safetensors
|
||||
- ffmpeg for assembly
|
||||
- opencv-python for frame IO (if needed)
|
||||
- pydantic for config/schema validation
|
||||
- rich / loguru for logs
|
||||
|
||||
### Testing
|
||||
- pytest
|
||||
- pytest-cov
|
||||
- snapshot-ish tests where feasible (metadata + shapes, not visual perfection)
|
||||
|
||||
### Docs
|
||||
- /docs/developer.md (developer documentation)
|
||||
- /docs/user.md (user manual)
|
||||
- Keep docs updated alongside code changes.
|
||||
|
||||
---
|
||||
|
||||
## 6) Video Models (Pragmatic Choices)
|
||||
### Primary (target)
|
||||
- WAN 2.x family (T2V; optional I2V if supported in chosen pipeline)
|
||||
Goal: best possible quality on consumer VRAM with chunking.
|
||||
|
||||
### Secondary / fallback
|
||||
- Stable Video Diffusion (SVD) if WAN is unstable
|
||||
- LTX-Video (only if it fits and is stable in our stack)
|
||||
|
||||
All model backends must implement the same interface:
|
||||
- generate_shot(shot_spec) -> video_file + metadata
|
||||
|
||||
---
|
||||
|
||||
## 7) Canonical Input: Storyboard JSON
|
||||
Storyboard source is text-only (often AI-generated). We will store and validate it as JSON.
|
||||
|
||||
A template exists at: `templates/storyboard.template.json`
|
||||
|
||||
We will later build a utility script:
|
||||
- input: plain text fields or a simple text format
|
||||
- output: valid storyboard JSON
|
||||
|
||||
---
|
||||
|
||||
## 8) Pipeline Modules (Required)
|
||||
### A) Storyboard parsing & validation
|
||||
- Load storyboard JSON
|
||||
- Validate schema
|
||||
- Expand defaults (fps, resolution, global style)
|
||||
- Produce normalized shot list
|
||||
|
||||
### B) Prompt compilation
|
||||
- Merge global style + shot prompt + camera notes
|
||||
- Produce positive + negative prompts
|
||||
- Keep deterministic via seeds
|
||||
|
||||
### C) Generation runner (per shot)
|
||||
- For each shot: generate clip
|
||||
- Support:
|
||||
- seed control
|
||||
- chunking (e.g., generate 4–6 seconds then continue)
|
||||
- optional init frame handoff between shots
|
||||
|
||||
### D) Assembly
|
||||
- Use ffmpeg concat to build final video
|
||||
- Optionally add:
|
||||
- transitions
|
||||
- temp audio
|
||||
- burn-in shot IDs for debugging mode
|
||||
|
||||
### E) Upscaling (optional)
|
||||
- Upscale final to 2K/4K (post step)
|
||||
- Keep this modular so user can skip.
|
||||
|
||||
---
|
||||
|
||||
## 9) Determinism & Logging (Must Have)
|
||||
For each shot and final render, save:
|
||||
- prompts (positive/negative)
|
||||
- seed(s)
|
||||
- model + revision/hash info if available
|
||||
- inference params (steps, cfg, sampler, resolution, fps, frames)
|
||||
- timing + VRAM notes if possible
|
||||
|
||||
Every run produces a folder:
|
||||
- outputs/<project>/<timestamp>/
|
||||
- shots/
|
||||
- assembled/
|
||||
- metadata/
|
||||
|
||||
---
|
||||
|
||||
## 10) Testing Rules (Hard Requirement)
|
||||
- Tests must be written alongside features.
|
||||
- Whenever a file/function is modified, corresponding tests MUST be updated.
|
||||
- Prefer tests that verify:
|
||||
- schema validation works
|
||||
- prompt compiler output is stable
|
||||
- shot planner expands durations -> frame counts
|
||||
- assembly command lines are correct
|
||||
- metadata is generated correctly
|
||||
|
||||
Do not require “visual quality” assertions. Test structure and determinism.
|
||||
|
||||
---
|
||||
|
||||
## 11) Documentation Rules (Hard Requirement)
|
||||
Maintain these continuously:
|
||||
- docs/developer.md
|
||||
- architecture
|
||||
- install steps
|
||||
- how to run tests
|
||||
- how to add a new model backend
|
||||
- docs/user.md
|
||||
- quickstart
|
||||
- how to create storyboard JSON
|
||||
- how to run generation
|
||||
- where outputs go
|
||||
- troubleshooting (VRAM, drivers, ffmpeg)
|
||||
|
||||
Docs must be updated whenever CLI flags, file formats, or workflows change.
|
||||
|
||||
---
|
||||
|
||||
## 12) Project Files to Maintain
|
||||
Required:
|
||||
- requirements.txt (pip deps)
|
||||
- environment.yml (conda env)
|
||||
- templates/storyboard.template.json
|
||||
- docs/developer.md
|
||||
- docs/user.md
|
||||
- src/ (implementation)
|
||||
- tests/ (pytest)
|
||||
|
||||
---
|
||||
|
||||
## 13) Definition of Done
|
||||
A feature is “done” only if:
|
||||
- implemented
|
||||
- tests added/updated
|
||||
- docs updated
|
||||
- reproducible install instructions remain valid
|
||||
|
||||
End of file.
|
||||
35
TODO.MD
Normal file
35
TODO.MD
Normal file
@@ -0,0 +1,35 @@
|
||||
# TODO
|
||||
|
||||
## Repo bootstrap
|
||||
- [x] Define project direction and constraints in `agents.md`
|
||||
- [x] Add `requirements.txt` (pip dependencies)
|
||||
- [x] Add `environment.yml` (conda environment; PyTorch CUDA 12.x runtime strategy)
|
||||
- [x] Add storyboard JSON template at `templates/storyboard.template.json`
|
||||
|
||||
## Core implementation (next)
|
||||
- [ ] Create repo structure: `src/`, `tests/`, `docs/`, `templates/`, `outputs/`
|
||||
- [ ] Implement storyboard schema validator (pydantic) + loader
|
||||
- [ ] Implement prompt compiler (global style + shot + camera)
|
||||
- [ ] Implement shot planning (duration -> frame count, chunk plan)
|
||||
- [ ] Implement model backend interface (`BaseVideoBackend`)
|
||||
- [ ] Implement WAN backend (primary) with VRAM-safe defaults
|
||||
- [ ] Implement fallback backend (SVD) for reliability testing
|
||||
- [ ] Implement ffmpeg assembler (concat + optional audio + debug burn-in)
|
||||
- [ ] Implement optional upscaling module (post-process)
|
||||
|
||||
## Utilities
|
||||
- [ ] Write storyboard “plain text → JSON” utility script (fills `storyboard.template.json`)
|
||||
- [ ] Add config file support (YAML/JSON) for global defaults
|
||||
|
||||
## Testing (parallel work; required)
|
||||
- [ ] Add `pytest` scaffolding
|
||||
- [ ] Add tests for schema validation
|
||||
- [ ] Add tests for prompt compilation determinism
|
||||
- [ ] Add tests for shot planning (frames/chunks)
|
||||
- [ ] Add tests for ffmpeg command generation (no actual render needed)
|
||||
- [ ] Ensure every code change includes a corresponding test update
|
||||
|
||||
## Documentation (maintained continuously)
|
||||
- [ ] Create `docs/developer.md` (install, architecture, tests, adding backends)
|
||||
- [ ] Create `docs/user.md` (quickstart, storyboard creation, running, outputs, troubleshooting)
|
||||
- [ ] Keep docs updated whenever CLI/config/schema changes
|
||||
61
config/models.yaml
Normal file
61
config/models.yaml
Normal file
@@ -0,0 +1,61 @@
|
||||
# Model configurations for video generation backends
|
||||
# Backend selection is controlled via ACTIVE_BACKEND environment variable
|
||||
# or --backend CLI flag
|
||||
|
||||
backends:
|
||||
wan_t2v_14b:
|
||||
name: "Wan 2.1 T2V 14B"
|
||||
class: "generation.backends.wan.WanBackend"
|
||||
model_id: "Wan-AI/Wan2.1-T2V-14B"
|
||||
vram_gb: 12
|
||||
dtype: "fp16"
|
||||
enable_vae_slicing: true
|
||||
enable_vae_tiling: true
|
||||
chunking:
|
||||
enabled: true
|
||||
mode: "sequential" # Options: sequential, overlapping
|
||||
max_chunk_seconds: 4
|
||||
overlap_seconds: 1 # Only used when mode is overlapping
|
||||
|
||||
wan_t2v_1_3b:
|
||||
name: "Wan 2.1 T2V 1.3B"
|
||||
class: "generation.backends.wan.WanBackend"
|
||||
model_id: "Wan-AI/Wan2.1-T2V-1.3B"
|
||||
vram_gb: 8
|
||||
dtype: "fp16"
|
||||
enable_vae_slicing: true
|
||||
enable_vae_tiling: false
|
||||
chunking:
|
||||
enabled: true
|
||||
mode: "sequential"
|
||||
max_chunk_seconds: 6
|
||||
overlap_seconds: 1
|
||||
|
||||
svd:
|
||||
name: "Stable Video Diffusion"
|
||||
class: "generation.backends.svd.SVDBackend"
|
||||
model_id: "stabilityai/stable-video-diffusion-img2vid-xt"
|
||||
vram_gb: 10
|
||||
dtype: "fp16"
|
||||
enable_vae_slicing: true
|
||||
enable_vae_tiling: true
|
||||
chunking:
|
||||
enabled: false
|
||||
|
||||
# Default backend selection
|
||||
# Override with ACTIVE_BACKEND environment variable
|
||||
active_backend: "wan_t2v_14b"
|
||||
|
||||
# Global generation defaults
|
||||
defaults:
|
||||
fps: 24
|
||||
resolution:
|
||||
width: 1920
|
||||
height: 1080
|
||||
|
||||
# Checkpoint database settings
|
||||
checkpoint_db: "outputs/checkpoints.db"
|
||||
|
||||
# Model cache directory
|
||||
# Override with MODEL_CACHE_DIR environment variable
|
||||
model_cache_dir: "~/.cache/storyboard-video/models"
|
||||
24
environment.yml
Normal file
24
environment.yml
Normal file
@@ -0,0 +1,24 @@
|
||||
# environment.yml
|
||||
name: storyboard-video
|
||||
channels:
|
||||
- pytorch
|
||||
- nvidia
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.11
|
||||
- pip=24.3.1
|
||||
|
||||
# FFmpeg in env is easiest on Windows if you rely on conda-forge
|
||||
- ffmpeg=7.1
|
||||
|
||||
# PyTorch with CUDA runtime bundled (NOT using local CUDA 13.1 toolkit)
|
||||
# Using compatible versions for PyTorch 2.5.1 with CUDA 12.4
|
||||
- pytorch=2.5.1
|
||||
- pytorch-cuda=12.4
|
||||
- torchvision=0.20.1
|
||||
- torchaudio=2.5.1
|
||||
|
||||
# pip deps (mirrors requirements.txt)
|
||||
- pip:
|
||||
- -r requirements.txt
|
||||
7
pytest.ini
Normal file
7
pytest.ini
Normal file
@@ -0,0 +1,7 @@
|
||||
# pytest configuration
|
||||
[tool:pytest]
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts = -v --tb=short
|
||||
26
requirements.txt
Normal file
26
requirements.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
# requirements.txt
|
||||
# Install PyTorch separately (see environment.yml and docs).
|
||||
|
||||
diffusers==0.32.2
|
||||
transformers==4.48.3
|
||||
accelerate==1.3.0
|
||||
safetensors==0.5.2
|
||||
|
||||
pydantic==2.10.6
|
||||
pyyaml==6.0.2
|
||||
numpy==2.1.3
|
||||
pillow==11.1.0
|
||||
opencv-python==4.11.0.86
|
||||
|
||||
rich==13.9.4
|
||||
loguru==0.7.3
|
||||
|
||||
tqdm==4.67.1
|
||||
requests==2.32.3
|
||||
|
||||
# CLI & tooling
|
||||
typer==0.15.1
|
||||
|
||||
# Testing
|
||||
pytest==8.3.4
|
||||
pytest-cov==6.0.0
|
||||
5
src/__init__.py
Normal file
5
src/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Storyboard video generation pipeline.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
3
src/assembly/__init__.py
Normal file
3
src/assembly/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Video assembly and post-processing.
|
||||
"""
|
||||
3
src/cli/__init__.py
Normal file
3
src/cli/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
CLI entry points.
|
||||
"""
|
||||
25
src/core/__init__.py
Normal file
25
src/core/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Core utilities and shared components.
|
||||
"""
|
||||
|
||||
from .config import Config, BackendConfig, ConfigLoader, get_config, reload_config
|
||||
from .checkpoint import (
|
||||
CheckpointManager,
|
||||
ShotCheckpoint,
|
||||
ProjectCheckpoint,
|
||||
ShotStatus,
|
||||
ProjectStatus
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Config',
|
||||
'BackendConfig',
|
||||
'ConfigLoader',
|
||||
'get_config',
|
||||
'reload_config',
|
||||
'CheckpointManager',
|
||||
'ShotCheckpoint',
|
||||
'ProjectCheckpoint',
|
||||
'ShotStatus',
|
||||
'ProjectStatus'
|
||||
]
|
||||
435
src/core/checkpoint.py
Normal file
435
src/core/checkpoint.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
Checkpoint and resume system using SQLite.
|
||||
Tracks generation progress and enables resuming from failures.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ShotStatus(str, Enum):
|
||||
"""Status of a shot generation."""
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class ProjectStatus(str, Enum):
|
||||
"""Status of a project."""
|
||||
INITIALIZED = "initialized"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShotCheckpoint:
|
||||
"""Checkpoint data for a single shot."""
|
||||
shot_id: str
|
||||
status: ShotStatus
|
||||
output_path: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProjectCheckpoint:
|
||||
"""Checkpoint data for a project."""
|
||||
project_id: str
|
||||
storyboard_path: str
|
||||
output_dir: str
|
||||
status: ProjectStatus
|
||||
backend_name: str
|
||||
started_at: str
|
||||
completed_at: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
"""Manages checkpoints using SQLite database."""
|
||||
|
||||
def __init__(self, db_path: str = "outputs/checkpoints.db"):
|
||||
"""
|
||||
Initialize checkpoint manager.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize database schema."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS projects (
|
||||
project_id TEXT PRIMARY KEY,
|
||||
storyboard_path TEXT NOT NULL,
|
||||
output_dir TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
backend_name TEXT NOT NULL,
|
||||
started_at TEXT NOT NULL,
|
||||
completed_at TEXT,
|
||||
error_message TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS shots (
|
||||
project_id TEXT NOT NULL,
|
||||
shot_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
output_path TEXT,
|
||||
error_message TEXT,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
metadata TEXT,
|
||||
PRIMARY KEY (project_id, shot_id),
|
||||
FOREIGN KEY (project_id) REFERENCES projects(project_id)
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _get_connection(self):
|
||||
"""Get a new database connection."""
|
||||
return sqlite3.connect(self.db_path)
|
||||
|
||||
def create_project(
|
||||
self,
|
||||
project_id: str,
|
||||
storyboard_path: str,
|
||||
output_dir: str,
|
||||
backend_name: str
|
||||
) -> ProjectCheckpoint:
|
||||
"""
|
||||
Create a new project checkpoint.
|
||||
|
||||
Args:
|
||||
project_id: Unique project identifier
|
||||
storyboard_path: Path to storyboard JSON file
|
||||
output_dir: Output directory for generated files
|
||||
backend_name: Name of the generation backend
|
||||
|
||||
Returns:
|
||||
ProjectCheckpoint object
|
||||
"""
|
||||
checkpoint = ProjectCheckpoint(
|
||||
project_id=project_id,
|
||||
storyboard_path=storyboard_path,
|
||||
output_dir=output_dir,
|
||||
status=ProjectStatus.INITIALIZED,
|
||||
backend_name=backend_name,
|
||||
started_at=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO projects
|
||||
(project_id, storyboard_path, output_dir, status, backend_name, started_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
checkpoint.project_id,
|
||||
checkpoint.storyboard_path,
|
||||
checkpoint.output_dir,
|
||||
checkpoint.status.value,
|
||||
checkpoint.backend_name,
|
||||
checkpoint.started_at
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return checkpoint
|
||||
|
||||
def get_project(self, project_id: str) -> Optional[ProjectCheckpoint]:
|
||||
"""Get project checkpoint by ID."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM projects WHERE project_id = ?",
|
||||
(project_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return ProjectCheckpoint(
|
||||
project_id=row[0],
|
||||
storyboard_path=row[1],
|
||||
output_dir=row[2],
|
||||
status=ProjectStatus(row[3]),
|
||||
backend_name=row[4],
|
||||
started_at=row[5],
|
||||
completed_at=row[6],
|
||||
error_message=row[7]
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_project_status(
|
||||
self,
|
||||
project_id: str,
|
||||
status: ProjectStatus,
|
||||
error_message: Optional[str] = None
|
||||
):
|
||||
"""Update project status."""
|
||||
completed_at = None
|
||||
if status in [ProjectStatus.COMPLETED, ProjectStatus.FAILED]:
|
||||
completed_at = datetime.now().isoformat()
|
||||
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE projects
|
||||
SET status = ?, error_message = ?, completed_at = ?
|
||||
WHERE project_id = ?
|
||||
""",
|
||||
(status.value, error_message, completed_at, project_id)
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def create_shot(
|
||||
self,
|
||||
project_id: str,
|
||||
shot_id: str,
|
||||
status: ShotStatus = ShotStatus.PENDING
|
||||
) -> ShotCheckpoint:
|
||||
"""Create a shot checkpoint."""
|
||||
checkpoint = ShotCheckpoint(
|
||||
shot_id=shot_id,
|
||||
status=status
|
||||
)
|
||||
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO shots
|
||||
(project_id, shot_id, status, metadata)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
project_id,
|
||||
shot_id,
|
||||
status.value,
|
||||
json.dumps(checkpoint.metadata)
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
return checkpoint
|
||||
|
||||
def get_shot(self, project_id: str, shot_id: str) -> Optional[ShotCheckpoint]:
|
||||
"""Get shot checkpoint."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM shots WHERE project_id = ? AND shot_id = ?",
|
||||
(project_id, shot_id)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return ShotCheckpoint(
|
||||
shot_id=row[1],
|
||||
status=ShotStatus(row[2]),
|
||||
output_path=row[3],
|
||||
error_message=row[4],
|
||||
started_at=row[5],
|
||||
completed_at=row[6],
|
||||
metadata=json.loads(row[7]) if row[7] else {}
|
||||
)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_project_shots(self, project_id: str) -> List[ShotCheckpoint]:
|
||||
"""Get all shots for a project."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM shots WHERE project_id = ? ORDER BY shot_id",
|
||||
(project_id,)
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
return [
|
||||
ShotCheckpoint(
|
||||
shot_id=row[1],
|
||||
status=ShotStatus(row[2]),
|
||||
output_path=row[3],
|
||||
error_message=row[4],
|
||||
started_at=row[5],
|
||||
completed_at=row[6],
|
||||
metadata=json.loads(row[7]) if row[7] else {}
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def update_shot(
|
||||
self,
|
||||
project_id: str,
|
||||
shot_id: str,
|
||||
status: Optional[ShotStatus] = None,
|
||||
output_path: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""Update shot checkpoint."""
|
||||
updates = []
|
||||
params = []
|
||||
|
||||
if status is not None:
|
||||
updates.append("status = ?")
|
||||
params.append(status.value)
|
||||
|
||||
if status == ShotStatus.IN_PROGRESS:
|
||||
updates.append("started_at = ?")
|
||||
params.append(datetime.now().isoformat())
|
||||
elif status in [ShotStatus.COMPLETED, ShotStatus.FAILED]:
|
||||
updates.append("completed_at = ?")
|
||||
params.append(datetime.now().isoformat())
|
||||
|
||||
if output_path is not None:
|
||||
updates.append("output_path = ?")
|
||||
params.append(output_path)
|
||||
|
||||
if error_message is not None:
|
||||
updates.append("error_message = ?")
|
||||
params.append(error_message)
|
||||
|
||||
if metadata is not None:
|
||||
updates.append("metadata = ?")
|
||||
params.append(json.dumps(metadata))
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
params.extend([project_id, shot_id])
|
||||
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
conn.execute(
|
||||
f"""
|
||||
UPDATE shots
|
||||
SET {', '.join(updates)}
|
||||
WHERE project_id = ? AND shot_id = ?
|
||||
""",
|
||||
params
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_pending_shots(self, project_id: str) -> List[str]:
|
||||
"""Get list of pending shot IDs for a project."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT shot_id FROM shots
|
||||
WHERE project_id = ? AND status = ?
|
||||
ORDER BY shot_id
|
||||
""",
|
||||
(project_id, ShotStatus.PENDING.value)
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_failed_shots(self, project_id: str) -> List[str]:
|
||||
"""Get list of failed shot IDs for a project."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT shot_id FROM shots
|
||||
WHERE project_id = ? AND status = ?
|
||||
ORDER BY shot_id
|
||||
""",
|
||||
(project_id, ShotStatus.FAILED.value)
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def can_resume(self, project_id: str) -> bool:
|
||||
"""Check if a project can be resumed."""
|
||||
project = self.get_project(project_id)
|
||||
if project is None:
|
||||
return False
|
||||
|
||||
if project.status == ProjectStatus.COMPLETED:
|
||||
return False
|
||||
|
||||
pending = self.get_pending_shots(project_id)
|
||||
failed = self.get_failed_shots(project_id)
|
||||
|
||||
return len(pending) > 0 or len(failed) > 0
|
||||
|
||||
def delete_project(self, project_id: str):
|
||||
"""Delete a project and all its shots."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
conn.execute("DELETE FROM shots WHERE project_id = ?", (project_id,))
|
||||
conn.execute("DELETE FROM projects WHERE project_id = ?", (project_id,))
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def list_projects(self) -> List[ProjectCheckpoint]:
|
||||
"""List all projects."""
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM projects ORDER BY started_at DESC"
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
return [
|
||||
ProjectCheckpoint(
|
||||
project_id=row[0],
|
||||
storyboard_path=row[1],
|
||||
output_dir=row[2],
|
||||
status=ProjectStatus(row[3]),
|
||||
backend_name=row[4],
|
||||
started_at=row[5],
|
||||
completed_at=row[6],
|
||||
error_message=row[7]
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
finally:
|
||||
conn.close()
|
||||
212
src/core/config.py
Normal file
212
src/core/config.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Configuration management system.
|
||||
Handles YAML configs and environment variables.
|
||||
"""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
"""Configuration for a generation backend."""
|
||||
name: str
|
||||
model_class: str
|
||||
model_id: str
|
||||
vram_gb: int
|
||||
dtype: str
|
||||
enable_vae_slicing: bool = True
|
||||
enable_vae_tiling: bool = False
|
||||
chunking_enabled: bool = True
|
||||
chunking_mode: str = "sequential"
|
||||
max_chunk_seconds: int = 4
|
||||
overlap_seconds: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""Application configuration."""
|
||||
active_backend: str = "wan_t2v_14b"
|
||||
backends: Dict[str, BackendConfig] = field(default_factory=dict)
|
||||
defaults: Dict[str, Any] = field(default_factory=dict)
|
||||
checkpoint_db: str = "outputs/checkpoints.db"
|
||||
model_cache_dir: str = "~/.cache/storyboard-video/models"
|
||||
log_level: str = "INFO"
|
||||
output_base_dir: str = "./outputs"
|
||||
|
||||
def get_backend(self, name: Optional[str] = None) -> BackendConfig:
|
||||
"""Get backend configuration by name."""
|
||||
name = name or self.active_backend
|
||||
if name not in self.backends:
|
||||
raise ValueError(f"Backend '{name}' not found in configuration")
|
||||
return self.backends[name]
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""Loads configuration from YAML files and environment variables."""
|
||||
|
||||
@staticmethod
|
||||
def load(config_path: Optional[Path] = None, env_file: Optional[Path] = None) -> Config:
|
||||
"""
|
||||
Load configuration from files and environment.
|
||||
|
||||
Priority (highest to lowest):
|
||||
1. Environment variables
|
||||
2. .env file
|
||||
3. YAML config file
|
||||
4. Defaults
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML config file (default: config/models.yaml)
|
||||
env_file: Path to .env file (default: .env if exists)
|
||||
|
||||
Returns:
|
||||
Config object
|
||||
"""
|
||||
# Start with defaults
|
||||
config = Config()
|
||||
|
||||
# Load YAML config
|
||||
if config_path is None:
|
||||
config_path = Path("config/models.yaml")
|
||||
|
||||
if config_path.exists():
|
||||
config = ConfigLoader._load_yaml(config_path, config)
|
||||
|
||||
# Load .env file
|
||||
if env_file is None:
|
||||
env_file = Path(".env")
|
||||
|
||||
if env_file.exists():
|
||||
config = ConfigLoader._load_env_file(env_file, config)
|
||||
|
||||
# Override with environment variables
|
||||
config = ConfigLoader._load_env_vars(config)
|
||||
|
||||
# Expand paths
|
||||
config.model_cache_dir = str(Path(config.model_cache_dir).expanduser())
|
||||
config.checkpoint_db = str(Path(config.checkpoint_db).expanduser())
|
||||
config.output_base_dir = str(Path(config.output_base_dir).expanduser())
|
||||
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def _load_yaml(path: Path, config: Config) -> Config:
|
||||
"""Load configuration from YAML file."""
|
||||
with open(path, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if not data:
|
||||
return config
|
||||
|
||||
# Load backends
|
||||
if 'backends' in data:
|
||||
for backend_id, backend_data in data['backends'].items():
|
||||
chunking = backend_data.get('chunking', {})
|
||||
config.backends[backend_id] = BackendConfig(
|
||||
name=backend_data.get('name', backend_id),
|
||||
model_class=backend_data.get('class', ''),
|
||||
model_id=backend_data.get('model_id', ''),
|
||||
vram_gb=backend_data.get('vram_gb', 12),
|
||||
dtype=backend_data.get('dtype', 'fp16'),
|
||||
enable_vae_slicing=backend_data.get('enable_vae_slicing', True),
|
||||
enable_vae_tiling=backend_data.get('enable_vae_tiling', False),
|
||||
chunking_enabled=chunking.get('enabled', True),
|
||||
chunking_mode=chunking.get('mode', 'sequential'),
|
||||
max_chunk_seconds=chunking.get('max_chunk_seconds', 4),
|
||||
overlap_seconds=chunking.get('overlap_seconds', 1)
|
||||
)
|
||||
|
||||
# Load active backend
|
||||
if 'active_backend' in data:
|
||||
config.active_backend = data['active_backend']
|
||||
|
||||
# Load defaults
|
||||
if 'defaults' in data:
|
||||
config.defaults = data['defaults']
|
||||
|
||||
# Load checkpoint DB path
|
||||
if 'checkpoint_db' in data:
|
||||
config.checkpoint_db = data['checkpoint_db']
|
||||
|
||||
# Load model cache dir
|
||||
if 'model_cache_dir' in data:
|
||||
config.model_cache_dir = data['model_cache_dir']
|
||||
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def _load_env_file(path: Path, config: Config) -> Config:
|
||||
"""Load configuration from .env file."""
|
||||
env_vars = {}
|
||||
with open(path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
|
||||
if '=' in line:
|
||||
key, value = line.split('=', 1)
|
||||
key = key.strip()
|
||||
value = value.strip().strip('"\'')
|
||||
env_vars[key] = value
|
||||
|
||||
# Apply env vars directly to config (don't set in os.environ)
|
||||
if 'ACTIVE_BACKEND' in env_vars:
|
||||
config.active_backend = env_vars['ACTIVE_BACKEND']
|
||||
|
||||
if 'MODEL_CACHE_DIR' in env_vars:
|
||||
config.model_cache_dir = env_vars['MODEL_CACHE_DIR']
|
||||
|
||||
if 'LOG_LEVEL' in env_vars:
|
||||
config.log_level = env_vars['LOG_LEVEL']
|
||||
|
||||
if 'OUTPUT_BASE_DIR' in env_vars:
|
||||
config.output_base_dir = env_vars['OUTPUT_BASE_DIR']
|
||||
|
||||
if 'CHECKPOINT_DB' in env_vars:
|
||||
config.checkpoint_db = env_vars['CHECKPOINT_DB']
|
||||
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def _load_env_vars(config: Config) -> Config:
|
||||
"""Override config with environment variables."""
|
||||
if 'ACTIVE_BACKEND' in os.environ:
|
||||
config.active_backend = os.environ['ACTIVE_BACKEND']
|
||||
|
||||
if 'MODEL_CACHE_DIR' in os.environ:
|
||||
config.model_cache_dir = os.environ['MODEL_CACHE_DIR']
|
||||
|
||||
if 'LOG_LEVEL' in os.environ:
|
||||
config.log_level = os.environ['LOG_LEVEL']
|
||||
|
||||
if 'OUTPUT_BASE_DIR' in os.environ:
|
||||
config.output_base_dir = os.environ['OUTPUT_BASE_DIR']
|
||||
|
||||
if 'CHECKPOINT_DB' in os.environ:
|
||||
config.checkpoint_db = os.environ['CHECKPOINT_DB']
|
||||
|
||||
return config
|
||||
|
||||
|
||||
# Global config instance (lazy-loaded)
|
||||
_config: Optional[Config] = None
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
"""Get the global configuration instance."""
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = ConfigLoader.load()
|
||||
return _config
|
||||
|
||||
|
||||
def reload_config() -> Config:
|
||||
"""Reload configuration from files."""
|
||||
global _config
|
||||
_config = ConfigLoader.load()
|
||||
return _config
|
||||
17
src/generation/__init__.py
Normal file
17
src/generation/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Video generation backends.
|
||||
"""
|
||||
|
||||
from .base import (
|
||||
BaseVideoBackend,
|
||||
GenerationResult,
|
||||
GenerationSpec,
|
||||
BackendFactory
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'BaseVideoBackend',
|
||||
'GenerationResult',
|
||||
'GenerationSpec',
|
||||
'BackendFactory'
|
||||
]
|
||||
235
src/generation/base.py
Normal file
235
src/generation/base.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Base interface for video generation backends.
|
||||
All backends (WAN, SVD, etc.) must implement this interface.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationResult:
|
||||
"""Result of a video generation."""
|
||||
success: bool
|
||||
output_path: Optional[Path] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
error_message: Optional[str] = None
|
||||
vram_usage_gb: Optional[float] = None
|
||||
generation_time_s: Optional[float] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationSpec:
|
||||
"""Specification for video generation."""
|
||||
prompt: str
|
||||
negative_prompt: str
|
||||
width: int
|
||||
height: int
|
||||
num_frames: int
|
||||
fps: int
|
||||
seed: int
|
||||
steps: int
|
||||
cfg_scale: float
|
||||
output_path: Path
|
||||
init_frame_path: Optional[Path] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for metadata storage."""
|
||||
return {
|
||||
"prompt": self.prompt,
|
||||
"negative_prompt": self.negative_prompt,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"num_frames": self.num_frames,
|
||||
"fps": self.fps,
|
||||
"seed": self.seed,
|
||||
"steps": self.steps,
|
||||
"cfg_scale": self.cfg_scale,
|
||||
"output_path": str(self.output_path),
|
||||
"init_frame_path": str(self.init_frame_path) if self.init_frame_path else None
|
||||
}
|
||||
|
||||
|
||||
class BaseVideoBackend(ABC):
|
||||
"""
|
||||
Abstract base class for video generation backends.
|
||||
|
||||
All backends must implement this interface to be used in the pipeline.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize the backend with configuration.
|
||||
|
||||
Args:
|
||||
config: Backend-specific configuration dictionary
|
||||
"""
|
||||
self.config = config
|
||||
self.model = None
|
||||
self._is_loaded = False
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Return the backend name."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supports_chunking(self) -> bool:
|
||||
"""Whether this backend supports chunked generation."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supports_init_frame(self) -> bool:
|
||||
"""Whether this backend supports init frame conditioning."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> None:
|
||||
"""
|
||||
Load the model into memory.
|
||||
|
||||
This should be called before generate().
|
||||
Implementations should set self._is_loaded = True when complete.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unload(self) -> None:
|
||||
"""
|
||||
Unload the model from memory.
|
||||
|
||||
Implementations should set self._is_loaded = False when complete.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, spec: GenerationSpec) -> GenerationResult:
|
||||
"""
|
||||
Generate a video clip.
|
||||
|
||||
Args:
|
||||
spec: Generation specification
|
||||
|
||||
Returns:
|
||||
GenerationResult with output path and metadata
|
||||
"""
|
||||
pass
|
||||
|
||||
def generate_chunked(
|
||||
self,
|
||||
spec: GenerationSpec,
|
||||
chunk_duration_s: int,
|
||||
overlap_s: int = 0,
|
||||
mode: str = "sequential"
|
||||
) -> GenerationResult:
|
||||
"""
|
||||
Generate a video in chunks.
|
||||
|
||||
Default implementation for backends that don't natively support chunking.
|
||||
Backends with native support should override this.
|
||||
|
||||
Args:
|
||||
spec: Generation specification
|
||||
chunk_duration_s: Duration of each chunk in seconds
|
||||
overlap_s: Overlap between chunks in seconds (for blending)
|
||||
mode: "sequential" or "overlapping"
|
||||
|
||||
Returns:
|
||||
GenerationResult with concatenated video
|
||||
"""
|
||||
if not self.supports_chunking:
|
||||
# Fall back to single generation
|
||||
return self.generate(spec)
|
||||
|
||||
raise NotImplementedError(
|
||||
"Chunked generation not implemented for this backend"
|
||||
)
|
||||
|
||||
def estimate_vram_usage(self, width: int, height: int, num_frames: int) -> float:
|
||||
"""
|
||||
Estimate VRAM usage for a generation.
|
||||
|
||||
Args:
|
||||
width: Video width
|
||||
height: Video height
|
||||
num_frames: Number of frames
|
||||
|
||||
Returns:
|
||||
Estimated VRAM usage in GB
|
||||
"""
|
||||
# Default implementation - backends should override with better estimates
|
||||
return self.config.get("vram_gb", 12.0)
|
||||
|
||||
def check_vram_available(self, width: int, height: int, num_frames: int) -> bool:
|
||||
"""
|
||||
Check if sufficient VRAM is available.
|
||||
|
||||
Args:
|
||||
width: Video width
|
||||
height: Video height
|
||||
num_frames: Number of frames
|
||||
|
||||
Returns:
|
||||
True if generation should fit in VRAM
|
||||
"""
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
||||
estimated = self.estimate_vram_usage(width, height, num_frames)
|
||||
# Leave 10% headroom
|
||||
return estimated < total_vram * 0.9
|
||||
except:
|
||||
pass
|
||||
|
||||
# If we can't check, assume it's OK
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if the model is loaded."""
|
||||
return self._is_loaded
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
self.load()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit."""
|
||||
self.unload()
|
||||
|
||||
|
||||
class BackendFactory:
|
||||
"""Factory for creating backend instances."""
|
||||
|
||||
_backends: Dict[str, type] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str, backend_class: type):
|
||||
"""Register a backend class."""
|
||||
if not issubclass(backend_class, BaseVideoBackend):
|
||||
raise ValueError(f"Backend must inherit from BaseVideoBackend")
|
||||
cls._backends[name] = backend_class
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, config: Dict[str, Any]) -> BaseVideoBackend:
|
||||
"""Create a backend instance."""
|
||||
if name not in cls._backends:
|
||||
raise ValueError(f"Unknown backend: {name}. "
|
||||
f"Available: {list(cls._backends.keys())}")
|
||||
return cls._backends[name](config)
|
||||
|
||||
@classmethod
|
||||
def list_backends(cls) -> List[str]:
|
||||
"""List available backend names."""
|
||||
return list(cls._backends.keys())
|
||||
39
src/storyboard/__init__.py
Normal file
39
src/storyboard/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
Storyboard module for validation, loading, and prompt compilation.
|
||||
"""
|
||||
|
||||
from .schema import (
|
||||
Storyboard,
|
||||
ProjectSettings,
|
||||
Shot,
|
||||
Character,
|
||||
Location,
|
||||
GlobalStyle,
|
||||
CameraSettings,
|
||||
GenerationSettings,
|
||||
OutputSettings,
|
||||
Resolution
|
||||
)
|
||||
from .loader import StoryboardValidator, StoryboardLoadError
|
||||
from .prompt_compiler import PromptCompiler, CompiledPrompt
|
||||
from .shot_planner import ShotPlanner, ShotPlan, Chunk
|
||||
|
||||
__all__ = [
|
||||
'Storyboard',
|
||||
'ProjectSettings',
|
||||
'Shot',
|
||||
'Character',
|
||||
'Location',
|
||||
'GlobalStyle',
|
||||
'CameraSettings',
|
||||
'GenerationSettings',
|
||||
'OutputSettings',
|
||||
'Resolution',
|
||||
'StoryboardValidator',
|
||||
'StoryboardLoadError',
|
||||
'PromptCompiler',
|
||||
'CompiledPrompt',
|
||||
'ShotPlanner',
|
||||
'ShotPlan',
|
||||
'Chunk'
|
||||
]
|
||||
84
src/storyboard/loader.py
Normal file
84
src/storyboard/loader.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Storyboard loader and validator.
|
||||
Handles loading JSON files and validating against schema.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from .schema import Storyboard
|
||||
|
||||
|
||||
class StoryboardLoadError(Exception):
|
||||
"""Raised when storyboard loading fails."""
|
||||
pass
|
||||
|
||||
|
||||
class StoryboardValidator:
|
||||
"""Validates and loads storyboard JSON files."""
|
||||
|
||||
@staticmethod
|
||||
def load(path: Union[str, Path]) -> Storyboard:
|
||||
"""
|
||||
Load and validate a storyboard JSON file.
|
||||
|
||||
Args:
|
||||
path: Path to the JSON file
|
||||
|
||||
Returns:
|
||||
Validated Storyboard object
|
||||
|
||||
Raises:
|
||||
StoryboardLoadError: If file doesn't exist or is invalid
|
||||
"""
|
||||
path = Path(path)
|
||||
|
||||
if not path.exists():
|
||||
raise StoryboardLoadError(f"Storyboard file not found: {path}")
|
||||
|
||||
if not path.suffix.lower() == '.json':
|
||||
raise StoryboardLoadError(f"Storyboard file must be JSON: {path}")
|
||||
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
raise StoryboardLoadError(f"Invalid JSON in {path}: {e}")
|
||||
except Exception as e:
|
||||
raise StoryboardLoadError(f"Failed to read {path}: {e}")
|
||||
|
||||
try:
|
||||
storyboard = Storyboard(**data)
|
||||
except Exception as e:
|
||||
raise StoryboardLoadError(f"Storyboard validation failed: {e}")
|
||||
|
||||
return storyboard
|
||||
|
||||
@staticmethod
|
||||
def validate_references(storyboard: Storyboard) -> list[str]:
|
||||
"""
|
||||
Validate that all references (location_id, characters) exist.
|
||||
|
||||
Args:
|
||||
storyboard: Loaded storyboard
|
||||
|
||||
Returns:
|
||||
List of validation warnings/errors
|
||||
"""
|
||||
issues = []
|
||||
|
||||
# Validate character references
|
||||
char_ids = {char.id for char in storyboard.characters}
|
||||
for shot in storyboard.shots:
|
||||
for char_id in shot.characters:
|
||||
if char_id not in char_ids:
|
||||
issues.append(f"Shot {shot.id}: Character '{char_id}' not defined")
|
||||
|
||||
# Validate location references
|
||||
loc_ids = {loc.id for loc in storyboard.locations}
|
||||
for shot in storyboard.shots:
|
||||
if shot.location_id and shot.location_id not in loc_ids:
|
||||
issues.append(f"Shot {shot.id}: Location '{shot.location_id}' not defined")
|
||||
|
||||
return issues
|
||||
154
src/storyboard/prompt_compiler.py
Normal file
154
src/storyboard/prompt_compiler.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
Prompt compiler for merging global style with shot-specific prompts.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..storyboard.schema import Storyboard, Shot, GlobalStyle
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompiledPrompt:
|
||||
"""Compiled prompt ready for generation."""
|
||||
positive: str
|
||||
negative: str
|
||||
camera_notes: str
|
||||
full_prompt: str # Combined positive + style + camera
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary for metadata storage."""
|
||||
return {
|
||||
"positive": self.positive,
|
||||
"negative": self.negative,
|
||||
"camera_notes": self.camera_notes,
|
||||
"full_prompt": self.full_prompt
|
||||
}
|
||||
|
||||
|
||||
class PromptCompiler:
|
||||
"""
|
||||
Compiles prompts by merging global style with shot-specific content.
|
||||
"""
|
||||
|
||||
def __init__(self, storyboard: Storyboard):
|
||||
"""
|
||||
Initialize with storyboard.
|
||||
|
||||
Args:
|
||||
storyboard: Loaded and validated storyboard
|
||||
"""
|
||||
self.storyboard = storyboard
|
||||
self.global_style = storyboard.project.global_style
|
||||
|
||||
def compile_shot(self, shot: Shot) -> CompiledPrompt:
|
||||
"""
|
||||
Compile prompt for a specific shot.
|
||||
|
||||
Args:
|
||||
shot: Shot to compile prompt for
|
||||
|
||||
Returns:
|
||||
CompiledPrompt with merged content
|
||||
"""
|
||||
# Build character descriptions
|
||||
char_descriptions = []
|
||||
for char_id in shot.characters:
|
||||
char = self.storyboard.get_character(char_id)
|
||||
if char:
|
||||
char_descriptions.append(f"{char.name}: {char.description}")
|
||||
|
||||
# Build location description
|
||||
location_desc = ""
|
||||
if shot.location_id:
|
||||
loc = self.storyboard.get_location(shot.location_id)
|
||||
if loc:
|
||||
location_desc = loc.description
|
||||
|
||||
# Compile positive prompt
|
||||
positive_parts = []
|
||||
|
||||
# Add global visual style first (sets the aesthetic)
|
||||
if self.global_style.visual_style:
|
||||
positive_parts.append(self.global_style.visual_style)
|
||||
|
||||
# Add location context
|
||||
if location_desc:
|
||||
positive_parts.append(f"Setting: {location_desc}")
|
||||
|
||||
# Add character descriptions
|
||||
if char_descriptions:
|
||||
positive_parts.append("Characters: " + "; ".join(char_descriptions))
|
||||
|
||||
# Add the main shot prompt
|
||||
positive_parts.append(shot.prompt)
|
||||
|
||||
# Add camera/lens info
|
||||
camera_parts = []
|
||||
if shot.camera.framing:
|
||||
camera_parts.append(shot.camera.framing)
|
||||
if self.global_style.lens:
|
||||
camera_parts.append(self.global_style.lens)
|
||||
if shot.camera.movement:
|
||||
camera_parts.append(shot.camera.movement)
|
||||
if self.global_style.motion_style:
|
||||
camera_parts.append(self.global_style.motion_style)
|
||||
if shot.camera.notes:
|
||||
camera_parts.append(shot.camera.notes)
|
||||
|
||||
if camera_parts:
|
||||
positive_parts.append(f"Camera: {', '.join(camera_parts)}")
|
||||
|
||||
# Add lighting
|
||||
if self.global_style.lighting:
|
||||
positive_parts.append(f"Lighting: {self.global_style.lighting}")
|
||||
|
||||
# Add color grade
|
||||
if self.global_style.color_grade:
|
||||
positive_parts.append(f"Color: {self.global_style.color_grade}")
|
||||
|
||||
positive = ". ".join(positive_parts)
|
||||
|
||||
# Compile negative prompt
|
||||
negative_parts = []
|
||||
if self.global_style.negative_prompt:
|
||||
negative_parts.append(self.global_style.negative_prompt)
|
||||
|
||||
negative = ", ".join(negative_parts) if negative_parts else ""
|
||||
|
||||
# Camera notes for metadata
|
||||
camera_notes = "; ".join(camera_parts) if camera_parts else ""
|
||||
|
||||
# Full combined prompt (for reference)
|
||||
full = positive
|
||||
|
||||
return CompiledPrompt(
|
||||
positive=positive,
|
||||
negative=negative,
|
||||
camera_notes=camera_notes,
|
||||
full_prompt=full
|
||||
)
|
||||
|
||||
def compile_all(self) -> List[CompiledPrompt]:
|
||||
"""
|
||||
Compile prompts for all shots.
|
||||
|
||||
Returns:
|
||||
List of CompiledPrompt objects in shot order
|
||||
"""
|
||||
return [self.compile_shot(shot) for shot in self.storyboard.shots]
|
||||
|
||||
def get_shot_prompt(self, shot_id: str) -> Optional[CompiledPrompt]:
|
||||
"""
|
||||
Get compiled prompt for a specific shot by ID.
|
||||
|
||||
Args:
|
||||
shot_id: Shot identifier
|
||||
|
||||
Returns:
|
||||
CompiledPrompt or None if shot not found
|
||||
"""
|
||||
for shot in self.storyboard.shots:
|
||||
if shot.id == shot_id:
|
||||
return self.compile_shot(shot)
|
||||
return None
|
||||
159
src/storyboard/schema.py
Normal file
159
src/storyboard/schema.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Storyboard schema validation using Pydantic.
|
||||
Defines the data models for storyboard JSON files.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Literal
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class Resolution(BaseModel):
|
||||
"""Video resolution dimensions."""
|
||||
width: int = Field(..., ge=256, le=4096, description="Width in pixels")
|
||||
height: int = Field(..., ge=256, le=4096, description="Height in pixels")
|
||||
|
||||
|
||||
class GlobalStyle(BaseModel):
|
||||
"""Global visual style settings applied to all shots."""
|
||||
visual_style: str = Field(default="", description="Overall visual aesthetic")
|
||||
color_grade: str = Field(default="", description="Color grading description")
|
||||
lens: str = Field(default="", description="Lens type/characteristics")
|
||||
lighting: str = Field(default="", description="Lighting setup description")
|
||||
motion_style: str = Field(default="", description="Camera motion characteristics")
|
||||
negative_prompt: str = Field(default="", description="What to avoid in generation")
|
||||
|
||||
|
||||
class AudioSettings(BaseModel):
|
||||
"""Audio configuration for the final video."""
|
||||
add_music: bool = Field(default=False)
|
||||
music_path: str = Field(default="")
|
||||
add_voiceover: bool = Field(default=False)
|
||||
voiceover_path: str = Field(default="")
|
||||
|
||||
|
||||
class ProjectSettings(BaseModel):
|
||||
"""Top-level project configuration."""
|
||||
title: str = Field(..., min_length=1, description="Project title")
|
||||
fps: int = Field(default=24, ge=1, le=120, description="Frames per second")
|
||||
target_duration_s: int = Field(default=20, ge=1, le=300, description="Target duration in seconds")
|
||||
resolution: Resolution
|
||||
aspect_ratio: str = Field(default="16:9", description="Aspect ratio (e.g., 16:9, 4:3)")
|
||||
global_style: GlobalStyle = Field(default_factory=GlobalStyle)
|
||||
audio: AudioSettings = Field(default_factory=AudioSettings)
|
||||
|
||||
|
||||
class Character(BaseModel):
|
||||
"""Character definition for consistency across shots."""
|
||||
id: str = Field(..., min_length=1, description="Unique character identifier")
|
||||
name: str = Field(..., min_length=1, description="Character name")
|
||||
description: str = Field(default="", description="Visual description")
|
||||
consistency_notes: str = Field(default="", description="Notes for maintaining consistency")
|
||||
|
||||
|
||||
class Location(BaseModel):
|
||||
"""Location/setting definition."""
|
||||
id: str = Field(..., min_length=1, description="Unique location identifier")
|
||||
name: str = Field(..., min_length=1, description="Location name")
|
||||
description: str = Field(default="", description="Visual description of the setting")
|
||||
|
||||
|
||||
class CameraSettings(BaseModel):
|
||||
"""Camera configuration for a shot."""
|
||||
framing: str = Field(default="", description="Shot framing (wide, medium, close-up, etc.)")
|
||||
movement: str = Field(default="", description="Camera movement description")
|
||||
notes: str = Field(default="", description="Additional camera notes")
|
||||
|
||||
|
||||
class GenerationSettings(BaseModel):
|
||||
"""AI generation parameters for a shot."""
|
||||
seed: int = Field(default=-1, description="Random seed (-1 for random)")
|
||||
steps: int = Field(default=30, ge=1, le=100, description="Inference steps")
|
||||
cfg_scale: float = Field(default=6.0, ge=1.0, le=20.0, description="Classifier-free guidance scale")
|
||||
sampler: str = Field(default="default", description="Sampler type")
|
||||
chunk_seconds: int = Field(default=4, ge=1, le=10, description="Duration per chunk in seconds")
|
||||
use_init_frame_from_prev: bool = Field(default=False, description="Use last frame of previous shot as init")
|
||||
|
||||
|
||||
class Shot(BaseModel):
|
||||
"""Individual shot definition."""
|
||||
id: str = Field(..., min_length=1, description="Unique shot identifier")
|
||||
duration_s: int = Field(..., ge=1, le=60, description="Shot duration in seconds")
|
||||
location_id: Optional[str] = Field(default=None, description="Reference to location")
|
||||
characters: List[str] = Field(default_factory=list, description="List of character IDs")
|
||||
prompt: str = Field(..., min_length=1, description="Generation prompt for this shot")
|
||||
camera: CameraSettings = Field(default_factory=CameraSettings)
|
||||
generation: GenerationSettings = Field(default_factory=GenerationSettings)
|
||||
|
||||
|
||||
class UpscaleSettings(BaseModel):
|
||||
"""Post-processing upscaling configuration."""
|
||||
enabled: bool = Field(default=False)
|
||||
target_height: int = Field(default=2160, ge=720, le=4320, description="Target height in pixels")
|
||||
method: str = Field(default="default", description="Upscaling method")
|
||||
|
||||
|
||||
class OutputSettings(BaseModel):
|
||||
"""Final video output configuration."""
|
||||
container: Literal["mp4", "mov", "mkv"] = Field(default="mp4")
|
||||
codec: Literal["h264", "h265", "vp9"] = Field(default="h264")
|
||||
crf: int = Field(default=18, ge=0, le=51, description="Constant rate factor (lower = higher quality)")
|
||||
preset: str = Field(default="medium", description="Encoding preset")
|
||||
upscale: UpscaleSettings = Field(default_factory=UpscaleSettings)
|
||||
|
||||
|
||||
class Storyboard(BaseModel):
|
||||
"""Complete storyboard document."""
|
||||
schema_version: str = Field(default="1.0", description="Schema version for compatibility")
|
||||
project: ProjectSettings
|
||||
characters: List[Character] = Field(default_factory=list)
|
||||
locations: List[Location] = Field(default_factory=list)
|
||||
shots: List[Shot] = Field(..., description="List of shots to generate")
|
||||
output: OutputSettings = Field(default_factory=OutputSettings)
|
||||
|
||||
@validator('shots')
|
||||
def validate_shot_ids_unique(cls, v):
|
||||
"""Ensure all shot IDs are unique and at least one shot exists."""
|
||||
if len(v) < 1:
|
||||
raise ValueError("At least one shot is required")
|
||||
ids = [shot.id for shot in v]
|
||||
if len(ids) != len(set(ids)):
|
||||
raise ValueError("Shot IDs must be unique")
|
||||
return v
|
||||
|
||||
@validator('characters')
|
||||
def validate_character_ids_unique(cls, v):
|
||||
"""Ensure all character IDs are unique."""
|
||||
ids = [char.id for char in v]
|
||||
if len(ids) != len(set(ids)):
|
||||
raise ValueError("Character IDs must be unique")
|
||||
return v
|
||||
|
||||
@validator('locations')
|
||||
def validate_location_ids_unique(cls, v):
|
||||
"""Ensure all location IDs are unique."""
|
||||
ids = [loc.id for loc in v]
|
||||
if len(ids) != len(set(ids)):
|
||||
raise ValueError("Location IDs must be unique")
|
||||
return v
|
||||
|
||||
def get_character(self, char_id: str) -> Optional[Character]:
|
||||
"""Get character by ID."""
|
||||
for char in self.characters:
|
||||
if char.id == char_id:
|
||||
return char
|
||||
return None
|
||||
|
||||
def get_location(self, loc_id: str) -> Optional[Location]:
|
||||
"""Get location by ID."""
|
||||
for loc in self.locations:
|
||||
if loc.id == loc_id:
|
||||
return loc
|
||||
return None
|
||||
|
||||
def get_total_duration(self) -> int:
|
||||
"""Calculate total duration in seconds."""
|
||||
return sum(shot.duration_s for shot in self.shots)
|
||||
|
||||
def get_total_frames(self) -> int:
|
||||
"""Calculate total frame count."""
|
||||
return self.get_total_duration() * self.project.fps
|
||||
252
src/storyboard/shot_planner.py
Normal file
252
src/storyboard/shot_planner.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""
|
||||
Shot planner for converting duration to frames and chunking strategy.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from ..storyboard.schema import Storyboard, Shot
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
"""A single chunk of a shot."""
|
||||
chunk_index: int
|
||||
start_frame: int
|
||||
end_frame: int
|
||||
num_frames: int
|
||||
duration_s: float
|
||||
use_init_frame: bool
|
||||
init_frame_source: Optional[Path] # Path to frame file or previous chunk
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for metadata."""
|
||||
return {
|
||||
"chunk_index": self.chunk_index,
|
||||
"start_frame": self.start_frame,
|
||||
"end_frame": self.end_frame,
|
||||
"num_frames": self.num_frames,
|
||||
"duration_s": self.duration_s,
|
||||
"use_init_frame": self.use_init_frame,
|
||||
"init_frame_source": str(self.init_frame_source) if self.init_frame_source else None
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShotPlan:
|
||||
"""Complete plan for generating a shot."""
|
||||
shot_id: str
|
||||
total_frames: int
|
||||
total_duration_s: int
|
||||
fps: int
|
||||
chunks: List[Chunk]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for metadata."""
|
||||
return {
|
||||
"shot_id": self.shot_id,
|
||||
"total_frames": self.total_frames,
|
||||
"total_duration_s": self.total_duration_s,
|
||||
"fps": self.fps,
|
||||
"chunks": [c.to_dict() for c in self.chunks]
|
||||
}
|
||||
|
||||
|
||||
class ShotPlanner:
|
||||
"""
|
||||
Plans shot generation including frame counts and chunking strategy.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fps: int = 24,
|
||||
max_chunk_seconds: int = 4,
|
||||
overlap_seconds: int = 0,
|
||||
chunking_mode: str = "sequential"
|
||||
):
|
||||
"""
|
||||
Initialize shot planner.
|
||||
|
||||
Args:
|
||||
fps: Frames per second
|
||||
max_chunk_seconds: Maximum duration per chunk
|
||||
overlap_seconds: Overlap between chunks (for blending mode)
|
||||
chunking_mode: "sequential" or "overlapping"
|
||||
"""
|
||||
self.fps = fps
|
||||
self.max_chunk_seconds = max_chunk_seconds
|
||||
self.overlap_seconds = overlap_seconds
|
||||
self.chunking_mode = chunking_mode
|
||||
|
||||
def duration_to_frames(self, duration_s: int) -> int:
|
||||
"""
|
||||
Convert duration in seconds to frame count.
|
||||
|
||||
Args:
|
||||
duration_s: Duration in seconds
|
||||
|
||||
Returns:
|
||||
Number of frames
|
||||
"""
|
||||
return duration_s * self.fps
|
||||
|
||||
def frames_to_duration(self, num_frames: int) -> float:
|
||||
"""
|
||||
Convert frame count to duration in seconds.
|
||||
|
||||
Args:
|
||||
num_frames: Number of frames
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
return num_frames / self.fps
|
||||
|
||||
def plan_shot(
|
||||
self,
|
||||
shot: Shot,
|
||||
fps: Optional[int] = None,
|
||||
use_init_from_prev: bool = False
|
||||
) -> ShotPlan:
|
||||
"""
|
||||
Create a generation plan for a shot.
|
||||
|
||||
Args:
|
||||
shot: Shot to plan
|
||||
fps: Override FPS (uses storyboard default if None)
|
||||
use_init_from_prev: Whether to use previous shot's last frame
|
||||
|
||||
Returns:
|
||||
ShotPlan with chunk breakdown
|
||||
"""
|
||||
fps = fps or self.fps
|
||||
total_frames = shot.duration_s * fps
|
||||
|
||||
# Determine chunk size
|
||||
chunk_frames = self.max_chunk_seconds * fps
|
||||
overlap_frames = self.overlap_seconds * fps if self.chunking_mode == "overlapping" else 0
|
||||
|
||||
chunks = []
|
||||
|
||||
if total_frames <= chunk_frames:
|
||||
# Single chunk - no need to split
|
||||
chunk = Chunk(
|
||||
chunk_index=0,
|
||||
start_frame=0,
|
||||
end_frame=total_frames - 1,
|
||||
num_frames=total_frames,
|
||||
duration_s=shot.duration_s,
|
||||
use_init_frame=use_init_from_prev and shot.generation.use_init_frame_from_prev,
|
||||
init_frame_source=None # Will be set by pipeline
|
||||
)
|
||||
chunks.append(chunk)
|
||||
else:
|
||||
# Multiple chunks needed
|
||||
remaining_frames = total_frames
|
||||
current_frame = 0
|
||||
chunk_index = 0
|
||||
|
||||
while remaining_frames > 0:
|
||||
# Calculate frames for this chunk
|
||||
if self.chunking_mode == "sequential":
|
||||
# Sequential: Each chunk starts where previous ended
|
||||
this_chunk_frames = min(chunk_frames, remaining_frames)
|
||||
|
||||
chunk = Chunk(
|
||||
chunk_index=chunk_index,
|
||||
start_frame=current_frame,
|
||||
end_frame=current_frame + this_chunk_frames - 1,
|
||||
num_frames=this_chunk_frames,
|
||||
duration_s=this_chunk_frames / fps,
|
||||
use_init_frame=(chunk_index == 0 and use_init_from_prev and shot.generation.use_init_frame_from_prev) or chunk_index > 0,
|
||||
init_frame_source=None
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
current_frame += this_chunk_frames
|
||||
remaining_frames -= this_chunk_frames
|
||||
|
||||
else: # overlapping
|
||||
# Overlapping: Chunks overlap for blending
|
||||
if remaining_frames <= chunk_frames:
|
||||
# Last chunk - take all remaining frames
|
||||
this_chunk_frames = remaining_frames
|
||||
else:
|
||||
# Not last chunk - include overlap
|
||||
this_chunk_frames = min(chunk_frames + overlap_frames, remaining_frames)
|
||||
|
||||
chunk = Chunk(
|
||||
chunk_index=chunk_index,
|
||||
start_frame=current_frame,
|
||||
end_frame=current_frame + this_chunk_frames - 1,
|
||||
num_frames=this_chunk_frames,
|
||||
duration_s=this_chunk_frames / fps,
|
||||
use_init_frame=(chunk_index == 0 and use_init_from_prev and shot.generation.use_init_frame_from_prev) or chunk_index > 0,
|
||||
init_frame_source=None
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
# Move forward by chunk size (minus overlap)
|
||||
current_frame += chunk_frames - overlap_frames
|
||||
remaining_frames -= chunk_frames - overlap_frames
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
return ShotPlan(
|
||||
shot_id=shot.id,
|
||||
total_frames=total_frames,
|
||||
total_duration_s=shot.duration_s,
|
||||
fps=fps,
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
def plan_storyboard(
|
||||
self,
|
||||
storyboard: Storyboard,
|
||||
override_chunk_config: Optional[Dict[str, Any]] = None
|
||||
) -> List[ShotPlan]:
|
||||
"""
|
||||
Create generation plans for all shots in a storyboard.
|
||||
|
||||
Args:
|
||||
storyboard: Storyboard to plan
|
||||
override_chunk_config: Optional config to override defaults
|
||||
|
||||
Returns:
|
||||
List of ShotPlan objects
|
||||
"""
|
||||
# Apply overrides if provided
|
||||
fps = override_chunk_config.get("fps", storyboard.project.fps) if override_chunk_config else storyboard.project.fps
|
||||
max_chunk = override_chunk_config.get("max_chunk_seconds", self.max_chunk_seconds) if override_chunk_config else self.max_chunk_seconds
|
||||
overlap = override_chunk_config.get("overlap_seconds", self.overlap_seconds) if override_chunk_config else self.overlap_seconds
|
||||
mode = override_chunk_config.get("mode", self.chunking_mode) if override_chunk_config else self.chunking_mode
|
||||
|
||||
planner = ShotPlanner(
|
||||
fps=fps,
|
||||
max_chunk_seconds=max_chunk,
|
||||
overlap_seconds=overlap,
|
||||
chunking_mode=mode
|
||||
)
|
||||
|
||||
plans = []
|
||||
prev_shot_id = None
|
||||
|
||||
for i, shot in enumerate(storyboard.shots):
|
||||
# Determine if we should use init frame from previous shot
|
||||
use_init = i > 0 and shot.generation.use_init_frame_from_prev
|
||||
|
||||
plan = planner.plan_shot(shot, fps=fps, use_init_from_prev=use_init)
|
||||
plans.append(plan)
|
||||
|
||||
prev_shot_id = shot.id
|
||||
|
||||
return plans
|
||||
|
||||
def get_total_frames(self, plans: List[ShotPlan]) -> int:
|
||||
"""Calculate total frames across all plans."""
|
||||
return sum(plan.total_frames for plan in plans)
|
||||
|
||||
def get_total_duration(self, plans: List[ShotPlan]) -> float:
|
||||
"""Calculate total duration in seconds across all plans."""
|
||||
return sum(plan.total_duration_s for plan in plans)
|
||||
3
src/upscaling/__init__.py
Normal file
3
src/upscaling/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Upscaling module.
|
||||
"""
|
||||
189
tests/unit/test_checkpoint.py
Normal file
189
tests/unit/test_checkpoint.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Tests for checkpoint system.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from src.core.checkpoint import (
|
||||
CheckpointManager, ShotCheckpoint, ProjectCheckpoint,
|
||||
ShotStatus, ProjectStatus
|
||||
)
|
||||
|
||||
|
||||
class TestCheckpointManager:
|
||||
"""Test checkpoint management."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(self):
|
||||
"""Create a temporary database for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
yield db_path
|
||||
|
||||
# Cleanup
|
||||
if os.path.exists(db_path):
|
||||
os.unlink(db_path)
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self, temp_db):
|
||||
"""Create a checkpoint manager with temp database."""
|
||||
return CheckpointManager(db_path=temp_db)
|
||||
|
||||
def test_create_project(self, manager):
|
||||
"""Test creating a project checkpoint."""
|
||||
checkpoint = manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/to/storyboard.json",
|
||||
output_dir="/path/to/output",
|
||||
backend_name="wan_t2v_14b"
|
||||
)
|
||||
|
||||
assert checkpoint.project_id == "test_001"
|
||||
assert checkpoint.status == ProjectStatus.INITIALIZED
|
||||
assert checkpoint.backend_name == "wan_t2v_14b"
|
||||
|
||||
def test_get_project(self, manager):
|
||||
"""Test retrieving a project checkpoint."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/to/storyboard.json",
|
||||
output_dir="/path/to/output",
|
||||
backend_name="wan_t2v_14b"
|
||||
)
|
||||
|
||||
retrieved = manager.get_project("test_001")
|
||||
assert retrieved is not None
|
||||
assert retrieved.project_id == "test_001"
|
||||
|
||||
missing = manager.get_project("nonexistent")
|
||||
assert missing is None
|
||||
|
||||
def test_update_project_status(self, manager):
|
||||
"""Test updating project status."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/to/storyboard.json",
|
||||
output_dir="/path/to/output",
|
||||
backend_name="wan_t2v_14b"
|
||||
)
|
||||
|
||||
manager.update_project_status("test_001", ProjectStatus.RUNNING)
|
||||
|
||||
retrieved = manager.get_project("test_001")
|
||||
assert retrieved.status == ProjectStatus.RUNNING
|
||||
|
||||
def test_create_shot(self, manager):
|
||||
"""Test creating a shot checkpoint."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/storyboard.json",
|
||||
output_dir="/path/output",
|
||||
backend_name="wan"
|
||||
)
|
||||
|
||||
shot = manager.create_shot("test_001", "S01", ShotStatus.PENDING)
|
||||
|
||||
assert shot.shot_id == "S01"
|
||||
assert shot.status == ShotStatus.PENDING
|
||||
|
||||
def test_update_shot(self, manager):
|
||||
"""Test updating shot checkpoint."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/storyboard.json",
|
||||
output_dir="/path/output",
|
||||
backend_name="wan"
|
||||
)
|
||||
manager.create_shot("test_001", "S01")
|
||||
|
||||
manager.update_shot(
|
||||
"test_001",
|
||||
"S01",
|
||||
status=ShotStatus.IN_PROGRESS,
|
||||
output_path="/path/output.mp4",
|
||||
metadata={"seed": 12345}
|
||||
)
|
||||
|
||||
shot = manager.get_shot("test_001", "S01")
|
||||
assert shot.status == ShotStatus.IN_PROGRESS
|
||||
assert shot.output_path == "/path/output.mp4"
|
||||
assert shot.metadata["seed"] == 12345
|
||||
|
||||
def test_get_pending_shots(self, manager):
|
||||
"""Test retrieving pending shots."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/storyboard.json",
|
||||
output_dir="/path/output",
|
||||
backend_name="wan"
|
||||
)
|
||||
|
||||
manager.create_shot("test_001", "S01", ShotStatus.PENDING)
|
||||
manager.create_shot("test_001", "S02", ShotStatus.COMPLETED)
|
||||
manager.create_shot("test_001", "S03", ShotStatus.PENDING)
|
||||
|
||||
pending = manager.get_pending_shots("test_001")
|
||||
assert len(pending) == 2
|
||||
assert "S01" in pending
|
||||
assert "S03" in pending
|
||||
|
||||
def test_can_resume(self, manager):
|
||||
"""Test checking if project can be resumed."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/storyboard.json",
|
||||
output_dir="/path/output",
|
||||
backend_name="wan"
|
||||
)
|
||||
|
||||
# No shots yet - can't resume
|
||||
assert not manager.can_resume("test_001")
|
||||
|
||||
# Add pending shot - can resume
|
||||
manager.create_shot("test_001", "S01", ShotStatus.PENDING)
|
||||
assert manager.can_resume("test_001")
|
||||
|
||||
# Complete all shots - can't resume
|
||||
manager.update_shot("test_001", "S01", status=ShotStatus.COMPLETED)
|
||||
assert not manager.can_resume("test_001")
|
||||
|
||||
def test_delete_project(self, manager):
|
||||
"""Test deleting a project."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/storyboard.json",
|
||||
output_dir="/path/output",
|
||||
backend_name="wan"
|
||||
)
|
||||
manager.create_shot("test_001", "S01")
|
||||
|
||||
manager.delete_project("test_001")
|
||||
|
||||
assert manager.get_project("test_001") is None
|
||||
assert manager.get_shot("test_001", "S01") is None
|
||||
|
||||
def test_list_projects(self, manager):
|
||||
"""Test listing all projects."""
|
||||
manager.create_project(
|
||||
project_id="test_001",
|
||||
storyboard_path="/path/1.json",
|
||||
output_dir="/out/1",
|
||||
backend_name="wan"
|
||||
)
|
||||
manager.create_project(
|
||||
project_id="test_002",
|
||||
storyboard_path="/path/2.json",
|
||||
output_dir="/out/2",
|
||||
backend_name="svd"
|
||||
)
|
||||
|
||||
projects = manager.list_projects()
|
||||
assert len(projects) == 2
|
||||
|
||||
project_ids = [p.project_id for p in projects]
|
||||
assert "test_001" in project_ids
|
||||
assert "test_002" in project_ids
|
||||
163
tests/unit/test_config.py
Normal file
163
tests/unit/test_config.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Tests for configuration system.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from src.core.config import Config, BackendConfig, ConfigLoader, get_config, reload_config
|
||||
|
||||
|
||||
class TestConfigLoader:
|
||||
"""Test configuration loading."""
|
||||
|
||||
def test_load_default_config(self):
|
||||
"""Test loading with no config files (uses defaults)."""
|
||||
config = ConfigLoader.load(
|
||||
config_path=Path("nonexistent.yaml"),
|
||||
env_file=Path("nonexistent.env")
|
||||
)
|
||||
|
||||
assert isinstance(config, Config)
|
||||
assert config.active_backend == "wan_t2v_14b"
|
||||
assert len(config.backends) == 0 # No YAML loaded
|
||||
|
||||
def test_load_yaml_config(self):
|
||||
"""Test loading from YAML file."""
|
||||
yaml_content = """
|
||||
backends:
|
||||
test_backend:
|
||||
name: "Test Backend"
|
||||
class: "test.TestBackend"
|
||||
model_id: "test/model"
|
||||
vram_gb: 8
|
||||
dtype: "fp16"
|
||||
enable_vae_slicing: true
|
||||
enable_vae_tiling: false
|
||||
chunking:
|
||||
enabled: true
|
||||
mode: "sequential"
|
||||
max_chunk_seconds: 4
|
||||
overlap_seconds: 1
|
||||
|
||||
active_backend: "test_backend"
|
||||
defaults:
|
||||
fps: 30
|
||||
checkpoint_db: "test.db"
|
||||
model_cache_dir: "~/test_cache"
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||
f.write(yaml_content)
|
||||
yaml_path = f.name
|
||||
|
||||
try:
|
||||
config = ConfigLoader.load(config_path=Path(yaml_path))
|
||||
|
||||
assert config.active_backend == "test_backend"
|
||||
assert "test_backend" in config.backends
|
||||
|
||||
backend = config.get_backend("test_backend")
|
||||
assert backend.name == "Test Backend"
|
||||
assert backend.vram_gb == 8
|
||||
assert backend.chunking_mode == "sequential"
|
||||
assert config.defaults["fps"] == 30
|
||||
finally:
|
||||
os.unlink(yaml_path)
|
||||
|
||||
def test_load_env_file(self, monkeypatch):
|
||||
"""Test loading from .env file."""
|
||||
# Clear any existing env vars first
|
||||
for key in ['ACTIVE_BACKEND', 'MODEL_CACHE_DIR', 'LOG_LEVEL']:
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
env_content = """
|
||||
ACTIVE_BACKEND=env_backend
|
||||
MODEL_CACHE_DIR=/env/cache
|
||||
LOG_LEVEL=DEBUG
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.env', delete=False) as f:
|
||||
f.write(env_content)
|
||||
env_path = f.name
|
||||
|
||||
try:
|
||||
config = ConfigLoader.load(
|
||||
config_path=Path("nonexistent.yaml"),
|
||||
env_file=Path(env_path)
|
||||
)
|
||||
|
||||
assert config.active_backend == "env_backend"
|
||||
assert "/env/cache" in config.model_cache_dir or "\\env\\cache" in config.model_cache_dir
|
||||
assert config.log_level == "DEBUG"
|
||||
finally:
|
||||
os.unlink(env_path)
|
||||
# Clean up env vars
|
||||
for key in ['ACTIVE_BACKEND', 'MODEL_CACHE_DIR', 'LOG_LEVEL']:
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
def test_env_variable_override(self, monkeypatch):
|
||||
"""Test that environment variables override config files."""
|
||||
# Clear env var first
|
||||
monkeypatch.delenv("ACTIVE_BACKEND", raising=False)
|
||||
|
||||
yaml_content = """
|
||||
active_backend: yaml_backend
|
||||
model_cache_dir: ~/yaml_cache
|
||||
"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||
f.write(yaml_content)
|
||||
yaml_path = f.name
|
||||
|
||||
try:
|
||||
monkeypatch.setenv("ACTIVE_BACKEND", "env_override")
|
||||
|
||||
config = ConfigLoader.load(config_path=Path(yaml_path))
|
||||
|
||||
# Environment variable should override
|
||||
assert config.active_backend == "env_override"
|
||||
# But other settings from YAML should remain
|
||||
assert "yaml_cache" in config.model_cache_dir
|
||||
finally:
|
||||
os.unlink(yaml_path)
|
||||
monkeypatch.delenv("ACTIVE_BACKEND", raising=False)
|
||||
|
||||
def test_path_expansion(self):
|
||||
"""Test that paths are properly expanded."""
|
||||
config = ConfigLoader.load(
|
||||
config_path=Path("nonexistent.yaml"),
|
||||
env_file=Path("nonexistent.env")
|
||||
)
|
||||
|
||||
# Default paths should be expanded
|
||||
assert not config.model_cache_dir.startswith("~")
|
||||
assert not config.checkpoint_db.startswith("~")
|
||||
|
||||
def test_get_backend_not_found(self):
|
||||
"""Test getting a non-existent backend."""
|
||||
config = Config()
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
config.get_backend("nonexistent")
|
||||
|
||||
|
||||
class TestBackendConfig:
|
||||
"""Test backend configuration."""
|
||||
|
||||
def test_backend_config_defaults(self):
|
||||
"""Test backend config with defaults."""
|
||||
config = BackendConfig(
|
||||
name="Test",
|
||||
model_class="test.Test",
|
||||
model_id="test/model",
|
||||
vram_gb=12,
|
||||
dtype="fp16"
|
||||
)
|
||||
|
||||
assert config.enable_vae_slicing is True
|
||||
assert config.enable_vae_tiling is False
|
||||
assert config.chunking_enabled is True
|
||||
assert config.chunking_mode == "sequential"
|
||||
268
tests/unit/test_storyboard.py
Normal file
268
tests/unit/test_storyboard.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
Tests for storyboard schema validation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import json
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from src.storyboard.schema import (
|
||||
Storyboard, ProjectSettings, Shot, Character, Location,
|
||||
GlobalStyle, CameraSettings, GenerationSettings, OutputSettings,
|
||||
Resolution
|
||||
)
|
||||
from src.storyboard.loader import StoryboardValidator, StoryboardLoadError
|
||||
|
||||
|
||||
class TestStoryboardSchema:
|
||||
"""Test storyboard schema validation."""
|
||||
|
||||
def test_create_minimal_storyboard(self):
|
||||
"""Test creating a minimal valid storyboard."""
|
||||
storyboard = Storyboard(
|
||||
project=ProjectSettings(
|
||||
title="Test Video",
|
||||
resolution=Resolution(width=1920, height=1080),
|
||||
target_duration_s=10
|
||||
),
|
||||
shots=[
|
||||
Shot(
|
||||
id="S01",
|
||||
duration_s=5,
|
||||
prompt="A test shot"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert storyboard.project.title == "Test Video"
|
||||
assert len(storyboard.shots) == 1
|
||||
assert storyboard.shots[0].id == "S01"
|
||||
|
||||
def test_shot_validation_unique_ids(self):
|
||||
"""Test that duplicate shot IDs are rejected."""
|
||||
with pytest.raises(ValueError, match="Shot IDs must be unique"):
|
||||
Storyboard(
|
||||
project=ProjectSettings(
|
||||
title="Test",
|
||||
resolution=Resolution(width=1920, height=1080),
|
||||
target_duration_s=10
|
||||
),
|
||||
shots=[
|
||||
Shot(id="S01", duration_s=5, prompt="Shot 1"),
|
||||
Shot(id="S01", duration_s=5, prompt="Shot 2")
|
||||
]
|
||||
)
|
||||
|
||||
def test_character_validation_unique_ids(self):
|
||||
"""Test that duplicate character IDs are rejected."""
|
||||
with pytest.raises(ValueError, match="Character IDs must be unique"):
|
||||
Storyboard(
|
||||
project=ProjectSettings(
|
||||
title="Test",
|
||||
resolution=Resolution(width=1920, height=1080),
|
||||
target_duration_s=10
|
||||
),
|
||||
characters=[
|
||||
Character(id="C01", name="Character 1"),
|
||||
Character(id="C01", name="Character 2")
|
||||
],
|
||||
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
|
||||
)
|
||||
|
||||
def test_location_validation_unique_ids(self):
|
||||
"""Test that duplicate location IDs are rejected."""
|
||||
with pytest.raises(ValueError, match="Location IDs must be unique"):
|
||||
Storyboard(
|
||||
project=ProjectSettings(
|
||||
title="Test",
|
||||
resolution=Resolution(width=1920, height=1080),
|
||||
target_duration_s=10
|
||||
),
|
||||
locations=[
|
||||
Location(id="L01", name="Location 1"),
|
||||
Location(id="L01", name="Location 2")
|
||||
],
|
||||
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
|
||||
)
|
||||
|
||||
def test_get_character_by_id(self):
|
||||
"""Test retrieving character by ID."""
|
||||
storyboard = Storyboard(
|
||||
project=ProjectSettings(
|
||||
title="Test",
|
||||
resolution=Resolution(width=1920, height=1080),
|
||||
target_duration_s=10
|
||||
),
|
||||
characters=[
|
||||
Character(id="C01", name="Hero", description="The protagonist")
|
||||
],
|
||||
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
|
||||
)
|
||||
|
||||
char = storyboard.get_character("C01")
|
||||
assert char is not None
|
||||
assert char.name == "Hero"
|
||||
|
||||
missing = storyboard.get_character("C99")
|
||||
assert missing is None
|
||||
|
||||
def test_get_location_by_id(self):
|
||||
"""Test retrieving location by ID."""
|
||||
storyboard = Storyboard(
|
||||
project=ProjectSettings(
|
||||
title="Test",
|
||||
resolution=Resolution(width=1920, height=1080),
|
||||
target_duration_s=10
|
||||
),
|
||||
locations=[
|
||||
Location(id="L01", name="Street", description="A city street")
|
||||
],
|
||||
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
|
||||
)
|
||||
|
||||
loc = storyboard.get_location("L01")
|
||||
assert loc is not None
|
||||
assert loc.name == "Street"
|
||||
|
||||
missing = storyboard.get_location("L99")
|
||||
assert missing is None
|
||||
|
||||
def test_total_duration_calculation(self):
|
||||
"""Test total duration calculation."""
|
||||
storyboard = Storyboard(
|
||||
project=ProjectSettings(
|
||||
title="Test",
|
||||
resolution=Resolution(width=1920, height=1080),
|
||||
target_duration_s=10
|
||||
),
|
||||
shots=[
|
||||
Shot(id="S01", duration_s=4, prompt="Shot 1"),
|
||||
Shot(id="S02", duration_s=6, prompt="Shot 2")
|
||||
]
|
||||
)
|
||||
|
||||
assert storyboard.get_total_duration() == 10
|
||||
|
||||
def test_total_frames_calculation(self):
|
||||
"""Test total frames calculation."""
|
||||
storyboard = Storyboard(
|
||||
project=ProjectSettings(
|
||||
title="Test",
|
||||
fps=24,
|
||||
resolution=Resolution(width=1920, height=1080),
|
||||
target_duration_s=10
|
||||
),
|
||||
shots=[
|
||||
Shot(id="S01", duration_s=4, prompt="Shot 1"),
|
||||
Shot(id="S02", duration_s=6, prompt="Shot 2")
|
||||
]
|
||||
)
|
||||
|
||||
assert storyboard.get_total_frames() == 240 # 10 seconds * 24 fps
|
||||
|
||||
|
||||
class TestStoryboardLoader:
|
||||
"""Test storyboard loading functionality."""
|
||||
|
||||
def test_load_valid_storyboard(self):
|
||||
"""Test loading a valid storyboard JSON file."""
|
||||
data = {
|
||||
"schema_version": "1.0",
|
||||
"project": {
|
||||
"title": "Test Video",
|
||||
"fps": 24,
|
||||
"target_duration_s": 10,
|
||||
"resolution": {"width": 1920, "height": 1080},
|
||||
"aspect_ratio": "16:9",
|
||||
"global_style": {
|
||||
"visual_style": "cinematic",
|
||||
"negative_prompt": "blurry"
|
||||
},
|
||||
"audio": {"add_music": False}
|
||||
},
|
||||
"characters": [],
|
||||
"locations": [],
|
||||
"shots": [
|
||||
{
|
||||
"id": "S01",
|
||||
"duration_s": 5,
|
||||
"prompt": "A test shot",
|
||||
"camera": {"framing": "wide"},
|
||||
"generation": {"seed": 12345, "steps": 30}
|
||||
}
|
||||
],
|
||||
"output": {"container": "mp4", "codec": "h264"}
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(data, f)
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
storyboard = StoryboardValidator.load(temp_path)
|
||||
assert storyboard.project.title == "Test Video"
|
||||
assert len(storyboard.shots) == 1
|
||||
assert storyboard.shots[0].id == "S01"
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_load_nonexistent_file(self):
|
||||
"""Test loading a file that doesn't exist."""
|
||||
with pytest.raises(StoryboardLoadError, match="not found"):
|
||||
StoryboardValidator.load("/nonexistent/path/storyboard.json")
|
||||
|
||||
def test_load_invalid_json(self):
|
||||
"""Test loading invalid JSON."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
f.write("{invalid json}")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
with pytest.raises(StoryboardLoadError, match="Invalid JSON"):
|
||||
StoryboardValidator.load(temp_path)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_load_wrong_extension(self):
|
||||
"""Test loading a file with wrong extension."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
|
||||
f.write('{}')
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
with pytest.raises(StoryboardLoadError, match="must be JSON"):
|
||||
StoryboardValidator.load(temp_path)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_validate_references(self):
|
||||
"""Test reference validation."""
|
||||
storyboard = Storyboard(
|
||||
project=ProjectSettings(
|
||||
title="Test",
|
||||
resolution=Resolution(width=1920, height=1080),
|
||||
target_duration_s=10
|
||||
),
|
||||
characters=[
|
||||
Character(id="C01", name="Hero")
|
||||
],
|
||||
locations=[
|
||||
Location(id="L01", name="Street")
|
||||
],
|
||||
shots=[
|
||||
Shot(
|
||||
id="S01",
|
||||
duration_s=5,
|
||||
prompt="Test",
|
||||
characters=["C01", "C99"], # C99 doesn't exist
|
||||
location_id="L99" # Doesn't exist
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
issues = StoryboardValidator.validate_references(storyboard)
|
||||
assert len(issues) == 2
|
||||
assert any("C99" in issue for issue in issues)
|
||||
assert any("L99" in issue for issue in issues)
|
||||
Reference in New Issue
Block a user