Initial Commit
This commit is contained in:
23
.env.example
Normal file
23
.env.example
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Environment configuration for storyboard-video
|
||||||
|
# Copy this file to .env and customize as needed
|
||||||
|
|
||||||
|
# Active backend selection (must match a key in config/models.yaml)
|
||||||
|
ACTIVE_BACKEND=wan_t2v_14b
|
||||||
|
|
||||||
|
# Model cache directory (where downloaded models are stored)
|
||||||
|
MODEL_CACHE_DIR=~/.cache/storyboard-video/models
|
||||||
|
|
||||||
|
# HuggingFace Hub cache (if using HF models)
|
||||||
|
HF_HOME=~/.cache/huggingface
|
||||||
|
|
||||||
|
# CUDA settings
|
||||||
|
# CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
|
# Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
|
||||||
|
# Output directory base path
|
||||||
|
OUTPUT_BASE_DIR=./outputs
|
||||||
|
|
||||||
|
# FFmpeg path (leave empty to use system PATH)
|
||||||
|
FFMPEG_PATH=
|
||||||
138
.gitignore
vendored
Normal file
138
.gitignore
vendored
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
.venv
|
||||||
|
|
||||||
|
# Conda
|
||||||
|
conda-env/
|
||||||
|
*.conda
|
||||||
|
|
||||||
|
# IDEs
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# Pytest
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
|
||||||
|
# Environment variables
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
.env.*.local
|
||||||
|
|
||||||
|
# Model caches and outputs
|
||||||
|
models/
|
||||||
|
*.ckpt
|
||||||
|
*.safetensors
|
||||||
|
*.bin
|
||||||
|
*.pth
|
||||||
|
*.onnx
|
||||||
|
outputs/
|
||||||
|
checkpoints.db
|
||||||
|
|
||||||
|
# Temporary files
|
||||||
|
tmp/
|
||||||
|
temp/
|
||||||
|
*.tmp
|
||||||
|
*.temp
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
logs/
|
||||||
|
|
||||||
|
# Video files (generated)
|
||||||
|
*.mp4
|
||||||
|
*.mov
|
||||||
|
*.avi
|
||||||
|
*.mkv
|
||||||
|
*.webm
|
||||||
|
|
||||||
|
# Image files (generated)
|
||||||
|
*.png
|
||||||
|
*.jpg
|
||||||
|
*.jpeg
|
||||||
|
*.gif
|
||||||
|
*.bmp
|
||||||
|
*.tiff
|
||||||
|
|
||||||
|
# Audio files (generated)
|
||||||
|
*.wav
|
||||||
|
*.mp3
|
||||||
|
*.aac
|
||||||
|
*.flac
|
||||||
|
*.ogg
|
||||||
|
|
||||||
|
# HuggingFace cache
|
||||||
|
.cache/
|
||||||
|
transformers_cache/
|
||||||
|
diffusers_cache/
|
||||||
|
|
||||||
|
# Windows
|
||||||
|
Thumbs.db
|
||||||
|
ehthumbs.db
|
||||||
|
Desktop.ini
|
||||||
|
$RECYCLE.BIN/
|
||||||
|
|
||||||
|
# macOS
|
||||||
|
.AppleDouble
|
||||||
|
.LSOverride
|
||||||
|
Icon
|
||||||
|
._*
|
||||||
|
.DocumentRevisions-V100
|
||||||
|
.fseventsd
|
||||||
|
.Spotlight-V100
|
||||||
|
.TemporaryItems
|
||||||
|
.Trashes
|
||||||
|
.VolumeIcon.icns
|
||||||
|
.com.apple.timemachine.donotpresent
|
||||||
|
|
||||||
|
# Project specific
|
||||||
|
# Keep templates but ignore generated outputs
|
||||||
|
!/templates/*.json
|
||||||
|
!/templates/*.yaml
|
||||||
|
!/templates/*.yml
|
||||||
|
|
||||||
|
# Allow example files
|
||||||
|
!.env.example
|
||||||
|
!config/*.yaml
|
||||||
|
!config/*.yml
|
||||||
|
|
||||||
|
# Keep docs
|
||||||
|
docs/*.md
|
||||||
210
AGENTS.MD
Normal file
210
AGENTS.MD
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
# agents.md
|
||||||
|
## Project: Local AI Video Generation from Text Storyboards (Windows + RTX 5070 12GB)
|
||||||
|
|
||||||
|
### 0) Who this is for
|
||||||
|
The owner (user) is not an ML expert. The system must:
|
||||||
|
- be reproducible (conda + requirements)
|
||||||
|
- have guardrails (configs, logs, validation)
|
||||||
|
- be test-driven (pytest)
|
||||||
|
- maintain docs (developer + user)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1) High-Level Goal
|
||||||
|
Build a local pipeline that converts **text-only storyboards** into **15–30 second videos** by:
|
||||||
|
1) converting storyboard -> shot plan
|
||||||
|
2) generating shot clips (T2V or I2V when possible)
|
||||||
|
3) assembling clips into a final MP4
|
||||||
|
4) upscaling to 2K/4K if desired
|
||||||
|
|
||||||
|
This is a **shot-based** system, not “one prompt makes a whole movie”.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2) Hard Constraints (Hardware & OS)
|
||||||
|
Target system:
|
||||||
|
- Windows 11
|
||||||
|
- NVIDIA RTX 5070 (12GB VRAM) - Must use GPU.
|
||||||
|
- 32GB RAM
|
||||||
|
- 2TB SSD
|
||||||
|
- Anaconda available
|
||||||
|
|
||||||
|
Design must be stable under 12GB VRAM using:
|
||||||
|
- fp16/bf16
|
||||||
|
- attention slicing
|
||||||
|
- xFormers / SDPA where supported
|
||||||
|
- optional CPU offload
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3) Output Targets (Realistic)
|
||||||
|
- Native generation: 720p–1080p (preferred)
|
||||||
|
- Final delivery: 1080p required; 2K/4K via upscaling
|
||||||
|
- Duration: 15–30s per video (may be segmented)
|
||||||
|
- FPS: 24 default
|
||||||
|
- Output: MP4 (H.264/H.265)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4) CUDA 13.1 Reality & PyTorch Plan (Critical)
|
||||||
|
User has CUDA Toolkit 13.1 installed. Current PyTorch builds generally ship with and target CUDA 12.x runtimes.
|
||||||
|
We must NOT assume PyTorch will build/run against local CUDA 13.1 toolkit.
|
||||||
|
|
||||||
|
**Plan:**
|
||||||
|
- Use **PyTorch prebuilt binaries that bundle CUDA runtime** (e.g., cu121 / cu124).
|
||||||
|
- Rely on NVIDIA driver compatibility rather than local CUDA toolkit version.
|
||||||
|
- Avoid compiling custom CUDA extensions unless necessary.
|
||||||
|
|
||||||
|
Implementation notes:
|
||||||
|
- Prefer installing PyTorch via conda or pip using official CUDA 12.x builds.
|
||||||
|
- If xFormers causes build issues, use PyTorch SDPA and disable xFormers.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5) Approved Stack (Do Not Deviate)
|
||||||
|
### Core
|
||||||
|
- Python 3.10 or 3.11 (conda env)
|
||||||
|
- PyTorch (CUDA 12.x build: cu121 or cu124)
|
||||||
|
- diffusers + transformers + accelerate + safetensors
|
||||||
|
- ffmpeg for assembly
|
||||||
|
- opencv-python for frame IO (if needed)
|
||||||
|
- pydantic for config/schema validation
|
||||||
|
- rich / loguru for logs
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
- pytest
|
||||||
|
- pytest-cov
|
||||||
|
- snapshot-ish tests where feasible (metadata + shapes, not visual perfection)
|
||||||
|
|
||||||
|
### Docs
|
||||||
|
- /docs/developer.md (developer documentation)
|
||||||
|
- /docs/user.md (user manual)
|
||||||
|
- Keep docs updated alongside code changes.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6) Video Models (Pragmatic Choices)
|
||||||
|
### Primary (target)
|
||||||
|
- WAN 2.x family (T2V; optional I2V if supported in chosen pipeline)
|
||||||
|
Goal: best possible quality on consumer VRAM with chunking.
|
||||||
|
|
||||||
|
### Secondary / fallback
|
||||||
|
- Stable Video Diffusion (SVD) if WAN is unstable
|
||||||
|
- LTX-Video (only if it fits and is stable in our stack)
|
||||||
|
|
||||||
|
All model backends must implement the same interface:
|
||||||
|
- generate_shot(shot_spec) -> video_file + metadata
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7) Canonical Input: Storyboard JSON
|
||||||
|
Storyboard source is text-only (often AI-generated). We will store and validate it as JSON.
|
||||||
|
|
||||||
|
A template exists at: `templates/storyboard.template.json`
|
||||||
|
|
||||||
|
We will later build a utility script:
|
||||||
|
- input: plain text fields or a simple text format
|
||||||
|
- output: valid storyboard JSON
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8) Pipeline Modules (Required)
|
||||||
|
### A) Storyboard parsing & validation
|
||||||
|
- Load storyboard JSON
|
||||||
|
- Validate schema
|
||||||
|
- Expand defaults (fps, resolution, global style)
|
||||||
|
- Produce normalized shot list
|
||||||
|
|
||||||
|
### B) Prompt compilation
|
||||||
|
- Merge global style + shot prompt + camera notes
|
||||||
|
- Produce positive + negative prompts
|
||||||
|
- Keep deterministic via seeds
|
||||||
|
|
||||||
|
### C) Generation runner (per shot)
|
||||||
|
- For each shot: generate clip
|
||||||
|
- Support:
|
||||||
|
- seed control
|
||||||
|
- chunking (e.g., generate 4–6 seconds then continue)
|
||||||
|
- optional init frame handoff between shots
|
||||||
|
|
||||||
|
### D) Assembly
|
||||||
|
- Use ffmpeg concat to build final video
|
||||||
|
- Optionally add:
|
||||||
|
- transitions
|
||||||
|
- temp audio
|
||||||
|
- burn-in shot IDs for debugging mode
|
||||||
|
|
||||||
|
### E) Upscaling (optional)
|
||||||
|
- Upscale final to 2K/4K (post step)
|
||||||
|
- Keep this modular so user can skip.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9) Determinism & Logging (Must Have)
|
||||||
|
For each shot and final render, save:
|
||||||
|
- prompts (positive/negative)
|
||||||
|
- seed(s)
|
||||||
|
- model + revision/hash info if available
|
||||||
|
- inference params (steps, cfg, sampler, resolution, fps, frames)
|
||||||
|
- timing + VRAM notes if possible
|
||||||
|
|
||||||
|
Every run produces a folder:
|
||||||
|
- outputs/<project>/<timestamp>/
|
||||||
|
- shots/
|
||||||
|
- assembled/
|
||||||
|
- metadata/
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10) Testing Rules (Hard Requirement)
|
||||||
|
- Tests must be written alongside features.
|
||||||
|
- Whenever a file/function is modified, corresponding tests MUST be updated.
|
||||||
|
- Prefer tests that verify:
|
||||||
|
- schema validation works
|
||||||
|
- prompt compiler output is stable
|
||||||
|
- shot planner expands durations -> frame counts
|
||||||
|
- assembly command lines are correct
|
||||||
|
- metadata is generated correctly
|
||||||
|
|
||||||
|
Do not require “visual quality” assertions. Test structure and determinism.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 11) Documentation Rules (Hard Requirement)
|
||||||
|
Maintain these continuously:
|
||||||
|
- docs/developer.md
|
||||||
|
- architecture
|
||||||
|
- install steps
|
||||||
|
- how to run tests
|
||||||
|
- how to add a new model backend
|
||||||
|
- docs/user.md
|
||||||
|
- quickstart
|
||||||
|
- how to create storyboard JSON
|
||||||
|
- how to run generation
|
||||||
|
- where outputs go
|
||||||
|
- troubleshooting (VRAM, drivers, ffmpeg)
|
||||||
|
|
||||||
|
Docs must be updated whenever CLI flags, file formats, or workflows change.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 12) Project Files to Maintain
|
||||||
|
Required:
|
||||||
|
- requirements.txt (pip deps)
|
||||||
|
- environment.yml (conda env)
|
||||||
|
- templates/storyboard.template.json
|
||||||
|
- docs/developer.md
|
||||||
|
- docs/user.md
|
||||||
|
- src/ (implementation)
|
||||||
|
- tests/ (pytest)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 13) Definition of Done
|
||||||
|
A feature is “done” only if:
|
||||||
|
- implemented
|
||||||
|
- tests added/updated
|
||||||
|
- docs updated
|
||||||
|
- reproducible install instructions remain valid
|
||||||
|
|
||||||
|
End of file.
|
||||||
35
TODO.MD
Normal file
35
TODO.MD
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# TODO
|
||||||
|
|
||||||
|
## Repo bootstrap
|
||||||
|
- [x] Define project direction and constraints in `agents.md`
|
||||||
|
- [x] Add `requirements.txt` (pip dependencies)
|
||||||
|
- [x] Add `environment.yml` (conda environment; PyTorch CUDA 12.x runtime strategy)
|
||||||
|
- [x] Add storyboard JSON template at `templates/storyboard.template.json`
|
||||||
|
|
||||||
|
## Core implementation (next)
|
||||||
|
- [ ] Create repo structure: `src/`, `tests/`, `docs/`, `templates/`, `outputs/`
|
||||||
|
- [ ] Implement storyboard schema validator (pydantic) + loader
|
||||||
|
- [ ] Implement prompt compiler (global style + shot + camera)
|
||||||
|
- [ ] Implement shot planning (duration -> frame count, chunk plan)
|
||||||
|
- [ ] Implement model backend interface (`BaseVideoBackend`)
|
||||||
|
- [ ] Implement WAN backend (primary) with VRAM-safe defaults
|
||||||
|
- [ ] Implement fallback backend (SVD) for reliability testing
|
||||||
|
- [ ] Implement ffmpeg assembler (concat + optional audio + debug burn-in)
|
||||||
|
- [ ] Implement optional upscaling module (post-process)
|
||||||
|
|
||||||
|
## Utilities
|
||||||
|
- [ ] Write storyboard “plain text → JSON” utility script (fills `storyboard.template.json`)
|
||||||
|
- [ ] Add config file support (YAML/JSON) for global defaults
|
||||||
|
|
||||||
|
## Testing (parallel work; required)
|
||||||
|
- [ ] Add `pytest` scaffolding
|
||||||
|
- [ ] Add tests for schema validation
|
||||||
|
- [ ] Add tests for prompt compilation determinism
|
||||||
|
- [ ] Add tests for shot planning (frames/chunks)
|
||||||
|
- [ ] Add tests for ffmpeg command generation (no actual render needed)
|
||||||
|
- [ ] Ensure every code change includes a corresponding test update
|
||||||
|
|
||||||
|
## Documentation (maintained continuously)
|
||||||
|
- [ ] Create `docs/developer.md` (install, architecture, tests, adding backends)
|
||||||
|
- [ ] Create `docs/user.md` (quickstart, storyboard creation, running, outputs, troubleshooting)
|
||||||
|
- [ ] Keep docs updated whenever CLI/config/schema changes
|
||||||
61
config/models.yaml
Normal file
61
config/models.yaml
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
# Model configurations for video generation backends
|
||||||
|
# Backend selection is controlled via ACTIVE_BACKEND environment variable
|
||||||
|
# or --backend CLI flag
|
||||||
|
|
||||||
|
backends:
|
||||||
|
wan_t2v_14b:
|
||||||
|
name: "Wan 2.1 T2V 14B"
|
||||||
|
class: "generation.backends.wan.WanBackend"
|
||||||
|
model_id: "Wan-AI/Wan2.1-T2V-14B"
|
||||||
|
vram_gb: 12
|
||||||
|
dtype: "fp16"
|
||||||
|
enable_vae_slicing: true
|
||||||
|
enable_vae_tiling: true
|
||||||
|
chunking:
|
||||||
|
enabled: true
|
||||||
|
mode: "sequential" # Options: sequential, overlapping
|
||||||
|
max_chunk_seconds: 4
|
||||||
|
overlap_seconds: 1 # Only used when mode is overlapping
|
||||||
|
|
||||||
|
wan_t2v_1_3b:
|
||||||
|
name: "Wan 2.1 T2V 1.3B"
|
||||||
|
class: "generation.backends.wan.WanBackend"
|
||||||
|
model_id: "Wan-AI/Wan2.1-T2V-1.3B"
|
||||||
|
vram_gb: 8
|
||||||
|
dtype: "fp16"
|
||||||
|
enable_vae_slicing: true
|
||||||
|
enable_vae_tiling: false
|
||||||
|
chunking:
|
||||||
|
enabled: true
|
||||||
|
mode: "sequential"
|
||||||
|
max_chunk_seconds: 6
|
||||||
|
overlap_seconds: 1
|
||||||
|
|
||||||
|
svd:
|
||||||
|
name: "Stable Video Diffusion"
|
||||||
|
class: "generation.backends.svd.SVDBackend"
|
||||||
|
model_id: "stabilityai/stable-video-diffusion-img2vid-xt"
|
||||||
|
vram_gb: 10
|
||||||
|
dtype: "fp16"
|
||||||
|
enable_vae_slicing: true
|
||||||
|
enable_vae_tiling: true
|
||||||
|
chunking:
|
||||||
|
enabled: false
|
||||||
|
|
||||||
|
# Default backend selection
|
||||||
|
# Override with ACTIVE_BACKEND environment variable
|
||||||
|
active_backend: "wan_t2v_14b"
|
||||||
|
|
||||||
|
# Global generation defaults
|
||||||
|
defaults:
|
||||||
|
fps: 24
|
||||||
|
resolution:
|
||||||
|
width: 1920
|
||||||
|
height: 1080
|
||||||
|
|
||||||
|
# Checkpoint database settings
|
||||||
|
checkpoint_db: "outputs/checkpoints.db"
|
||||||
|
|
||||||
|
# Model cache directory
|
||||||
|
# Override with MODEL_CACHE_DIR environment variable
|
||||||
|
model_cache_dir: "~/.cache/storyboard-video/models"
|
||||||
24
environment.yml
Normal file
24
environment.yml
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# environment.yml
|
||||||
|
name: storyboard-video
|
||||||
|
channels:
|
||||||
|
- pytorch
|
||||||
|
- nvidia
|
||||||
|
- conda-forge
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- python=3.11
|
||||||
|
- pip=24.3.1
|
||||||
|
|
||||||
|
# FFmpeg in env is easiest on Windows if you rely on conda-forge
|
||||||
|
- ffmpeg=7.1
|
||||||
|
|
||||||
|
# PyTorch with CUDA runtime bundled (NOT using local CUDA 13.1 toolkit)
|
||||||
|
# Using compatible versions for PyTorch 2.5.1 with CUDA 12.4
|
||||||
|
- pytorch=2.5.1
|
||||||
|
- pytorch-cuda=12.4
|
||||||
|
- torchvision=0.20.1
|
||||||
|
- torchaudio=2.5.1
|
||||||
|
|
||||||
|
# pip deps (mirrors requirements.txt)
|
||||||
|
- pip:
|
||||||
|
- -r requirements.txt
|
||||||
7
pytest.ini
Normal file
7
pytest.ini
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# pytest configuration
|
||||||
|
[tool:pytest]
|
||||||
|
testpaths = tests
|
||||||
|
python_files = test_*.py
|
||||||
|
python_classes = Test*
|
||||||
|
python_functions = test_*
|
||||||
|
addopts = -v --tb=short
|
||||||
26
requirements.txt
Normal file
26
requirements.txt
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# requirements.txt
|
||||||
|
# Install PyTorch separately (see environment.yml and docs).
|
||||||
|
|
||||||
|
diffusers==0.32.2
|
||||||
|
transformers==4.48.3
|
||||||
|
accelerate==1.3.0
|
||||||
|
safetensors==0.5.2
|
||||||
|
|
||||||
|
pydantic==2.10.6
|
||||||
|
pyyaml==6.0.2
|
||||||
|
numpy==2.1.3
|
||||||
|
pillow==11.1.0
|
||||||
|
opencv-python==4.11.0.86
|
||||||
|
|
||||||
|
rich==13.9.4
|
||||||
|
loguru==0.7.3
|
||||||
|
|
||||||
|
tqdm==4.67.1
|
||||||
|
requests==2.32.3
|
||||||
|
|
||||||
|
# CLI & tooling
|
||||||
|
typer==0.15.1
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
pytest==8.3.4
|
||||||
|
pytest-cov==6.0.0
|
||||||
5
src/__init__.py
Normal file
5
src/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""
|
||||||
|
Storyboard video generation pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
3
src/assembly/__init__.py
Normal file
3
src/assembly/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
Video assembly and post-processing.
|
||||||
|
"""
|
||||||
3
src/cli/__init__.py
Normal file
3
src/cli/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
CLI entry points.
|
||||||
|
"""
|
||||||
25
src/core/__init__.py
Normal file
25
src/core/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""
|
||||||
|
Core utilities and shared components.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .config import Config, BackendConfig, ConfigLoader, get_config, reload_config
|
||||||
|
from .checkpoint import (
|
||||||
|
CheckpointManager,
|
||||||
|
ShotCheckpoint,
|
||||||
|
ProjectCheckpoint,
|
||||||
|
ShotStatus,
|
||||||
|
ProjectStatus
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Config',
|
||||||
|
'BackendConfig',
|
||||||
|
'ConfigLoader',
|
||||||
|
'get_config',
|
||||||
|
'reload_config',
|
||||||
|
'CheckpointManager',
|
||||||
|
'ShotCheckpoint',
|
||||||
|
'ProjectCheckpoint',
|
||||||
|
'ShotStatus',
|
||||||
|
'ProjectStatus'
|
||||||
|
]
|
||||||
435
src/core/checkpoint.py
Normal file
435
src/core/checkpoint.py
Normal file
@@ -0,0 +1,435 @@
|
|||||||
|
"""
|
||||||
|
Checkpoint and resume system using SQLite.
|
||||||
|
Tracks generation progress and enables resuming from failures.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ShotStatus(str, Enum):
|
||||||
|
"""Status of a shot generation."""
|
||||||
|
PENDING = "pending"
|
||||||
|
IN_PROGRESS = "in_progress"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
SKIPPED = "skipped"
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectStatus(str, Enum):
|
||||||
|
"""Status of a project."""
|
||||||
|
INITIALIZED = "initialized"
|
||||||
|
RUNNING = "running"
|
||||||
|
PAUSED = "paused"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ShotCheckpoint:
|
||||||
|
"""Checkpoint data for a single shot."""
|
||||||
|
shot_id: str
|
||||||
|
status: ShotStatus
|
||||||
|
output_path: Optional[str] = None
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
started_at: Optional[str] = None
|
||||||
|
completed_at: Optional[str] = None
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.metadata is None:
|
||||||
|
self.metadata = {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProjectCheckpoint:
|
||||||
|
"""Checkpoint data for a project."""
|
||||||
|
project_id: str
|
||||||
|
storyboard_path: str
|
||||||
|
output_dir: str
|
||||||
|
status: ProjectStatus
|
||||||
|
backend_name: str
|
||||||
|
started_at: str
|
||||||
|
completed_at: Optional[str] = None
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointManager:
|
||||||
|
"""Manages checkpoints using SQLite database."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str = "outputs/checkpoints.db"):
|
||||||
|
"""
|
||||||
|
Initialize checkpoint manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to SQLite database file
|
||||||
|
"""
|
||||||
|
self.db_path = Path(db_path)
|
||||||
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._init_db()
|
||||||
|
|
||||||
|
def _init_db(self):
|
||||||
|
"""Initialize database schema."""
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
try:
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS projects (
|
||||||
|
project_id TEXT PRIMARY KEY,
|
||||||
|
storyboard_path TEXT NOT NULL,
|
||||||
|
output_dir TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
backend_name TEXT NOT NULL,
|
||||||
|
started_at TEXT NOT NULL,
|
||||||
|
completed_at TEXT,
|
||||||
|
error_message TEXT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS shots (
|
||||||
|
project_id TEXT NOT NULL,
|
||||||
|
shot_id TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
output_path TEXT,
|
||||||
|
error_message TEXT,
|
||||||
|
started_at TEXT,
|
||||||
|
completed_at TEXT,
|
||||||
|
metadata TEXT,
|
||||||
|
PRIMARY KEY (project_id, shot_id),
|
||||||
|
FOREIGN KEY (project_id) REFERENCES projects(project_id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _get_connection(self):
|
||||||
|
"""Get a new database connection."""
|
||||||
|
return sqlite3.connect(self.db_path)
|
||||||
|
|
||||||
|
def create_project(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
storyboard_path: str,
|
||||||
|
output_dir: str,
|
||||||
|
backend_name: str
|
||||||
|
) -> ProjectCheckpoint:
|
||||||
|
"""
|
||||||
|
Create a new project checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Unique project identifier
|
||||||
|
storyboard_path: Path to storyboard JSON file
|
||||||
|
output_dir: Output directory for generated files
|
||||||
|
backend_name: Name of the generation backend
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ProjectCheckpoint object
|
||||||
|
"""
|
||||||
|
checkpoint = ProjectCheckpoint(
|
||||||
|
project_id=project_id,
|
||||||
|
storyboard_path=storyboard_path,
|
||||||
|
output_dir=output_dir,
|
||||||
|
status=ProjectStatus.INITIALIZED,
|
||||||
|
backend_name=backend_name,
|
||||||
|
started_at=datetime.now().isoformat()
|
||||||
|
)
|
||||||
|
|
||||||
|
conn = self._get_connection()
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO projects
|
||||||
|
(project_id, storyboard_path, output_dir, status, backend_name, started_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
checkpoint.project_id,
|
||||||
|
checkpoint.storyboard_path,
|
||||||
|
checkpoint.output_dir,
|
||||||
|
checkpoint.status.value,
|
||||||
|
checkpoint.backend_name,
|
||||||
|
checkpoint.started_at
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
def get_project(self, project_id: str) -> Optional[ProjectCheckpoint]:
|
||||||
|
"""Get project checkpoint by ID."""
|
||||||
|
conn = self._get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT * FROM projects WHERE project_id = ?",
|
||||||
|
(project_id,)
|
||||||
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return ProjectCheckpoint(
|
||||||
|
project_id=row[0],
|
||||||
|
storyboard_path=row[1],
|
||||||
|
output_dir=row[2],
|
||||||
|
status=ProjectStatus(row[3]),
|
||||||
|
backend_name=row[4],
|
||||||
|
started_at=row[5],
|
||||||
|
completed_at=row[6],
|
||||||
|
error_message=row[7]
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def update_project_status(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
status: ProjectStatus,
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""Update project status."""
|
||||||
|
completed_at = None
|
||||||
|
if status in [ProjectStatus.COMPLETED, ProjectStatus.FAILED]:
|
||||||
|
completed_at = datetime.now().isoformat()
|
||||||
|
|
||||||
|
conn = self._get_connection()
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE projects
|
||||||
|
SET status = ?, error_message = ?, completed_at = ?
|
||||||
|
WHERE project_id = ?
|
||||||
|
""",
|
||||||
|
(status.value, error_message, completed_at, project_id)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def create_shot(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
shot_id: str,
|
||||||
|
status: ShotStatus = ShotStatus.PENDING
|
||||||
|
) -> ShotCheckpoint:
|
||||||
|
"""Create a shot checkpoint."""
|
||||||
|
checkpoint = ShotCheckpoint(
|
||||||
|
shot_id=shot_id,
|
||||||
|
status=status
|
||||||
|
)
|
||||||
|
|
||||||
|
conn = self._get_connection()
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO shots
|
||||||
|
(project_id, shot_id, status, metadata)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
project_id,
|
||||||
|
shot_id,
|
||||||
|
status.value,
|
||||||
|
json.dumps(checkpoint.metadata)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
def get_shot(self, project_id: str, shot_id: str) -> Optional[ShotCheckpoint]:
|
||||||
|
"""Get shot checkpoint."""
|
||||||
|
conn = self._get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT * FROM shots WHERE project_id = ? AND shot_id = ?",
|
||||||
|
(project_id, shot_id)
|
||||||
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return ShotCheckpoint(
|
||||||
|
shot_id=row[1],
|
||||||
|
status=ShotStatus(row[2]),
|
||||||
|
output_path=row[3],
|
||||||
|
error_message=row[4],
|
||||||
|
started_at=row[5],
|
||||||
|
completed_at=row[6],
|
||||||
|
metadata=json.loads(row[7]) if row[7] else {}
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def get_project_shots(self, project_id: str) -> List[ShotCheckpoint]:
|
||||||
|
"""Get all shots for a project."""
|
||||||
|
conn = self._get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT * FROM shots WHERE project_id = ? ORDER BY shot_id",
|
||||||
|
(project_id,)
|
||||||
|
)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
|
return [
|
||||||
|
ShotCheckpoint(
|
||||||
|
shot_id=row[1],
|
||||||
|
status=ShotStatus(row[2]),
|
||||||
|
output_path=row[3],
|
||||||
|
error_message=row[4],
|
||||||
|
started_at=row[5],
|
||||||
|
completed_at=row[6],
|
||||||
|
metadata=json.loads(row[7]) if row[7] else {}
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def update_shot(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
shot_id: str,
|
||||||
|
status: Optional[ShotStatus] = None,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
error_message: Optional[str] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
):
|
||||||
|
"""Update shot checkpoint."""
|
||||||
|
updates = []
|
||||||
|
params = []
|
||||||
|
|
||||||
|
if status is not None:
|
||||||
|
updates.append("status = ?")
|
||||||
|
params.append(status.value)
|
||||||
|
|
||||||
|
if status == ShotStatus.IN_PROGRESS:
|
||||||
|
updates.append("started_at = ?")
|
||||||
|
params.append(datetime.now().isoformat())
|
||||||
|
elif status in [ShotStatus.COMPLETED, ShotStatus.FAILED]:
|
||||||
|
updates.append("completed_at = ?")
|
||||||
|
params.append(datetime.now().isoformat())
|
||||||
|
|
||||||
|
if output_path is not None:
|
||||||
|
updates.append("output_path = ?")
|
||||||
|
params.append(output_path)
|
||||||
|
|
||||||
|
if error_message is not None:
|
||||||
|
updates.append("error_message = ?")
|
||||||
|
params.append(error_message)
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
updates.append("metadata = ?")
|
||||||
|
params.append(json.dumps(metadata))
|
||||||
|
|
||||||
|
if not updates:
|
||||||
|
return
|
||||||
|
|
||||||
|
params.extend([project_id, shot_id])
|
||||||
|
|
||||||
|
conn = self._get_connection()
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
f"""
|
||||||
|
UPDATE shots
|
||||||
|
SET {', '.join(updates)}
|
||||||
|
WHERE project_id = ? AND shot_id = ?
|
||||||
|
""",
|
||||||
|
params
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def get_pending_shots(self, project_id: str) -> List[str]:
|
||||||
|
"""Get list of pending shot IDs for a project."""
|
||||||
|
conn = self._get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT shot_id FROM shots
|
||||||
|
WHERE project_id = ? AND status = ?
|
||||||
|
ORDER BY shot_id
|
||||||
|
""",
|
||||||
|
(project_id, ShotStatus.PENDING.value)
|
||||||
|
)
|
||||||
|
return [row[0] for row in cursor.fetchall()]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def get_failed_shots(self, project_id: str) -> List[str]:
|
||||||
|
"""Get list of failed shot IDs for a project."""
|
||||||
|
conn = self._get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT shot_id FROM shots
|
||||||
|
WHERE project_id = ? AND status = ?
|
||||||
|
ORDER BY shot_id
|
||||||
|
""",
|
||||||
|
(project_id, ShotStatus.FAILED.value)
|
||||||
|
)
|
||||||
|
return [row[0] for row in cursor.fetchall()]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def can_resume(self, project_id: str) -> bool:
|
||||||
|
"""Check if a project can be resumed."""
|
||||||
|
project = self.get_project(project_id)
|
||||||
|
if project is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if project.status == ProjectStatus.COMPLETED:
|
||||||
|
return False
|
||||||
|
|
||||||
|
pending = self.get_pending_shots(project_id)
|
||||||
|
failed = self.get_failed_shots(project_id)
|
||||||
|
|
||||||
|
return len(pending) > 0 or len(failed) > 0
|
||||||
|
|
||||||
|
def delete_project(self, project_id: str):
|
||||||
|
"""Delete a project and all its shots."""
|
||||||
|
conn = self._get_connection()
|
||||||
|
try:
|
||||||
|
conn.execute("DELETE FROM shots WHERE project_id = ?", (project_id,))
|
||||||
|
conn.execute("DELETE FROM projects WHERE project_id = ?", (project_id,))
|
||||||
|
conn.commit()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def list_projects(self) -> List[ProjectCheckpoint]:
|
||||||
|
"""List all projects."""
|
||||||
|
conn = self._get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT * FROM projects ORDER BY started_at DESC"
|
||||||
|
)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
|
return [
|
||||||
|
ProjectCheckpoint(
|
||||||
|
project_id=row[0],
|
||||||
|
storyboard_path=row[1],
|
||||||
|
output_dir=row[2],
|
||||||
|
status=ProjectStatus(row[3]),
|
||||||
|
backend_name=row[4],
|
||||||
|
started_at=row[5],
|
||||||
|
completed_at=row[6],
|
||||||
|
error_message=row[7]
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
212
src/core/config.py
Normal file
212
src/core/config.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""
|
||||||
|
Configuration management system.
|
||||||
|
Handles YAML configs and environment variables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackendConfig:
|
||||||
|
"""Configuration for a generation backend."""
|
||||||
|
name: str
|
||||||
|
model_class: str
|
||||||
|
model_id: str
|
||||||
|
vram_gb: int
|
||||||
|
dtype: str
|
||||||
|
enable_vae_slicing: bool = True
|
||||||
|
enable_vae_tiling: bool = False
|
||||||
|
chunking_enabled: bool = True
|
||||||
|
chunking_mode: str = "sequential"
|
||||||
|
max_chunk_seconds: int = 4
|
||||||
|
overlap_seconds: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
"""Application configuration."""
|
||||||
|
active_backend: str = "wan_t2v_14b"
|
||||||
|
backends: Dict[str, BackendConfig] = field(default_factory=dict)
|
||||||
|
defaults: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
checkpoint_db: str = "outputs/checkpoints.db"
|
||||||
|
model_cache_dir: str = "~/.cache/storyboard-video/models"
|
||||||
|
log_level: str = "INFO"
|
||||||
|
output_base_dir: str = "./outputs"
|
||||||
|
|
||||||
|
def get_backend(self, name: Optional[str] = None) -> BackendConfig:
|
||||||
|
"""Get backend configuration by name."""
|
||||||
|
name = name or self.active_backend
|
||||||
|
if name not in self.backends:
|
||||||
|
raise ValueError(f"Backend '{name}' not found in configuration")
|
||||||
|
return self.backends[name]
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigLoader:
|
||||||
|
"""Loads configuration from YAML files and environment variables."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(config_path: Optional[Path] = None, env_file: Optional[Path] = None) -> Config:
|
||||||
|
"""
|
||||||
|
Load configuration from files and environment.
|
||||||
|
|
||||||
|
Priority (highest to lowest):
|
||||||
|
1. Environment variables
|
||||||
|
2. .env file
|
||||||
|
3. YAML config file
|
||||||
|
4. Defaults
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to YAML config file (default: config/models.yaml)
|
||||||
|
env_file: Path to .env file (default: .env if exists)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Config object
|
||||||
|
"""
|
||||||
|
# Start with defaults
|
||||||
|
config = Config()
|
||||||
|
|
||||||
|
# Load YAML config
|
||||||
|
if config_path is None:
|
||||||
|
config_path = Path("config/models.yaml")
|
||||||
|
|
||||||
|
if config_path.exists():
|
||||||
|
config = ConfigLoader._load_yaml(config_path, config)
|
||||||
|
|
||||||
|
# Load .env file
|
||||||
|
if env_file is None:
|
||||||
|
env_file = Path(".env")
|
||||||
|
|
||||||
|
if env_file.exists():
|
||||||
|
config = ConfigLoader._load_env_file(env_file, config)
|
||||||
|
|
||||||
|
# Override with environment variables
|
||||||
|
config = ConfigLoader._load_env_vars(config)
|
||||||
|
|
||||||
|
# Expand paths
|
||||||
|
config.model_cache_dir = str(Path(config.model_cache_dir).expanduser())
|
||||||
|
config.checkpoint_db = str(Path(config.checkpoint_db).expanduser())
|
||||||
|
config.output_base_dir = str(Path(config.output_base_dir).expanduser())
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_yaml(path: Path, config: Config) -> Config:
|
||||||
|
"""Load configuration from YAML file."""
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
return config
|
||||||
|
|
||||||
|
# Load backends
|
||||||
|
if 'backends' in data:
|
||||||
|
for backend_id, backend_data in data['backends'].items():
|
||||||
|
chunking = backend_data.get('chunking', {})
|
||||||
|
config.backends[backend_id] = BackendConfig(
|
||||||
|
name=backend_data.get('name', backend_id),
|
||||||
|
model_class=backend_data.get('class', ''),
|
||||||
|
model_id=backend_data.get('model_id', ''),
|
||||||
|
vram_gb=backend_data.get('vram_gb', 12),
|
||||||
|
dtype=backend_data.get('dtype', 'fp16'),
|
||||||
|
enable_vae_slicing=backend_data.get('enable_vae_slicing', True),
|
||||||
|
enable_vae_tiling=backend_data.get('enable_vae_tiling', False),
|
||||||
|
chunking_enabled=chunking.get('enabled', True),
|
||||||
|
chunking_mode=chunking.get('mode', 'sequential'),
|
||||||
|
max_chunk_seconds=chunking.get('max_chunk_seconds', 4),
|
||||||
|
overlap_seconds=chunking.get('overlap_seconds', 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load active backend
|
||||||
|
if 'active_backend' in data:
|
||||||
|
config.active_backend = data['active_backend']
|
||||||
|
|
||||||
|
# Load defaults
|
||||||
|
if 'defaults' in data:
|
||||||
|
config.defaults = data['defaults']
|
||||||
|
|
||||||
|
# Load checkpoint DB path
|
||||||
|
if 'checkpoint_db' in data:
|
||||||
|
config.checkpoint_db = data['checkpoint_db']
|
||||||
|
|
||||||
|
# Load model cache dir
|
||||||
|
if 'model_cache_dir' in data:
|
||||||
|
config.model_cache_dir = data['model_cache_dir']
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_env_file(path: Path, config: Config) -> Config:
|
||||||
|
"""Load configuration from .env file."""
|
||||||
|
env_vars = {}
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith('#'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if '=' in line:
|
||||||
|
key, value = line.split('=', 1)
|
||||||
|
key = key.strip()
|
||||||
|
value = value.strip().strip('"\'')
|
||||||
|
env_vars[key] = value
|
||||||
|
|
||||||
|
# Apply env vars directly to config (don't set in os.environ)
|
||||||
|
if 'ACTIVE_BACKEND' in env_vars:
|
||||||
|
config.active_backend = env_vars['ACTIVE_BACKEND']
|
||||||
|
|
||||||
|
if 'MODEL_CACHE_DIR' in env_vars:
|
||||||
|
config.model_cache_dir = env_vars['MODEL_CACHE_DIR']
|
||||||
|
|
||||||
|
if 'LOG_LEVEL' in env_vars:
|
||||||
|
config.log_level = env_vars['LOG_LEVEL']
|
||||||
|
|
||||||
|
if 'OUTPUT_BASE_DIR' in env_vars:
|
||||||
|
config.output_base_dir = env_vars['OUTPUT_BASE_DIR']
|
||||||
|
|
||||||
|
if 'CHECKPOINT_DB' in env_vars:
|
||||||
|
config.checkpoint_db = env_vars['CHECKPOINT_DB']
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load_env_vars(config: Config) -> Config:
|
||||||
|
"""Override config with environment variables."""
|
||||||
|
if 'ACTIVE_BACKEND' in os.environ:
|
||||||
|
config.active_backend = os.environ['ACTIVE_BACKEND']
|
||||||
|
|
||||||
|
if 'MODEL_CACHE_DIR' in os.environ:
|
||||||
|
config.model_cache_dir = os.environ['MODEL_CACHE_DIR']
|
||||||
|
|
||||||
|
if 'LOG_LEVEL' in os.environ:
|
||||||
|
config.log_level = os.environ['LOG_LEVEL']
|
||||||
|
|
||||||
|
if 'OUTPUT_BASE_DIR' in os.environ:
|
||||||
|
config.output_base_dir = os.environ['OUTPUT_BASE_DIR']
|
||||||
|
|
||||||
|
if 'CHECKPOINT_DB' in os.environ:
|
||||||
|
config.checkpoint_db = os.environ['CHECKPOINT_DB']
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
# Global config instance (lazy-loaded)
|
||||||
|
_config: Optional[Config] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_config() -> Config:
|
||||||
|
"""Get the global configuration instance."""
|
||||||
|
global _config
|
||||||
|
if _config is None:
|
||||||
|
_config = ConfigLoader.load()
|
||||||
|
return _config
|
||||||
|
|
||||||
|
|
||||||
|
def reload_config() -> Config:
|
||||||
|
"""Reload configuration from files."""
|
||||||
|
global _config
|
||||||
|
_config = ConfigLoader.load()
|
||||||
|
return _config
|
||||||
17
src/generation/__init__.py
Normal file
17
src/generation/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""
|
||||||
|
Video generation backends.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
BaseVideoBackend,
|
||||||
|
GenerationResult,
|
||||||
|
GenerationSpec,
|
||||||
|
BackendFactory
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'BaseVideoBackend',
|
||||||
|
'GenerationResult',
|
||||||
|
'GenerationSpec',
|
||||||
|
'BackendFactory'
|
||||||
|
]
|
||||||
235
src/generation/base.py
Normal file
235
src/generation/base.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
"""
|
||||||
|
Base interface for video generation backends.
|
||||||
|
All backends (WAN, SVD, etc.) must implement this interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationResult:
|
||||||
|
"""Result of a video generation."""
|
||||||
|
success: bool
|
||||||
|
output_path: Optional[Path] = None
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
vram_usage_gb: Optional[float] = None
|
||||||
|
generation_time_s: Optional[float] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.metadata is None:
|
||||||
|
self.metadata = {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationSpec:
|
||||||
|
"""Specification for video generation."""
|
||||||
|
prompt: str
|
||||||
|
negative_prompt: str
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
num_frames: int
|
||||||
|
fps: int
|
||||||
|
seed: int
|
||||||
|
steps: int
|
||||||
|
cfg_scale: float
|
||||||
|
output_path: Path
|
||||||
|
init_frame_path: Optional[Path] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for metadata storage."""
|
||||||
|
return {
|
||||||
|
"prompt": self.prompt,
|
||||||
|
"negative_prompt": self.negative_prompt,
|
||||||
|
"width": self.width,
|
||||||
|
"height": self.height,
|
||||||
|
"num_frames": self.num_frames,
|
||||||
|
"fps": self.fps,
|
||||||
|
"seed": self.seed,
|
||||||
|
"steps": self.steps,
|
||||||
|
"cfg_scale": self.cfg_scale,
|
||||||
|
"output_path": str(self.output_path),
|
||||||
|
"init_frame_path": str(self.init_frame_path) if self.init_frame_path else None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BaseVideoBackend(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for video generation backends.
|
||||||
|
|
||||||
|
All backends must implement this interface to be used in the pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Initialize the backend with configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Backend-specific configuration dictionary
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.model = None
|
||||||
|
self._is_loaded = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the backend name."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supports_chunking(self) -> bool:
|
||||||
|
"""Whether this backend supports chunked generation."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supports_init_frame(self) -> bool:
|
||||||
|
"""Whether this backend supports init frame conditioning."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self) -> None:
|
||||||
|
"""
|
||||||
|
Load the model into memory.
|
||||||
|
|
||||||
|
This should be called before generate().
|
||||||
|
Implementations should set self._is_loaded = True when complete.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def unload(self) -> None:
|
||||||
|
"""
|
||||||
|
Unload the model from memory.
|
||||||
|
|
||||||
|
Implementations should set self._is_loaded = False when complete.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self, spec: GenerationSpec) -> GenerationResult:
|
||||||
|
"""
|
||||||
|
Generate a video clip.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec: Generation specification
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GenerationResult with output path and metadata
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def generate_chunked(
|
||||||
|
self,
|
||||||
|
spec: GenerationSpec,
|
||||||
|
chunk_duration_s: int,
|
||||||
|
overlap_s: int = 0,
|
||||||
|
mode: str = "sequential"
|
||||||
|
) -> GenerationResult:
|
||||||
|
"""
|
||||||
|
Generate a video in chunks.
|
||||||
|
|
||||||
|
Default implementation for backends that don't natively support chunking.
|
||||||
|
Backends with native support should override this.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec: Generation specification
|
||||||
|
chunk_duration_s: Duration of each chunk in seconds
|
||||||
|
overlap_s: Overlap between chunks in seconds (for blending)
|
||||||
|
mode: "sequential" or "overlapping"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GenerationResult with concatenated video
|
||||||
|
"""
|
||||||
|
if not self.supports_chunking:
|
||||||
|
# Fall back to single generation
|
||||||
|
return self.generate(spec)
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Chunked generation not implemented for this backend"
|
||||||
|
)
|
||||||
|
|
||||||
|
def estimate_vram_usage(self, width: int, height: int, num_frames: int) -> float:
|
||||||
|
"""
|
||||||
|
Estimate VRAM usage for a generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
width: Video width
|
||||||
|
height: Video height
|
||||||
|
num_frames: Number of frames
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated VRAM usage in GB
|
||||||
|
"""
|
||||||
|
# Default implementation - backends should override with better estimates
|
||||||
|
return self.config.get("vram_gb", 12.0)
|
||||||
|
|
||||||
|
def check_vram_available(self, width: int, height: int, num_frames: int) -> bool:
|
||||||
|
"""
|
||||||
|
Check if sufficient VRAM is available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
width: Video width
|
||||||
|
height: Video height
|
||||||
|
num_frames: Number of frames
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if generation should fit in VRAM
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
||||||
|
estimated = self.estimate_vram_usage(width, height, num_frames)
|
||||||
|
# Leave 10% headroom
|
||||||
|
return estimated < total_vram * 0.9
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If we can't check, assume it's OK
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
"""Check if the model is loaded."""
|
||||||
|
return self._is_loaded
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Context manager entry."""
|
||||||
|
self.load()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Context manager exit."""
|
||||||
|
self.unload()
|
||||||
|
|
||||||
|
|
||||||
|
class BackendFactory:
|
||||||
|
"""Factory for creating backend instances."""
|
||||||
|
|
||||||
|
_backends: Dict[str, type] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name: str, backend_class: type):
|
||||||
|
"""Register a backend class."""
|
||||||
|
if not issubclass(backend_class, BaseVideoBackend):
|
||||||
|
raise ValueError(f"Backend must inherit from BaseVideoBackend")
|
||||||
|
cls._backends[name] = backend_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, name: str, config: Dict[str, Any]) -> BaseVideoBackend:
|
||||||
|
"""Create a backend instance."""
|
||||||
|
if name not in cls._backends:
|
||||||
|
raise ValueError(f"Unknown backend: {name}. "
|
||||||
|
f"Available: {list(cls._backends.keys())}")
|
||||||
|
return cls._backends[name](config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_backends(cls) -> List[str]:
|
||||||
|
"""List available backend names."""
|
||||||
|
return list(cls._backends.keys())
|
||||||
39
src/storyboard/__init__.py
Normal file
39
src/storyboard/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""
|
||||||
|
Storyboard module for validation, loading, and prompt compilation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .schema import (
|
||||||
|
Storyboard,
|
||||||
|
ProjectSettings,
|
||||||
|
Shot,
|
||||||
|
Character,
|
||||||
|
Location,
|
||||||
|
GlobalStyle,
|
||||||
|
CameraSettings,
|
||||||
|
GenerationSettings,
|
||||||
|
OutputSettings,
|
||||||
|
Resolution
|
||||||
|
)
|
||||||
|
from .loader import StoryboardValidator, StoryboardLoadError
|
||||||
|
from .prompt_compiler import PromptCompiler, CompiledPrompt
|
||||||
|
from .shot_planner import ShotPlanner, ShotPlan, Chunk
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Storyboard',
|
||||||
|
'ProjectSettings',
|
||||||
|
'Shot',
|
||||||
|
'Character',
|
||||||
|
'Location',
|
||||||
|
'GlobalStyle',
|
||||||
|
'CameraSettings',
|
||||||
|
'GenerationSettings',
|
||||||
|
'OutputSettings',
|
||||||
|
'Resolution',
|
||||||
|
'StoryboardValidator',
|
||||||
|
'StoryboardLoadError',
|
||||||
|
'PromptCompiler',
|
||||||
|
'CompiledPrompt',
|
||||||
|
'ShotPlanner',
|
||||||
|
'ShotPlan',
|
||||||
|
'Chunk'
|
||||||
|
]
|
||||||
84
src/storyboard/loader.py
Normal file
84
src/storyboard/loader.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""
|
||||||
|
Storyboard loader and validator.
|
||||||
|
Handles loading JSON files and validating against schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from .schema import Storyboard
|
||||||
|
|
||||||
|
|
||||||
|
class StoryboardLoadError(Exception):
|
||||||
|
"""Raised when storyboard loading fails."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StoryboardValidator:
|
||||||
|
"""Validates and loads storyboard JSON files."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load(path: Union[str, Path]) -> Storyboard:
|
||||||
|
"""
|
||||||
|
Load and validate a storyboard JSON file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to the JSON file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated Storyboard object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
StoryboardLoadError: If file doesn't exist or is invalid
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
raise StoryboardLoadError(f"Storyboard file not found: {path}")
|
||||||
|
|
||||||
|
if not path.suffix.lower() == '.json':
|
||||||
|
raise StoryboardLoadError(f"Storyboard file must be JSON: {path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise StoryboardLoadError(f"Invalid JSON in {path}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
raise StoryboardLoadError(f"Failed to read {path}: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
storyboard = Storyboard(**data)
|
||||||
|
except Exception as e:
|
||||||
|
raise StoryboardLoadError(f"Storyboard validation failed: {e}")
|
||||||
|
|
||||||
|
return storyboard
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_references(storyboard: Storyboard) -> list[str]:
|
||||||
|
"""
|
||||||
|
Validate that all references (location_id, characters) exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storyboard: Loaded storyboard
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of validation warnings/errors
|
||||||
|
"""
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
# Validate character references
|
||||||
|
char_ids = {char.id for char in storyboard.characters}
|
||||||
|
for shot in storyboard.shots:
|
||||||
|
for char_id in shot.characters:
|
||||||
|
if char_id not in char_ids:
|
||||||
|
issues.append(f"Shot {shot.id}: Character '{char_id}' not defined")
|
||||||
|
|
||||||
|
# Validate location references
|
||||||
|
loc_ids = {loc.id for loc in storyboard.locations}
|
||||||
|
for shot in storyboard.shots:
|
||||||
|
if shot.location_id and shot.location_id not in loc_ids:
|
||||||
|
issues.append(f"Shot {shot.id}: Location '{shot.location_id}' not defined")
|
||||||
|
|
||||||
|
return issues
|
||||||
154
src/storyboard/prompt_compiler.py
Normal file
154
src/storyboard/prompt_compiler.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
"""
|
||||||
|
Prompt compiler for merging global style with shot-specific prompts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from ..storyboard.schema import Storyboard, Shot, GlobalStyle
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CompiledPrompt:
|
||||||
|
"""Compiled prompt ready for generation."""
|
||||||
|
positive: str
|
||||||
|
negative: str
|
||||||
|
camera_notes: str
|
||||||
|
full_prompt: str # Combined positive + style + camera
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary for metadata storage."""
|
||||||
|
return {
|
||||||
|
"positive": self.positive,
|
||||||
|
"negative": self.negative,
|
||||||
|
"camera_notes": self.camera_notes,
|
||||||
|
"full_prompt": self.full_prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PromptCompiler:
|
||||||
|
"""
|
||||||
|
Compiles prompts by merging global style with shot-specific content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, storyboard: Storyboard):
|
||||||
|
"""
|
||||||
|
Initialize with storyboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storyboard: Loaded and validated storyboard
|
||||||
|
"""
|
||||||
|
self.storyboard = storyboard
|
||||||
|
self.global_style = storyboard.project.global_style
|
||||||
|
|
||||||
|
def compile_shot(self, shot: Shot) -> CompiledPrompt:
|
||||||
|
"""
|
||||||
|
Compile prompt for a specific shot.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shot: Shot to compile prompt for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CompiledPrompt with merged content
|
||||||
|
"""
|
||||||
|
# Build character descriptions
|
||||||
|
char_descriptions = []
|
||||||
|
for char_id in shot.characters:
|
||||||
|
char = self.storyboard.get_character(char_id)
|
||||||
|
if char:
|
||||||
|
char_descriptions.append(f"{char.name}: {char.description}")
|
||||||
|
|
||||||
|
# Build location description
|
||||||
|
location_desc = ""
|
||||||
|
if shot.location_id:
|
||||||
|
loc = self.storyboard.get_location(shot.location_id)
|
||||||
|
if loc:
|
||||||
|
location_desc = loc.description
|
||||||
|
|
||||||
|
# Compile positive prompt
|
||||||
|
positive_parts = []
|
||||||
|
|
||||||
|
# Add global visual style first (sets the aesthetic)
|
||||||
|
if self.global_style.visual_style:
|
||||||
|
positive_parts.append(self.global_style.visual_style)
|
||||||
|
|
||||||
|
# Add location context
|
||||||
|
if location_desc:
|
||||||
|
positive_parts.append(f"Setting: {location_desc}")
|
||||||
|
|
||||||
|
# Add character descriptions
|
||||||
|
if char_descriptions:
|
||||||
|
positive_parts.append("Characters: " + "; ".join(char_descriptions))
|
||||||
|
|
||||||
|
# Add the main shot prompt
|
||||||
|
positive_parts.append(shot.prompt)
|
||||||
|
|
||||||
|
# Add camera/lens info
|
||||||
|
camera_parts = []
|
||||||
|
if shot.camera.framing:
|
||||||
|
camera_parts.append(shot.camera.framing)
|
||||||
|
if self.global_style.lens:
|
||||||
|
camera_parts.append(self.global_style.lens)
|
||||||
|
if shot.camera.movement:
|
||||||
|
camera_parts.append(shot.camera.movement)
|
||||||
|
if self.global_style.motion_style:
|
||||||
|
camera_parts.append(self.global_style.motion_style)
|
||||||
|
if shot.camera.notes:
|
||||||
|
camera_parts.append(shot.camera.notes)
|
||||||
|
|
||||||
|
if camera_parts:
|
||||||
|
positive_parts.append(f"Camera: {', '.join(camera_parts)}")
|
||||||
|
|
||||||
|
# Add lighting
|
||||||
|
if self.global_style.lighting:
|
||||||
|
positive_parts.append(f"Lighting: {self.global_style.lighting}")
|
||||||
|
|
||||||
|
# Add color grade
|
||||||
|
if self.global_style.color_grade:
|
||||||
|
positive_parts.append(f"Color: {self.global_style.color_grade}")
|
||||||
|
|
||||||
|
positive = ". ".join(positive_parts)
|
||||||
|
|
||||||
|
# Compile negative prompt
|
||||||
|
negative_parts = []
|
||||||
|
if self.global_style.negative_prompt:
|
||||||
|
negative_parts.append(self.global_style.negative_prompt)
|
||||||
|
|
||||||
|
negative = ", ".join(negative_parts) if negative_parts else ""
|
||||||
|
|
||||||
|
# Camera notes for metadata
|
||||||
|
camera_notes = "; ".join(camera_parts) if camera_parts else ""
|
||||||
|
|
||||||
|
# Full combined prompt (for reference)
|
||||||
|
full = positive
|
||||||
|
|
||||||
|
return CompiledPrompt(
|
||||||
|
positive=positive,
|
||||||
|
negative=negative,
|
||||||
|
camera_notes=camera_notes,
|
||||||
|
full_prompt=full
|
||||||
|
)
|
||||||
|
|
||||||
|
def compile_all(self) -> List[CompiledPrompt]:
|
||||||
|
"""
|
||||||
|
Compile prompts for all shots.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of CompiledPrompt objects in shot order
|
||||||
|
"""
|
||||||
|
return [self.compile_shot(shot) for shot in self.storyboard.shots]
|
||||||
|
|
||||||
|
def get_shot_prompt(self, shot_id: str) -> Optional[CompiledPrompt]:
|
||||||
|
"""
|
||||||
|
Get compiled prompt for a specific shot by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shot_id: Shot identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CompiledPrompt or None if shot not found
|
||||||
|
"""
|
||||||
|
for shot in self.storyboard.shots:
|
||||||
|
if shot.id == shot_id:
|
||||||
|
return self.compile_shot(shot)
|
||||||
|
return None
|
||||||
159
src/storyboard/schema.py
Normal file
159
src/storyboard/schema.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""
|
||||||
|
Storyboard schema validation using Pydantic.
|
||||||
|
Defines the data models for storyboard JSON files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, Literal
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
|
|
||||||
|
class Resolution(BaseModel):
|
||||||
|
"""Video resolution dimensions."""
|
||||||
|
width: int = Field(..., ge=256, le=4096, description="Width in pixels")
|
||||||
|
height: int = Field(..., ge=256, le=4096, description="Height in pixels")
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalStyle(BaseModel):
|
||||||
|
"""Global visual style settings applied to all shots."""
|
||||||
|
visual_style: str = Field(default="", description="Overall visual aesthetic")
|
||||||
|
color_grade: str = Field(default="", description="Color grading description")
|
||||||
|
lens: str = Field(default="", description="Lens type/characteristics")
|
||||||
|
lighting: str = Field(default="", description="Lighting setup description")
|
||||||
|
motion_style: str = Field(default="", description="Camera motion characteristics")
|
||||||
|
negative_prompt: str = Field(default="", description="What to avoid in generation")
|
||||||
|
|
||||||
|
|
||||||
|
class AudioSettings(BaseModel):
|
||||||
|
"""Audio configuration for the final video."""
|
||||||
|
add_music: bool = Field(default=False)
|
||||||
|
music_path: str = Field(default="")
|
||||||
|
add_voiceover: bool = Field(default=False)
|
||||||
|
voiceover_path: str = Field(default="")
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectSettings(BaseModel):
|
||||||
|
"""Top-level project configuration."""
|
||||||
|
title: str = Field(..., min_length=1, description="Project title")
|
||||||
|
fps: int = Field(default=24, ge=1, le=120, description="Frames per second")
|
||||||
|
target_duration_s: int = Field(default=20, ge=1, le=300, description="Target duration in seconds")
|
||||||
|
resolution: Resolution
|
||||||
|
aspect_ratio: str = Field(default="16:9", description="Aspect ratio (e.g., 16:9, 4:3)")
|
||||||
|
global_style: GlobalStyle = Field(default_factory=GlobalStyle)
|
||||||
|
audio: AudioSettings = Field(default_factory=AudioSettings)
|
||||||
|
|
||||||
|
|
||||||
|
class Character(BaseModel):
|
||||||
|
"""Character definition for consistency across shots."""
|
||||||
|
id: str = Field(..., min_length=1, description="Unique character identifier")
|
||||||
|
name: str = Field(..., min_length=1, description="Character name")
|
||||||
|
description: str = Field(default="", description="Visual description")
|
||||||
|
consistency_notes: str = Field(default="", description="Notes for maintaining consistency")
|
||||||
|
|
||||||
|
|
||||||
|
class Location(BaseModel):
|
||||||
|
"""Location/setting definition."""
|
||||||
|
id: str = Field(..., min_length=1, description="Unique location identifier")
|
||||||
|
name: str = Field(..., min_length=1, description="Location name")
|
||||||
|
description: str = Field(default="", description="Visual description of the setting")
|
||||||
|
|
||||||
|
|
||||||
|
class CameraSettings(BaseModel):
|
||||||
|
"""Camera configuration for a shot."""
|
||||||
|
framing: str = Field(default="", description="Shot framing (wide, medium, close-up, etc.)")
|
||||||
|
movement: str = Field(default="", description="Camera movement description")
|
||||||
|
notes: str = Field(default="", description="Additional camera notes")
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationSettings(BaseModel):
|
||||||
|
"""AI generation parameters for a shot."""
|
||||||
|
seed: int = Field(default=-1, description="Random seed (-1 for random)")
|
||||||
|
steps: int = Field(default=30, ge=1, le=100, description="Inference steps")
|
||||||
|
cfg_scale: float = Field(default=6.0, ge=1.0, le=20.0, description="Classifier-free guidance scale")
|
||||||
|
sampler: str = Field(default="default", description="Sampler type")
|
||||||
|
chunk_seconds: int = Field(default=4, ge=1, le=10, description="Duration per chunk in seconds")
|
||||||
|
use_init_frame_from_prev: bool = Field(default=False, description="Use last frame of previous shot as init")
|
||||||
|
|
||||||
|
|
||||||
|
class Shot(BaseModel):
|
||||||
|
"""Individual shot definition."""
|
||||||
|
id: str = Field(..., min_length=1, description="Unique shot identifier")
|
||||||
|
duration_s: int = Field(..., ge=1, le=60, description="Shot duration in seconds")
|
||||||
|
location_id: Optional[str] = Field(default=None, description="Reference to location")
|
||||||
|
characters: List[str] = Field(default_factory=list, description="List of character IDs")
|
||||||
|
prompt: str = Field(..., min_length=1, description="Generation prompt for this shot")
|
||||||
|
camera: CameraSettings = Field(default_factory=CameraSettings)
|
||||||
|
generation: GenerationSettings = Field(default_factory=GenerationSettings)
|
||||||
|
|
||||||
|
|
||||||
|
class UpscaleSettings(BaseModel):
|
||||||
|
"""Post-processing upscaling configuration."""
|
||||||
|
enabled: bool = Field(default=False)
|
||||||
|
target_height: int = Field(default=2160, ge=720, le=4320, description="Target height in pixels")
|
||||||
|
method: str = Field(default="default", description="Upscaling method")
|
||||||
|
|
||||||
|
|
||||||
|
class OutputSettings(BaseModel):
|
||||||
|
"""Final video output configuration."""
|
||||||
|
container: Literal["mp4", "mov", "mkv"] = Field(default="mp4")
|
||||||
|
codec: Literal["h264", "h265", "vp9"] = Field(default="h264")
|
||||||
|
crf: int = Field(default=18, ge=0, le=51, description="Constant rate factor (lower = higher quality)")
|
||||||
|
preset: str = Field(default="medium", description="Encoding preset")
|
||||||
|
upscale: UpscaleSettings = Field(default_factory=UpscaleSettings)
|
||||||
|
|
||||||
|
|
||||||
|
class Storyboard(BaseModel):
|
||||||
|
"""Complete storyboard document."""
|
||||||
|
schema_version: str = Field(default="1.0", description="Schema version for compatibility")
|
||||||
|
project: ProjectSettings
|
||||||
|
characters: List[Character] = Field(default_factory=list)
|
||||||
|
locations: List[Location] = Field(default_factory=list)
|
||||||
|
shots: List[Shot] = Field(..., description="List of shots to generate")
|
||||||
|
output: OutputSettings = Field(default_factory=OutputSettings)
|
||||||
|
|
||||||
|
@validator('shots')
|
||||||
|
def validate_shot_ids_unique(cls, v):
|
||||||
|
"""Ensure all shot IDs are unique and at least one shot exists."""
|
||||||
|
if len(v) < 1:
|
||||||
|
raise ValueError("At least one shot is required")
|
||||||
|
ids = [shot.id for shot in v]
|
||||||
|
if len(ids) != len(set(ids)):
|
||||||
|
raise ValueError("Shot IDs must be unique")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator('characters')
|
||||||
|
def validate_character_ids_unique(cls, v):
|
||||||
|
"""Ensure all character IDs are unique."""
|
||||||
|
ids = [char.id for char in v]
|
||||||
|
if len(ids) != len(set(ids)):
|
||||||
|
raise ValueError("Character IDs must be unique")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator('locations')
|
||||||
|
def validate_location_ids_unique(cls, v):
|
||||||
|
"""Ensure all location IDs are unique."""
|
||||||
|
ids = [loc.id for loc in v]
|
||||||
|
if len(ids) != len(set(ids)):
|
||||||
|
raise ValueError("Location IDs must be unique")
|
||||||
|
return v
|
||||||
|
|
||||||
|
def get_character(self, char_id: str) -> Optional[Character]:
|
||||||
|
"""Get character by ID."""
|
||||||
|
for char in self.characters:
|
||||||
|
if char.id == char_id:
|
||||||
|
return char
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_location(self, loc_id: str) -> Optional[Location]:
|
||||||
|
"""Get location by ID."""
|
||||||
|
for loc in self.locations:
|
||||||
|
if loc.id == loc_id:
|
||||||
|
return loc
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_total_duration(self) -> int:
|
||||||
|
"""Calculate total duration in seconds."""
|
||||||
|
return sum(shot.duration_s for shot in self.shots)
|
||||||
|
|
||||||
|
def get_total_frames(self) -> int:
|
||||||
|
"""Calculate total frame count."""
|
||||||
|
return self.get_total_duration() * self.project.fps
|
||||||
252
src/storyboard/shot_planner.py
Normal file
252
src/storyboard/shot_planner.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""
|
||||||
|
Shot planner for converting duration to frames and chunking strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from ..storyboard.schema import Storyboard, Shot
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Chunk:
|
||||||
|
"""A single chunk of a shot."""
|
||||||
|
chunk_index: int
|
||||||
|
start_frame: int
|
||||||
|
end_frame: int
|
||||||
|
num_frames: int
|
||||||
|
duration_s: float
|
||||||
|
use_init_frame: bool
|
||||||
|
init_frame_source: Optional[Path] # Path to frame file or previous chunk
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for metadata."""
|
||||||
|
return {
|
||||||
|
"chunk_index": self.chunk_index,
|
||||||
|
"start_frame": self.start_frame,
|
||||||
|
"end_frame": self.end_frame,
|
||||||
|
"num_frames": self.num_frames,
|
||||||
|
"duration_s": self.duration_s,
|
||||||
|
"use_init_frame": self.use_init_frame,
|
||||||
|
"init_frame_source": str(self.init_frame_source) if self.init_frame_source else None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ShotPlan:
|
||||||
|
"""Complete plan for generating a shot."""
|
||||||
|
shot_id: str
|
||||||
|
total_frames: int
|
||||||
|
total_duration_s: int
|
||||||
|
fps: int
|
||||||
|
chunks: List[Chunk]
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for metadata."""
|
||||||
|
return {
|
||||||
|
"shot_id": self.shot_id,
|
||||||
|
"total_frames": self.total_frames,
|
||||||
|
"total_duration_s": self.total_duration_s,
|
||||||
|
"fps": self.fps,
|
||||||
|
"chunks": [c.to_dict() for c in self.chunks]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ShotPlanner:
|
||||||
|
"""
|
||||||
|
Plans shot generation including frame counts and chunking strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fps: int = 24,
|
||||||
|
max_chunk_seconds: int = 4,
|
||||||
|
overlap_seconds: int = 0,
|
||||||
|
chunking_mode: str = "sequential"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize shot planner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fps: Frames per second
|
||||||
|
max_chunk_seconds: Maximum duration per chunk
|
||||||
|
overlap_seconds: Overlap between chunks (for blending mode)
|
||||||
|
chunking_mode: "sequential" or "overlapping"
|
||||||
|
"""
|
||||||
|
self.fps = fps
|
||||||
|
self.max_chunk_seconds = max_chunk_seconds
|
||||||
|
self.overlap_seconds = overlap_seconds
|
||||||
|
self.chunking_mode = chunking_mode
|
||||||
|
|
||||||
|
def duration_to_frames(self, duration_s: int) -> int:
|
||||||
|
"""
|
||||||
|
Convert duration in seconds to frame count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
duration_s: Duration in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of frames
|
||||||
|
"""
|
||||||
|
return duration_s * self.fps
|
||||||
|
|
||||||
|
def frames_to_duration(self, num_frames: int) -> float:
|
||||||
|
"""
|
||||||
|
Convert frame count to duration in seconds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_frames: Number of frames
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Duration in seconds
|
||||||
|
"""
|
||||||
|
return num_frames / self.fps
|
||||||
|
|
||||||
|
def plan_shot(
|
||||||
|
self,
|
||||||
|
shot: Shot,
|
||||||
|
fps: Optional[int] = None,
|
||||||
|
use_init_from_prev: bool = False
|
||||||
|
) -> ShotPlan:
|
||||||
|
"""
|
||||||
|
Create a generation plan for a shot.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shot: Shot to plan
|
||||||
|
fps: Override FPS (uses storyboard default if None)
|
||||||
|
use_init_from_prev: Whether to use previous shot's last frame
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ShotPlan with chunk breakdown
|
||||||
|
"""
|
||||||
|
fps = fps or self.fps
|
||||||
|
total_frames = shot.duration_s * fps
|
||||||
|
|
||||||
|
# Determine chunk size
|
||||||
|
chunk_frames = self.max_chunk_seconds * fps
|
||||||
|
overlap_frames = self.overlap_seconds * fps if self.chunking_mode == "overlapping" else 0
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
if total_frames <= chunk_frames:
|
||||||
|
# Single chunk - no need to split
|
||||||
|
chunk = Chunk(
|
||||||
|
chunk_index=0,
|
||||||
|
start_frame=0,
|
||||||
|
end_frame=total_frames - 1,
|
||||||
|
num_frames=total_frames,
|
||||||
|
duration_s=shot.duration_s,
|
||||||
|
use_init_frame=use_init_from_prev and shot.generation.use_init_frame_from_prev,
|
||||||
|
init_frame_source=None # Will be set by pipeline
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
else:
|
||||||
|
# Multiple chunks needed
|
||||||
|
remaining_frames = total_frames
|
||||||
|
current_frame = 0
|
||||||
|
chunk_index = 0
|
||||||
|
|
||||||
|
while remaining_frames > 0:
|
||||||
|
# Calculate frames for this chunk
|
||||||
|
if self.chunking_mode == "sequential":
|
||||||
|
# Sequential: Each chunk starts where previous ended
|
||||||
|
this_chunk_frames = min(chunk_frames, remaining_frames)
|
||||||
|
|
||||||
|
chunk = Chunk(
|
||||||
|
chunk_index=chunk_index,
|
||||||
|
start_frame=current_frame,
|
||||||
|
end_frame=current_frame + this_chunk_frames - 1,
|
||||||
|
num_frames=this_chunk_frames,
|
||||||
|
duration_s=this_chunk_frames / fps,
|
||||||
|
use_init_frame=(chunk_index == 0 and use_init_from_prev and shot.generation.use_init_frame_from_prev) or chunk_index > 0,
|
||||||
|
init_frame_source=None
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
current_frame += this_chunk_frames
|
||||||
|
remaining_frames -= this_chunk_frames
|
||||||
|
|
||||||
|
else: # overlapping
|
||||||
|
# Overlapping: Chunks overlap for blending
|
||||||
|
if remaining_frames <= chunk_frames:
|
||||||
|
# Last chunk - take all remaining frames
|
||||||
|
this_chunk_frames = remaining_frames
|
||||||
|
else:
|
||||||
|
# Not last chunk - include overlap
|
||||||
|
this_chunk_frames = min(chunk_frames + overlap_frames, remaining_frames)
|
||||||
|
|
||||||
|
chunk = Chunk(
|
||||||
|
chunk_index=chunk_index,
|
||||||
|
start_frame=current_frame,
|
||||||
|
end_frame=current_frame + this_chunk_frames - 1,
|
||||||
|
num_frames=this_chunk_frames,
|
||||||
|
duration_s=this_chunk_frames / fps,
|
||||||
|
use_init_frame=(chunk_index == 0 and use_init_from_prev and shot.generation.use_init_frame_from_prev) or chunk_index > 0,
|
||||||
|
init_frame_source=None
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Move forward by chunk size (minus overlap)
|
||||||
|
current_frame += chunk_frames - overlap_frames
|
||||||
|
remaining_frames -= chunk_frames - overlap_frames
|
||||||
|
|
||||||
|
chunk_index += 1
|
||||||
|
|
||||||
|
return ShotPlan(
|
||||||
|
shot_id=shot.id,
|
||||||
|
total_frames=total_frames,
|
||||||
|
total_duration_s=shot.duration_s,
|
||||||
|
fps=fps,
|
||||||
|
chunks=chunks
|
||||||
|
)
|
||||||
|
|
||||||
|
def plan_storyboard(
|
||||||
|
self,
|
||||||
|
storyboard: Storyboard,
|
||||||
|
override_chunk_config: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[ShotPlan]:
|
||||||
|
"""
|
||||||
|
Create generation plans for all shots in a storyboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storyboard: Storyboard to plan
|
||||||
|
override_chunk_config: Optional config to override defaults
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ShotPlan objects
|
||||||
|
"""
|
||||||
|
# Apply overrides if provided
|
||||||
|
fps = override_chunk_config.get("fps", storyboard.project.fps) if override_chunk_config else storyboard.project.fps
|
||||||
|
max_chunk = override_chunk_config.get("max_chunk_seconds", self.max_chunk_seconds) if override_chunk_config else self.max_chunk_seconds
|
||||||
|
overlap = override_chunk_config.get("overlap_seconds", self.overlap_seconds) if override_chunk_config else self.overlap_seconds
|
||||||
|
mode = override_chunk_config.get("mode", self.chunking_mode) if override_chunk_config else self.chunking_mode
|
||||||
|
|
||||||
|
planner = ShotPlanner(
|
||||||
|
fps=fps,
|
||||||
|
max_chunk_seconds=max_chunk,
|
||||||
|
overlap_seconds=overlap,
|
||||||
|
chunking_mode=mode
|
||||||
|
)
|
||||||
|
|
||||||
|
plans = []
|
||||||
|
prev_shot_id = None
|
||||||
|
|
||||||
|
for i, shot in enumerate(storyboard.shots):
|
||||||
|
# Determine if we should use init frame from previous shot
|
||||||
|
use_init = i > 0 and shot.generation.use_init_frame_from_prev
|
||||||
|
|
||||||
|
plan = planner.plan_shot(shot, fps=fps, use_init_from_prev=use_init)
|
||||||
|
plans.append(plan)
|
||||||
|
|
||||||
|
prev_shot_id = shot.id
|
||||||
|
|
||||||
|
return plans
|
||||||
|
|
||||||
|
def get_total_frames(self, plans: List[ShotPlan]) -> int:
|
||||||
|
"""Calculate total frames across all plans."""
|
||||||
|
return sum(plan.total_frames for plan in plans)
|
||||||
|
|
||||||
|
def get_total_duration(self, plans: List[ShotPlan]) -> float:
|
||||||
|
"""Calculate total duration in seconds across all plans."""
|
||||||
|
return sum(plan.total_duration_s for plan in plans)
|
||||||
3
src/upscaling/__init__.py
Normal file
3
src/upscaling/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
Upscaling module.
|
||||||
|
"""
|
||||||
189
tests/unit/test_checkpoint.py
Normal file
189
tests/unit/test_checkpoint.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
Tests for checkpoint system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.core.checkpoint import (
|
||||||
|
CheckpointManager, ShotCheckpoint, ProjectCheckpoint,
|
||||||
|
ShotStatus, ProjectStatus
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointManager:
|
||||||
|
"""Test checkpoint management."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_db(self):
|
||||||
|
"""Create a temporary database for testing."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
|
||||||
|
db_path = f.name
|
||||||
|
|
||||||
|
yield db_path
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
os.unlink(db_path)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(self, temp_db):
|
||||||
|
"""Create a checkpoint manager with temp database."""
|
||||||
|
return CheckpointManager(db_path=temp_db)
|
||||||
|
|
||||||
|
def test_create_project(self, manager):
|
||||||
|
"""Test creating a project checkpoint."""
|
||||||
|
checkpoint = manager.create_project(
|
||||||
|
project_id="test_001",
|
||||||
|
storyboard_path="/path/to/storyboard.json",
|
||||||
|
output_dir="/path/to/output",
|
||||||
|
backend_name="wan_t2v_14b"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert checkpoint.project_id == "test_001"
|
||||||
|
assert checkpoint.status == ProjectStatus.INITIALIZED
|
||||||
|
assert checkpoint.backend_name == "wan_t2v_14b"
|
||||||
|
|
||||||
|
def test_get_project(self, manager):
|
||||||
|
"""Test retrieving a project checkpoint."""
|
||||||
|
manager.create_project(
|
||||||
|
project_id="test_001",
|
||||||
|
storyboard_path="/path/to/storyboard.json",
|
||||||
|
output_dir="/path/to/output",
|
||||||
|
backend_name="wan_t2v_14b"
|
||||||
|
)
|
||||||
|
|
||||||
|
retrieved = manager.get_project("test_001")
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.project_id == "test_001"
|
||||||
|
|
||||||
|
missing = manager.get_project("nonexistent")
|
||||||
|
assert missing is None
|
||||||
|
|
||||||
|
def test_update_project_status(self, manager):
|
||||||
|
"""Test updating project status."""
|
||||||
|
manager.create_project(
|
||||||
|
project_id="test_001",
|
||||||
|
storyboard_path="/path/to/storyboard.json",
|
||||||
|
output_dir="/path/to/output",
|
||||||
|
backend_name="wan_t2v_14b"
|
||||||
|
)
|
||||||
|
|
||||||
|
manager.update_project_status("test_001", ProjectStatus.RUNNING)
|
||||||
|
|
||||||
|
retrieved = manager.get_project("test_001")
|
||||||
|
assert retrieved.status == ProjectStatus.RUNNING
|
||||||
|
|
||||||
|
def test_create_shot(self, manager):
|
||||||
|
"""Test creating a shot checkpoint."""
|
||||||
|
manager.create_project(
|
||||||
|
project_id="test_001",
|
||||||
|
storyboard_path="/path/storyboard.json",
|
||||||
|
output_dir="/path/output",
|
||||||
|
backend_name="wan"
|
||||||
|
)
|
||||||
|
|
||||||
|
shot = manager.create_shot("test_001", "S01", ShotStatus.PENDING)
|
||||||
|
|
||||||
|
assert shot.shot_id == "S01"
|
||||||
|
assert shot.status == ShotStatus.PENDING
|
||||||
|
|
||||||
|
def test_update_shot(self, manager):
|
||||||
|
"""Test updating shot checkpoint."""
|
||||||
|
manager.create_project(
|
||||||
|
project_id="test_001",
|
||||||
|
storyboard_path="/path/storyboard.json",
|
||||||
|
output_dir="/path/output",
|
||||||
|
backend_name="wan"
|
||||||
|
)
|
||||||
|
manager.create_shot("test_001", "S01")
|
||||||
|
|
||||||
|
manager.update_shot(
|
||||||
|
"test_001",
|
||||||
|
"S01",
|
||||||
|
status=ShotStatus.IN_PROGRESS,
|
||||||
|
output_path="/path/output.mp4",
|
||||||
|
metadata={"seed": 12345}
|
||||||
|
)
|
||||||
|
|
||||||
|
shot = manager.get_shot("test_001", "S01")
|
||||||
|
assert shot.status == ShotStatus.IN_PROGRESS
|
||||||
|
assert shot.output_path == "/path/output.mp4"
|
||||||
|
assert shot.metadata["seed"] == 12345
|
||||||
|
|
||||||
|
def test_get_pending_shots(self, manager):
|
||||||
|
"""Test retrieving pending shots."""
|
||||||
|
manager.create_project(
|
||||||
|
project_id="test_001",
|
||||||
|
storyboard_path="/path/storyboard.json",
|
||||||
|
output_dir="/path/output",
|
||||||
|
backend_name="wan"
|
||||||
|
)
|
||||||
|
|
||||||
|
manager.create_shot("test_001", "S01", ShotStatus.PENDING)
|
||||||
|
manager.create_shot("test_001", "S02", ShotStatus.COMPLETED)
|
||||||
|
manager.create_shot("test_001", "S03", ShotStatus.PENDING)
|
||||||
|
|
||||||
|
pending = manager.get_pending_shots("test_001")
|
||||||
|
assert len(pending) == 2
|
||||||
|
assert "S01" in pending
|
||||||
|
assert "S03" in pending
|
||||||
|
|
||||||
|
def test_can_resume(self, manager):
|
||||||
|
"""Test checking if project can be resumed."""
|
||||||
|
manager.create_project(
|
||||||
|
project_id="test_001",
|
||||||
|
storyboard_path="/path/storyboard.json",
|
||||||
|
output_dir="/path/output",
|
||||||
|
backend_name="wan"
|
||||||
|
)
|
||||||
|
|
||||||
|
# No shots yet - can't resume
|
||||||
|
assert not manager.can_resume("test_001")
|
||||||
|
|
||||||
|
# Add pending shot - can resume
|
||||||
|
manager.create_shot("test_001", "S01", ShotStatus.PENDING)
|
||||||
|
assert manager.can_resume("test_001")
|
||||||
|
|
||||||
|
# Complete all shots - can't resume
|
||||||
|
manager.update_shot("test_001", "S01", status=ShotStatus.COMPLETED)
|
||||||
|
assert not manager.can_resume("test_001")
|
||||||
|
|
||||||
|
def test_delete_project(self, manager):
|
||||||
|
"""Test deleting a project."""
|
||||||
|
manager.create_project(
|
||||||
|
project_id="test_001",
|
||||||
|
storyboard_path="/path/storyboard.json",
|
||||||
|
output_dir="/path/output",
|
||||||
|
backend_name="wan"
|
||||||
|
)
|
||||||
|
manager.create_shot("test_001", "S01")
|
||||||
|
|
||||||
|
manager.delete_project("test_001")
|
||||||
|
|
||||||
|
assert manager.get_project("test_001") is None
|
||||||
|
assert manager.get_shot("test_001", "S01") is None
|
||||||
|
|
||||||
|
def test_list_projects(self, manager):
|
||||||
|
"""Test listing all projects."""
|
||||||
|
manager.create_project(
|
||||||
|
project_id="test_001",
|
||||||
|
storyboard_path="/path/1.json",
|
||||||
|
output_dir="/out/1",
|
||||||
|
backend_name="wan"
|
||||||
|
)
|
||||||
|
manager.create_project(
|
||||||
|
project_id="test_002",
|
||||||
|
storyboard_path="/path/2.json",
|
||||||
|
output_dir="/out/2",
|
||||||
|
backend_name="svd"
|
||||||
|
)
|
||||||
|
|
||||||
|
projects = manager.list_projects()
|
||||||
|
assert len(projects) == 2
|
||||||
|
|
||||||
|
project_ids = [p.project_id for p in projects]
|
||||||
|
assert "test_001" in project_ids
|
||||||
|
assert "test_002" in project_ids
|
||||||
163
tests/unit/test_config.py
Normal file
163
tests/unit/test_config.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
"""
|
||||||
|
Tests for configuration system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.core.config import Config, BackendConfig, ConfigLoader, get_config, reload_config
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigLoader:
|
||||||
|
"""Test configuration loading."""
|
||||||
|
|
||||||
|
def test_load_default_config(self):
|
||||||
|
"""Test loading with no config files (uses defaults)."""
|
||||||
|
config = ConfigLoader.load(
|
||||||
|
config_path=Path("nonexistent.yaml"),
|
||||||
|
env_file=Path("nonexistent.env")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(config, Config)
|
||||||
|
assert config.active_backend == "wan_t2v_14b"
|
||||||
|
assert len(config.backends) == 0 # No YAML loaded
|
||||||
|
|
||||||
|
def test_load_yaml_config(self):
|
||||||
|
"""Test loading from YAML file."""
|
||||||
|
yaml_content = """
|
||||||
|
backends:
|
||||||
|
test_backend:
|
||||||
|
name: "Test Backend"
|
||||||
|
class: "test.TestBackend"
|
||||||
|
model_id: "test/model"
|
||||||
|
vram_gb: 8
|
||||||
|
dtype: "fp16"
|
||||||
|
enable_vae_slicing: true
|
||||||
|
enable_vae_tiling: false
|
||||||
|
chunking:
|
||||||
|
enabled: true
|
||||||
|
mode: "sequential"
|
||||||
|
max_chunk_seconds: 4
|
||||||
|
overlap_seconds: 1
|
||||||
|
|
||||||
|
active_backend: "test_backend"
|
||||||
|
defaults:
|
||||||
|
fps: 30
|
||||||
|
checkpoint_db: "test.db"
|
||||||
|
model_cache_dir: "~/test_cache"
|
||||||
|
"""
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||||
|
f.write(yaml_content)
|
||||||
|
yaml_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = ConfigLoader.load(config_path=Path(yaml_path))
|
||||||
|
|
||||||
|
assert config.active_backend == "test_backend"
|
||||||
|
assert "test_backend" in config.backends
|
||||||
|
|
||||||
|
backend = config.get_backend("test_backend")
|
||||||
|
assert backend.name == "Test Backend"
|
||||||
|
assert backend.vram_gb == 8
|
||||||
|
assert backend.chunking_mode == "sequential"
|
||||||
|
assert config.defaults["fps"] == 30
|
||||||
|
finally:
|
||||||
|
os.unlink(yaml_path)
|
||||||
|
|
||||||
|
def test_load_env_file(self, monkeypatch):
|
||||||
|
"""Test loading from .env file."""
|
||||||
|
# Clear any existing env vars first
|
||||||
|
for key in ['ACTIVE_BACKEND', 'MODEL_CACHE_DIR', 'LOG_LEVEL']:
|
||||||
|
monkeypatch.delenv(key, raising=False)
|
||||||
|
|
||||||
|
env_content = """
|
||||||
|
ACTIVE_BACKEND=env_backend
|
||||||
|
MODEL_CACHE_DIR=/env/cache
|
||||||
|
LOG_LEVEL=DEBUG
|
||||||
|
"""
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.env', delete=False) as f:
|
||||||
|
f.write(env_content)
|
||||||
|
env_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = ConfigLoader.load(
|
||||||
|
config_path=Path("nonexistent.yaml"),
|
||||||
|
env_file=Path(env_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.active_backend == "env_backend"
|
||||||
|
assert "/env/cache" in config.model_cache_dir or "\\env\\cache" in config.model_cache_dir
|
||||||
|
assert config.log_level == "DEBUG"
|
||||||
|
finally:
|
||||||
|
os.unlink(env_path)
|
||||||
|
# Clean up env vars
|
||||||
|
for key in ['ACTIVE_BACKEND', 'MODEL_CACHE_DIR', 'LOG_LEVEL']:
|
||||||
|
monkeypatch.delenv(key, raising=False)
|
||||||
|
|
||||||
|
def test_env_variable_override(self, monkeypatch):
|
||||||
|
"""Test that environment variables override config files."""
|
||||||
|
# Clear env var first
|
||||||
|
monkeypatch.delenv("ACTIVE_BACKEND", raising=False)
|
||||||
|
|
||||||
|
yaml_content = """
|
||||||
|
active_backend: yaml_backend
|
||||||
|
model_cache_dir: ~/yaml_cache
|
||||||
|
"""
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||||
|
f.write(yaml_content)
|
||||||
|
yaml_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
monkeypatch.setenv("ACTIVE_BACKEND", "env_override")
|
||||||
|
|
||||||
|
config = ConfigLoader.load(config_path=Path(yaml_path))
|
||||||
|
|
||||||
|
# Environment variable should override
|
||||||
|
assert config.active_backend == "env_override"
|
||||||
|
# But other settings from YAML should remain
|
||||||
|
assert "yaml_cache" in config.model_cache_dir
|
||||||
|
finally:
|
||||||
|
os.unlink(yaml_path)
|
||||||
|
monkeypatch.delenv("ACTIVE_BACKEND", raising=False)
|
||||||
|
|
||||||
|
def test_path_expansion(self):
|
||||||
|
"""Test that paths are properly expanded."""
|
||||||
|
config = ConfigLoader.load(
|
||||||
|
config_path=Path("nonexistent.yaml"),
|
||||||
|
env_file=Path("nonexistent.env")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default paths should be expanded
|
||||||
|
assert not config.model_cache_dir.startswith("~")
|
||||||
|
assert not config.checkpoint_db.startswith("~")
|
||||||
|
|
||||||
|
def test_get_backend_not_found(self):
|
||||||
|
"""Test getting a non-existent backend."""
|
||||||
|
config = Config()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
config.get_backend("nonexistent")
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackendConfig:
|
||||||
|
"""Test backend configuration."""
|
||||||
|
|
||||||
|
def test_backend_config_defaults(self):
|
||||||
|
"""Test backend config with defaults."""
|
||||||
|
config = BackendConfig(
|
||||||
|
name="Test",
|
||||||
|
model_class="test.Test",
|
||||||
|
model_id="test/model",
|
||||||
|
vram_gb=12,
|
||||||
|
dtype="fp16"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.enable_vae_slicing is True
|
||||||
|
assert config.enable_vae_tiling is False
|
||||||
|
assert config.chunking_enabled is True
|
||||||
|
assert config.chunking_mode == "sequential"
|
||||||
268
tests/unit/test_storyboard.py
Normal file
268
tests/unit/test_storyboard.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
"""
|
||||||
|
Tests for storyboard schema validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
|
||||||
|
from src.storyboard.schema import (
|
||||||
|
Storyboard, ProjectSettings, Shot, Character, Location,
|
||||||
|
GlobalStyle, CameraSettings, GenerationSettings, OutputSettings,
|
||||||
|
Resolution
|
||||||
|
)
|
||||||
|
from src.storyboard.loader import StoryboardValidator, StoryboardLoadError
|
||||||
|
|
||||||
|
|
||||||
|
class TestStoryboardSchema:
|
||||||
|
"""Test storyboard schema validation."""
|
||||||
|
|
||||||
|
def test_create_minimal_storyboard(self):
|
||||||
|
"""Test creating a minimal valid storyboard."""
|
||||||
|
storyboard = Storyboard(
|
||||||
|
project=ProjectSettings(
|
||||||
|
title="Test Video",
|
||||||
|
resolution=Resolution(width=1920, height=1080),
|
||||||
|
target_duration_s=10
|
||||||
|
),
|
||||||
|
shots=[
|
||||||
|
Shot(
|
||||||
|
id="S01",
|
||||||
|
duration_s=5,
|
||||||
|
prompt="A test shot"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert storyboard.project.title == "Test Video"
|
||||||
|
assert len(storyboard.shots) == 1
|
||||||
|
assert storyboard.shots[0].id == "S01"
|
||||||
|
|
||||||
|
def test_shot_validation_unique_ids(self):
|
||||||
|
"""Test that duplicate shot IDs are rejected."""
|
||||||
|
with pytest.raises(ValueError, match="Shot IDs must be unique"):
|
||||||
|
Storyboard(
|
||||||
|
project=ProjectSettings(
|
||||||
|
title="Test",
|
||||||
|
resolution=Resolution(width=1920, height=1080),
|
||||||
|
target_duration_s=10
|
||||||
|
),
|
||||||
|
shots=[
|
||||||
|
Shot(id="S01", duration_s=5, prompt="Shot 1"),
|
||||||
|
Shot(id="S01", duration_s=5, prompt="Shot 2")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_character_validation_unique_ids(self):
|
||||||
|
"""Test that duplicate character IDs are rejected."""
|
||||||
|
with pytest.raises(ValueError, match="Character IDs must be unique"):
|
||||||
|
Storyboard(
|
||||||
|
project=ProjectSettings(
|
||||||
|
title="Test",
|
||||||
|
resolution=Resolution(width=1920, height=1080),
|
||||||
|
target_duration_s=10
|
||||||
|
),
|
||||||
|
characters=[
|
||||||
|
Character(id="C01", name="Character 1"),
|
||||||
|
Character(id="C01", name="Character 2")
|
||||||
|
],
|
||||||
|
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_location_validation_unique_ids(self):
|
||||||
|
"""Test that duplicate location IDs are rejected."""
|
||||||
|
with pytest.raises(ValueError, match="Location IDs must be unique"):
|
||||||
|
Storyboard(
|
||||||
|
project=ProjectSettings(
|
||||||
|
title="Test",
|
||||||
|
resolution=Resolution(width=1920, height=1080),
|
||||||
|
target_duration_s=10
|
||||||
|
),
|
||||||
|
locations=[
|
||||||
|
Location(id="L01", name="Location 1"),
|
||||||
|
Location(id="L01", name="Location 2")
|
||||||
|
],
|
||||||
|
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_character_by_id(self):
|
||||||
|
"""Test retrieving character by ID."""
|
||||||
|
storyboard = Storyboard(
|
||||||
|
project=ProjectSettings(
|
||||||
|
title="Test",
|
||||||
|
resolution=Resolution(width=1920, height=1080),
|
||||||
|
target_duration_s=10
|
||||||
|
),
|
||||||
|
characters=[
|
||||||
|
Character(id="C01", name="Hero", description="The protagonist")
|
||||||
|
],
|
||||||
|
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
|
||||||
|
)
|
||||||
|
|
||||||
|
char = storyboard.get_character("C01")
|
||||||
|
assert char is not None
|
||||||
|
assert char.name == "Hero"
|
||||||
|
|
||||||
|
missing = storyboard.get_character("C99")
|
||||||
|
assert missing is None
|
||||||
|
|
||||||
|
def test_get_location_by_id(self):
|
||||||
|
"""Test retrieving location by ID."""
|
||||||
|
storyboard = Storyboard(
|
||||||
|
project=ProjectSettings(
|
||||||
|
title="Test",
|
||||||
|
resolution=Resolution(width=1920, height=1080),
|
||||||
|
target_duration_s=10
|
||||||
|
),
|
||||||
|
locations=[
|
||||||
|
Location(id="L01", name="Street", description="A city street")
|
||||||
|
],
|
||||||
|
shots=[Shot(id="S01", duration_s=5, prompt="Test")]
|
||||||
|
)
|
||||||
|
|
||||||
|
loc = storyboard.get_location("L01")
|
||||||
|
assert loc is not None
|
||||||
|
assert loc.name == "Street"
|
||||||
|
|
||||||
|
missing = storyboard.get_location("L99")
|
||||||
|
assert missing is None
|
||||||
|
|
||||||
|
def test_total_duration_calculation(self):
|
||||||
|
"""Test total duration calculation."""
|
||||||
|
storyboard = Storyboard(
|
||||||
|
project=ProjectSettings(
|
||||||
|
title="Test",
|
||||||
|
resolution=Resolution(width=1920, height=1080),
|
||||||
|
target_duration_s=10
|
||||||
|
),
|
||||||
|
shots=[
|
||||||
|
Shot(id="S01", duration_s=4, prompt="Shot 1"),
|
||||||
|
Shot(id="S02", duration_s=6, prompt="Shot 2")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert storyboard.get_total_duration() == 10
|
||||||
|
|
||||||
|
def test_total_frames_calculation(self):
|
||||||
|
"""Test total frames calculation."""
|
||||||
|
storyboard = Storyboard(
|
||||||
|
project=ProjectSettings(
|
||||||
|
title="Test",
|
||||||
|
fps=24,
|
||||||
|
resolution=Resolution(width=1920, height=1080),
|
||||||
|
target_duration_s=10
|
||||||
|
),
|
||||||
|
shots=[
|
||||||
|
Shot(id="S01", duration_s=4, prompt="Shot 1"),
|
||||||
|
Shot(id="S02", duration_s=6, prompt="Shot 2")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert storyboard.get_total_frames() == 240 # 10 seconds * 24 fps
|
||||||
|
|
||||||
|
|
||||||
|
class TestStoryboardLoader:
|
||||||
|
"""Test storyboard loading functionality."""
|
||||||
|
|
||||||
|
def test_load_valid_storyboard(self):
|
||||||
|
"""Test loading a valid storyboard JSON file."""
|
||||||
|
data = {
|
||||||
|
"schema_version": "1.0",
|
||||||
|
"project": {
|
||||||
|
"title": "Test Video",
|
||||||
|
"fps": 24,
|
||||||
|
"target_duration_s": 10,
|
||||||
|
"resolution": {"width": 1920, "height": 1080},
|
||||||
|
"aspect_ratio": "16:9",
|
||||||
|
"global_style": {
|
||||||
|
"visual_style": "cinematic",
|
||||||
|
"negative_prompt": "blurry"
|
||||||
|
},
|
||||||
|
"audio": {"add_music": False}
|
||||||
|
},
|
||||||
|
"characters": [],
|
||||||
|
"locations": [],
|
||||||
|
"shots": [
|
||||||
|
{
|
||||||
|
"id": "S01",
|
||||||
|
"duration_s": 5,
|
||||||
|
"prompt": "A test shot",
|
||||||
|
"camera": {"framing": "wide"},
|
||||||
|
"generation": {"seed": 12345, "steps": 30}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output": {"container": "mp4", "codec": "h264"}
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
storyboard = StoryboardValidator.load(temp_path)
|
||||||
|
assert storyboard.project.title == "Test Video"
|
||||||
|
assert len(storyboard.shots) == 1
|
||||||
|
assert storyboard.shots[0].id == "S01"
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_load_nonexistent_file(self):
|
||||||
|
"""Test loading a file that doesn't exist."""
|
||||||
|
with pytest.raises(StoryboardLoadError, match="not found"):
|
||||||
|
StoryboardValidator.load("/nonexistent/path/storyboard.json")
|
||||||
|
|
||||||
|
def test_load_invalid_json(self):
|
||||||
|
"""Test loading invalid JSON."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||||
|
f.write("{invalid json}")
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(StoryboardLoadError, match="Invalid JSON"):
|
||||||
|
StoryboardValidator.load(temp_path)
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_load_wrong_extension(self):
|
||||||
|
"""Test loading a file with wrong extension."""
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
|
||||||
|
f.write('{}')
|
||||||
|
temp_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(StoryboardLoadError, match="must be JSON"):
|
||||||
|
StoryboardValidator.load(temp_path)
|
||||||
|
finally:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
def test_validate_references(self):
|
||||||
|
"""Test reference validation."""
|
||||||
|
storyboard = Storyboard(
|
||||||
|
project=ProjectSettings(
|
||||||
|
title="Test",
|
||||||
|
resolution=Resolution(width=1920, height=1080),
|
||||||
|
target_duration_s=10
|
||||||
|
),
|
||||||
|
characters=[
|
||||||
|
Character(id="C01", name="Hero")
|
||||||
|
],
|
||||||
|
locations=[
|
||||||
|
Location(id="L01", name="Street")
|
||||||
|
],
|
||||||
|
shots=[
|
||||||
|
Shot(
|
||||||
|
id="S01",
|
||||||
|
duration_s=5,
|
||||||
|
prompt="Test",
|
||||||
|
characters=["C01", "C99"], # C99 doesn't exist
|
||||||
|
location_id="L99" # Doesn't exist
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
issues = StoryboardValidator.validate_references(storyboard)
|
||||||
|
assert len(issues) == 2
|
||||||
|
assert any("C99" in issue for issue in issues)
|
||||||
|
assert any("L99" in issue for issue in issues)
|
||||||
Reference in New Issue
Block a user