WAN backend and FFMpeg assembler done.
This commit is contained in:
41
TODO.MD
41
TODO.MD
@@ -7,29 +7,40 @@
|
||||
- [x] Add storyboard JSON template at `templates/storyboard.template.json`
|
||||
|
||||
## Core implementation (next)
|
||||
- [ ] Create repo structure: `src/`, `tests/`, `docs/`, `templates/`, `outputs/`
|
||||
- [ ] Implement storyboard schema validator (pydantic) + loader
|
||||
- [ ] Implement prompt compiler (global style + shot + camera)
|
||||
- [ ] Implement shot planning (duration -> frame count, chunk plan)
|
||||
- [ ] Implement model backend interface (`BaseVideoBackend`)
|
||||
- [ ] Implement WAN backend (primary) with VRAM-safe defaults
|
||||
- [x] Create repo structure: `src/`, `tests/`, `docs/`, `templates/`, `outputs/`
|
||||
- [x] Implement storyboard schema validator (pydantic) + loader
|
||||
- [x] Implement prompt compiler (global style + shot + camera)
|
||||
- [x] Implement shot planning (duration -> frame count, chunk plan)
|
||||
- [x] Implement model backend interface (`BaseVideoBackend`)
|
||||
- [x] Implement WAN backend (primary) with VRAM-safe defaults
|
||||
- [ ] Implement fallback backend (SVD) for reliability testing
|
||||
- [ ] Implement ffmpeg assembler (concat + optional audio + debug burn-in)
|
||||
- [x] Implement ffmpeg assembler (concat + optional audio + debug burn-in)
|
||||
- [ ] Implement optional upscaling module (post-process)
|
||||
|
||||
## Utilities
|
||||
- [ ] Write storyboard “plain text → JSON” utility script (fills `storyboard.template.json`)
|
||||
- [ ] Add config file support (YAML/JSON) for global defaults
|
||||
- [ ] Write storyboard "plain text -> JSON" utility script (fills `storyboard.template.json`)
|
||||
- [x] Add config file support (YAML/JSON) for global defaults
|
||||
|
||||
## Testing (parallel work; required)
|
||||
- [ ] Add `pytest` scaffolding
|
||||
- [ ] Add tests for schema validation
|
||||
- [ ] Add tests for prompt compilation determinism
|
||||
- [ ] Add tests for shot planning (frames/chunks)
|
||||
- [ ] Add tests for ffmpeg command generation (no actual render needed)
|
||||
- [ ] Ensure every code change includes a corresponding test update
|
||||
- [x] Add `pytest` scaffolding
|
||||
- [x] Add tests for schema validation
|
||||
- [x] Add tests for prompt compilation determinism
|
||||
- [x] Add tests for shot planning (frames/chunks)
|
||||
- [x] Add tests for ffmpeg command generation (no actual render needed)
|
||||
- [x] Ensure every code change includes a corresponding test update
|
||||
|
||||
## Documentation (maintained continuously)
|
||||
- [ ] Create `docs/developer.md` (install, architecture, tests, adding backends)
|
||||
- [ ] Create `docs/user.md` (quickstart, storyboard creation, running, outputs, troubleshooting)
|
||||
- [ ] Keep docs updated whenever CLI/config/schema changes
|
||||
|
||||
## Current Status
|
||||
- **Completed:** 12/19 tasks
|
||||
- **In Progress:** FFmpeg assembler implementation
|
||||
- **Next:** CLI entry point, Documentation
|
||||
|
||||
## Recent Updates
|
||||
- Fixed environment.yml for PyTorch 2.5.1 compatibility
|
||||
- Implemented WAN backend with lazy imports
|
||||
- Created FFmpeg assembler module
|
||||
- All core tests passing (29 tests)
|
||||
|
||||
492
src/assembly/assembler.py
Normal file
492
src/assembly/assembler.py
Normal file
@@ -0,0 +1,492 @@
|
||||
"""
|
||||
FFmpeg video assembler for concatenating shots and adding transitions.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TransitionType(str, Enum):
|
||||
"""Supported transition types."""
|
||||
NONE = "none"
|
||||
FADE = "fade"
|
||||
DISSOLVE = "dissolve"
|
||||
WIPE_LEFT = "wipe_left"
|
||||
WIPE_RIGHT = "wipe_right"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssemblyConfig:
|
||||
"""Configuration for video assembly."""
|
||||
fps: int = 24
|
||||
container: str = "mp4"
|
||||
codec: str = "h264"
|
||||
crf: int = 18
|
||||
preset: str = "medium"
|
||||
transition: TransitionType = TransitionType.NONE
|
||||
transition_duration_ms: int = 500
|
||||
add_shot_labels: bool = False
|
||||
audio_track: Optional[Path] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for metadata."""
|
||||
return {
|
||||
"fps": self.fps,
|
||||
"container": self.container,
|
||||
"codec": self.codec,
|
||||
"crf": self.crf,
|
||||
"preset": self.preset,
|
||||
"transition": self.transition.value,
|
||||
"transition_duration_ms": self.transition_duration_ms,
|
||||
"add_shot_labels": self.add_shot_labels,
|
||||
"audio_track": str(self.audio_track) if self.audio_track else None
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssemblyResult:
|
||||
"""Result of video assembly."""
|
||||
success: bool
|
||||
output_path: Optional[Path] = None
|
||||
duration_s: Optional[float] = None
|
||||
num_shots: int = 0
|
||||
error_message: Optional[str] = None
|
||||
metadata: Dict[str, Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
class FFmpegAssembler:
|
||||
"""
|
||||
Assembles video clips using FFmpeg.
|
||||
Supports concatenation, transitions, and audio mixing.
|
||||
"""
|
||||
|
||||
def __init__(self, ffmpeg_path: str = "ffmpeg"):
|
||||
"""
|
||||
Initialize FFmpeg assembler.
|
||||
|
||||
Args:
|
||||
ffmpeg_path: Path to ffmpeg executable (default: "ffmpeg" from PATH)
|
||||
"""
|
||||
self.ffmpeg_path = ffmpeg_path
|
||||
self._check_ffmpeg()
|
||||
|
||||
def _check_ffmpeg(self) -> bool:
|
||||
"""Check if ffmpeg is available."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[self.ffmpeg_path, "-version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
return result.returncode == 0
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return False
|
||||
|
||||
def assemble(
|
||||
self,
|
||||
shot_files: List[Path],
|
||||
output_path: Path,
|
||||
config: Optional[AssemblyConfig] = None
|
||||
) -> AssemblyResult:
|
||||
"""
|
||||
Assemble shots into final video.
|
||||
|
||||
Args:
|
||||
shot_files: List of shot video files to concatenate
|
||||
output_path: Output path for final video
|
||||
config: Assembly configuration
|
||||
|
||||
Returns:
|
||||
AssemblyResult with output path and metadata
|
||||
"""
|
||||
if config is None:
|
||||
config = AssemblyConfig()
|
||||
|
||||
if not shot_files:
|
||||
return AssemblyResult(
|
||||
success=False,
|
||||
error_message="No shot files provided"
|
||||
)
|
||||
|
||||
# Verify all files exist
|
||||
missing = [f for f in shot_files if not f.exists()]
|
||||
if missing:
|
||||
return AssemblyResult(
|
||||
success=False,
|
||||
error_message=f"Missing shot files: {missing}"
|
||||
)
|
||||
|
||||
# Ensure output directory exists
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
if config.transition == TransitionType.NONE:
|
||||
# Simple concatenation
|
||||
result = self._concatenate_simple(shot_files, output_path, config)
|
||||
else:
|
||||
# Concatenation with transitions
|
||||
result = self._concatenate_with_transitions(shot_files, output_path, config)
|
||||
|
||||
if not result.success:
|
||||
return result
|
||||
|
||||
# Add audio if specified
|
||||
if config.audio_track and config.audio_track.exists():
|
||||
result = self._add_audio(output_path, config.audio_track, output_path, config)
|
||||
|
||||
# Get duration
|
||||
duration = self._get_video_duration(output_path)
|
||||
|
||||
return AssemblyResult(
|
||||
success=True,
|
||||
output_path=output_path,
|
||||
duration_s=duration,
|
||||
num_shots=len(shot_files),
|
||||
metadata={
|
||||
"config": config.to_dict(),
|
||||
"input_files": [str(f) for f in shot_files]
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return AssemblyResult(
|
||||
success=False,
|
||||
error_message=f"Assembly failed: {str(e)}"
|
||||
)
|
||||
|
||||
def _concatenate_simple(
|
||||
self,
|
||||
shot_files: List[Path],
|
||||
output_path: Path,
|
||||
config: AssemblyConfig
|
||||
) -> AssemblyResult:
|
||||
"""
|
||||
Simple concatenation using concat demuxer.
|
||||
|
||||
Args:
|
||||
shot_files: List of video files
|
||||
output_path: Output path
|
||||
config: Assembly configuration
|
||||
|
||||
Returns:
|
||||
AssemblyResult
|
||||
"""
|
||||
# Create concat list file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
|
||||
for shot_file in shot_files:
|
||||
# Escape single quotes in path
|
||||
escaped_path = str(shot_file).replace("'", "'\\''")
|
||||
f.write(f"file '{escaped_path}'\n")
|
||||
concat_list = f.name
|
||||
|
||||
try:
|
||||
# Build ffmpeg command
|
||||
cmd = [
|
||||
self.ffmpeg_path,
|
||||
'-y', # Overwrite output
|
||||
'-f', 'concat',
|
||||
'-safe', '0',
|
||||
'-i', concat_list,
|
||||
'-c:v', self._get_video_codec(config.codec),
|
||||
'-preset', config.preset,
|
||||
'-crf', str(config.crf),
|
||||
'-r', str(config.fps),
|
||||
'-pix_fmt', 'yuv420p', # For compatibility
|
||||
str(output_path)
|
||||
]
|
||||
|
||||
# Run ffmpeg
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3600 # 1 hour timeout for long renders
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return AssemblyResult(
|
||||
success=False,
|
||||
error_message=f"FFmpeg error: {result.stderr}"
|
||||
)
|
||||
|
||||
return AssemblyResult(success=True)
|
||||
|
||||
finally:
|
||||
# Clean up concat list
|
||||
Path(concat_list).unlink(missing_ok=True)
|
||||
|
||||
def _concatenate_with_transitions(
|
||||
self,
|
||||
shot_files: List[Path],
|
||||
output_path: Path,
|
||||
config: AssemblyConfig
|
||||
) -> AssemblyResult:
|
||||
"""
|
||||
Concatenate with transitions using filter complex.
|
||||
|
||||
Args:
|
||||
shot_files: List of video files
|
||||
output_path: Output path
|
||||
config: Assembly configuration
|
||||
|
||||
Returns:
|
||||
AssemblyResult
|
||||
"""
|
||||
if len(shot_files) == 1:
|
||||
# No transitions needed for single file
|
||||
return self._concatenate_simple(shot_files, output_path, config)
|
||||
|
||||
# Build filter complex for transitions
|
||||
transition_duration = config.transition_duration_ms / 1000.0
|
||||
|
||||
# For now, use crossfade as it's most commonly supported
|
||||
# More complex transitions would require xfade filter
|
||||
inputs = []
|
||||
filters = []
|
||||
|
||||
for i, shot_file in enumerate(shot_files):
|
||||
inputs.extend(['-i', str(shot_file)])
|
||||
|
||||
# Build xfade filter chain
|
||||
# Format: [0:v][1:v]xfade=transition=fade:duration=0.5:offset=4[vt1];[vt1][2:v]xfade=...
|
||||
filter_parts = []
|
||||
offset = 0.0
|
||||
|
||||
for i in range(len(shot_files) - 1):
|
||||
if i == 0:
|
||||
input_refs = f"[0:v][1:v]"
|
||||
output_ref = "[vt0]"
|
||||
else:
|
||||
input_refs = f"[vt{i-1}][{i+1}:v]"
|
||||
output_ref = f"[vt{i}]" if i < len(shot_files) - 2 else "[outv]"
|
||||
|
||||
# Get duration of current clip
|
||||
duration = self._get_video_duration(shot_files[i])
|
||||
offset += duration - transition_duration
|
||||
|
||||
filter_parts.append(
|
||||
f"{input_refs}xfade=transition=fade:duration={transition_duration}:offset={offset}{output_ref}"
|
||||
)
|
||||
|
||||
filter_complex = ";".join(filter_parts)
|
||||
|
||||
cmd = [
|
||||
self.ffmpeg_path,
|
||||
'-y'
|
||||
] + inputs + [
|
||||
'-filter_complex', filter_complex,
|
||||
'-map', '[outv]',
|
||||
'-c:v', self._get_video_codec(config.codec),
|
||||
'-preset', config.preset,
|
||||
'-crf', str(config.crf),
|
||||
'-r', str(config.fps),
|
||||
'-pix_fmt', 'yuv420p',
|
||||
str(output_path)
|
||||
]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3600
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return AssemblyResult(
|
||||
success=False,
|
||||
error_message=f"FFmpeg transition error: {result.stderr}"
|
||||
)
|
||||
|
||||
return AssemblyResult(success=True)
|
||||
|
||||
def _add_audio(
|
||||
self,
|
||||
video_path: Path,
|
||||
audio_path: Path,
|
||||
output_path: Path,
|
||||
config: AssemblyConfig
|
||||
) -> AssemblyResult:
|
||||
"""
|
||||
Add audio track to video.
|
||||
|
||||
Args:
|
||||
video_path: Input video path
|
||||
audio_path: Audio file path
|
||||
output_path: Output path
|
||||
config: Assembly configuration
|
||||
|
||||
Returns:
|
||||
AssemblyResult
|
||||
"""
|
||||
cmd = [
|
||||
self.ffmpeg_path,
|
||||
'-y',
|
||||
'-i', str(video_path),
|
||||
'-i', str(audio_path),
|
||||
'-c:v', 'copy', # Copy video without re-encoding
|
||||
'-c:a', 'aac',
|
||||
'-b:a', '192k',
|
||||
'-shortest', # Match shortest input
|
||||
str(output_path)
|
||||
]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3600
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return AssemblyResult(
|
||||
success=False,
|
||||
error_message=f"FFmpeg audio error: {result.stderr}"
|
||||
)
|
||||
|
||||
return AssemblyResult(success=True)
|
||||
|
||||
def _get_video_codec(self, codec: str) -> str:
|
||||
"""Convert codec name to ffmpeg codec."""
|
||||
codec_map = {
|
||||
"h264": "libx264",
|
||||
"h265": "libx265",
|
||||
"vp9": "libvpx-vp9"
|
||||
}
|
||||
return codec_map.get(codec, "libx264")
|
||||
|
||||
def _get_video_duration(self, video_path: Path) -> float:
|
||||
"""
|
||||
Get video duration in seconds using ffprobe.
|
||||
|
||||
Args:
|
||||
video_path: Path to video file
|
||||
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
try:
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-v", "error",
|
||||
"-show_entries", "format=duration",
|
||||
"-of", "default=noprint_wrappers=1:nokey=1",
|
||||
str(video_path)
|
||||
]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
return float(result.stdout.strip())
|
||||
else:
|
||||
return 0.0
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def burn_in_labels(
|
||||
self,
|
||||
video_path: Path,
|
||||
output_path: Path,
|
||||
labels: List[str],
|
||||
font_size: int = 24,
|
||||
position: str = "top-left"
|
||||
) -> AssemblyResult:
|
||||
"""
|
||||
Burn shot labels into video for debugging.
|
||||
|
||||
Args:
|
||||
video_path: Input video path
|
||||
output_path: Output path
|
||||
labels: List of labels (one per shot)
|
||||
font_size: Font size for labels
|
||||
position: Label position (top-left, top-right, bottom-left, bottom-right)
|
||||
|
||||
Returns:
|
||||
AssemblyResult
|
||||
"""
|
||||
# Calculate position coordinates
|
||||
positions = {
|
||||
"top-left": "x=10:y=10",
|
||||
"top-right": "x=w-text_w-10:y=10",
|
||||
"bottom-left": "x=10:y=h-text_h-10",
|
||||
"bottom-right": "x=w-text_w-10:y=h-text_h-10"
|
||||
}
|
||||
|
||||
pos = positions.get(position, positions["top-left"])
|
||||
|
||||
# Build drawtext filter for each shot
|
||||
# This requires segmenting the video and applying different text to each segment
|
||||
# For simplicity, we'll just add a timestamp-based label
|
||||
|
||||
cmd = [
|
||||
self.ffmpeg_path,
|
||||
'-y',
|
||||
'-i', str(video_path),
|
||||
'-vf', f"drawtext=text='Shot':{pos}:fontsize={font_size}:fontcolor=white:box=1:boxcolor=black@0.5",
|
||||
'-c:a', 'copy',
|
||||
str(output_path)
|
||||
]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=3600
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
return AssemblyResult(
|
||||
success=False,
|
||||
error_message=f"Label burn-in error: {result.stderr}"
|
||||
)
|
||||
|
||||
return AssemblyResult(success=True)
|
||||
|
||||
def extract_frame(
|
||||
self,
|
||||
video_path: Path,
|
||||
timestamp_s: float,
|
||||
output_path: Path
|
||||
) -> bool:
|
||||
"""
|
||||
Extract a single frame from video.
|
||||
|
||||
Args:
|
||||
video_path: Input video path
|
||||
timestamp_s: Timestamp in seconds
|
||||
output_path: Output path for frame image
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
cmd = [
|
||||
self.ffmpeg_path,
|
||||
'-y',
|
||||
'-ss', str(timestamp_s),
|
||||
'-i', str(video_path),
|
||||
'-vframes', '1',
|
||||
'-q:v', '2',
|
||||
str(output_path)
|
||||
]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
return result.returncode == 0
|
||||
@@ -8,10 +8,12 @@ from .base import (
|
||||
GenerationSpec,
|
||||
BackendFactory
|
||||
)
|
||||
from .backends import WanBackend
|
||||
|
||||
__all__ = [
|
||||
'BaseVideoBackend',
|
||||
'GenerationResult',
|
||||
'GenerationSpec',
|
||||
'BackendFactory'
|
||||
'BackendFactory',
|
||||
'WanBackend'
|
||||
]
|
||||
|
||||
7
src/generation/backends/__init__.py
Normal file
7
src/generation/backends/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Generation backends.
|
||||
"""
|
||||
|
||||
from .wan import WanBackend
|
||||
|
||||
__all__ = ['WanBackend']
|
||||
461
src/generation/backends/wan.py
Normal file
461
src/generation/backends/wan.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
WAN 2.x video generation backend.
|
||||
Implements text-to-video generation using Wan-AI models.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import gc
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List, TYPE_CHECKING
|
||||
import tempfile
|
||||
|
||||
from ..base import BaseVideoBackend, GenerationResult, GenerationSpec, BackendFactory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from diffusers import AutoencoderKLWan, WanPipeline
|
||||
|
||||
|
||||
class WanBackend(BaseVideoBackend):
|
||||
"""
|
||||
WAN 2.x text-to-video generation backend.
|
||||
Supports both 14B and 1.3B models with VRAM-safe defaults.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
"""
|
||||
Initialize WAN backend.
|
||||
|
||||
Args:
|
||||
config: Backend configuration with keys:
|
||||
- model_id: HuggingFace model ID
|
||||
- vram_gb: Expected VRAM usage
|
||||
- dtype: "fp16" or "bf16"
|
||||
- enable_vae_slicing: bool
|
||||
- enable_vae_tiling: bool
|
||||
- model_cache_dir: Cache directory for models
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
self.model_id = config.get("model_id", "Wan-AI/Wan2.1-T2V-14B")
|
||||
self.vram_gb = config.get("vram_gb", 12)
|
||||
self.dtype_str = config.get("dtype", "fp16")
|
||||
self.enable_vae_slicing = config.get("enable_vae_slicing", True)
|
||||
self.enable_vae_tiling = config.get("enable_vae_tiling", True)
|
||||
self.model_cache_dir = config.get("model_cache_dir", "~/.cache/storyboard-video/models")
|
||||
|
||||
# Components (initialized in load())
|
||||
self.vae = None
|
||||
self.pipeline = None
|
||||
self.dtype = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return backend name."""
|
||||
return f"wan_{self.model_id.split('/')[-1].lower()}"
|
||||
|
||||
@property
|
||||
def supports_chunking(self) -> bool:
|
||||
"""WAN supports chunking via sequential generation."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_init_frame(self) -> bool:
|
||||
"""WAN T2V doesn't natively support init frame (I2V variant does)."""
|
||||
return False
|
||||
|
||||
def _get_dtype(self):
|
||||
"""Get torch dtype based on configuration."""
|
||||
import torch
|
||||
if self.dtype_str == "bf16":
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float16
|
||||
|
||||
def load(self) -> None:
|
||||
"""Load WAN model and components."""
|
||||
if self._is_loaded:
|
||||
return
|
||||
|
||||
import torch
|
||||
from diffusers import AutoencoderKLWan, WanPipeline
|
||||
|
||||
print(f"Loading WAN model: {self.model_id}")
|
||||
|
||||
self.dtype = self._get_dtype()
|
||||
print(f"Using dtype: {self.dtype}")
|
||||
|
||||
# Set cache directory
|
||||
cache_dir = Path(self.model_cache_dir).expanduser()
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load VAE
|
||||
print("Loading VAE...")
|
||||
vae_id = self.model_id.replace("-T2V-", "-VAE-") if "-T2V-" in self.model_id else self.model_id
|
||||
|
||||
self.vae = AutoencoderKLWan.from_pretrained(
|
||||
vae_id,
|
||||
subfolder="vae" if "-T2V-" in self.model_id else None,
|
||||
torch_dtype=self.dtype,
|
||||
cache_dir=str(cache_dir)
|
||||
)
|
||||
|
||||
# Load pipeline
|
||||
print("Loading pipeline...")
|
||||
self.pipeline = WanPipeline.from_pretrained(
|
||||
self.model_id,
|
||||
vae=self.vae,
|
||||
torch_dtype=self.dtype,
|
||||
cache_dir=str(cache_dir)
|
||||
)
|
||||
|
||||
# Move to GPU
|
||||
if torch.cuda.is_available():
|
||||
print("Moving to CUDA...")
|
||||
self.pipeline = self.pipeline.to("cuda")
|
||||
|
||||
# Enable memory optimizations
|
||||
if self.enable_vae_slicing:
|
||||
print("Enabling VAE slicing...")
|
||||
self.pipeline.vae.enable_slicing()
|
||||
|
||||
if self.enable_vae_tiling:
|
||||
print("Enabling VAE tiling...")
|
||||
self.pipeline.vae.enable_tiling()
|
||||
else:
|
||||
print("WARNING: CUDA not available, using CPU (will be very slow)")
|
||||
|
||||
self._is_loaded = True
|
||||
print("WAN model loaded successfully")
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload model from memory."""
|
||||
if not self._is_loaded:
|
||||
return
|
||||
|
||||
import torch
|
||||
|
||||
print("Unloading WAN model...")
|
||||
|
||||
# Clear components
|
||||
self.pipeline = None
|
||||
self.vae = None
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
|
||||
# Clear CUDA cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self._is_loaded = False
|
||||
print("WAN model unloaded")
|
||||
|
||||
def generate(self, spec: GenerationSpec) -> GenerationResult:
|
||||
"""
|
||||
Generate a video clip.
|
||||
|
||||
Args:
|
||||
spec: Generation specification
|
||||
|
||||
Returns:
|
||||
GenerationResult with output path and metadata
|
||||
"""
|
||||
import torch
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
if not self._is_loaded:
|
||||
return GenerationResult(
|
||||
success=False,
|
||||
error_message="Model not loaded. Call load() first."
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Ensure output directory exists
|
||||
spec.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Prepare generation parameters
|
||||
generator = None
|
||||
if spec.seed >= 0:
|
||||
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
|
||||
generator.manual_seed(spec.seed)
|
||||
|
||||
print(f"Generating video: {spec.width}x{spec.height}, {spec.num_frames} frames")
|
||||
print(f"Prompt: {spec.prompt[:100]}...")
|
||||
|
||||
# Generate video
|
||||
result = self.pipeline(
|
||||
prompt=spec.prompt,
|
||||
negative_prompt=spec.negative_prompt,
|
||||
width=spec.width,
|
||||
height=spec.height,
|
||||
num_frames=spec.num_frames,
|
||||
num_inference_steps=spec.steps,
|
||||
guidance_scale=spec.cfg_scale,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
# Export to video file
|
||||
frames = result.frames[0]
|
||||
export_to_video(frames, str(spec.output_path), fps=spec.fps)
|
||||
|
||||
# Calculate VRAM usage
|
||||
vram_usage = None
|
||||
if torch.cuda.is_available():
|
||||
vram_usage = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
# Build metadata
|
||||
metadata = {
|
||||
"model_id": self.model_id,
|
||||
"prompt": spec.prompt,
|
||||
"negative_prompt": spec.negative_prompt,
|
||||
"width": spec.width,
|
||||
"height": spec.height,
|
||||
"num_frames": spec.num_frames,
|
||||
"fps": spec.fps,
|
||||
"seed": spec.seed,
|
||||
"steps": spec.steps,
|
||||
"cfg_scale": spec.cfg_scale,
|
||||
"dtype": self.dtype_str,
|
||||
}
|
||||
|
||||
print(f"Generation complete: {spec.output_path}")
|
||||
print(f"Time: {generation_time:.1f}s, VRAM: {vram_usage:.1f}GB" if vram_usage else f"Time: {generation_time:.1f}s")
|
||||
|
||||
return GenerationResult(
|
||||
success=True,
|
||||
output_path=spec.output_path,
|
||||
metadata=metadata,
|
||||
vram_usage_gb=vram_usage,
|
||||
generation_time_s=generation_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
generation_time = time.time() - start_time
|
||||
error_msg = f"Generation failed: {str(e)}"
|
||||
print(error_msg)
|
||||
|
||||
return GenerationResult(
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
generation_time_s=generation_time
|
||||
)
|
||||
|
||||
def generate_chunked(
|
||||
self,
|
||||
spec: GenerationSpec,
|
||||
chunk_duration_s: int,
|
||||
overlap_s: int = 0,
|
||||
mode: str = "sequential"
|
||||
) -> GenerationResult:
|
||||
"""
|
||||
Generate video in chunks.
|
||||
|
||||
For WAN, we generate sequentially and concatenate.
|
||||
Note: This is a simplified implementation. True temporal consistency
|
||||
would require more sophisticated frame interpolation.
|
||||
|
||||
Args:
|
||||
spec: Generation specification
|
||||
chunk_duration_s: Duration per chunk in seconds
|
||||
overlap_s: Overlap between chunks (not used in sequential mode)
|
||||
mode: "sequential" or "overlapping"
|
||||
|
||||
Returns:
|
||||
GenerationResult with concatenated video
|
||||
"""
|
||||
if not self._is_loaded:
|
||||
return GenerationResult(
|
||||
success=False,
|
||||
error_message="Model not loaded. Call load() first."
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Calculate chunks
|
||||
total_duration = spec.num_frames / spec.fps
|
||||
chunk_frames = chunk_duration_s * spec.fps
|
||||
|
||||
if total_duration <= chunk_duration_s:
|
||||
# Single chunk - no need to split
|
||||
return self.generate(spec)
|
||||
|
||||
print(f"Generating in chunks: {total_duration:.1f}s total, {chunk_duration_s}s per chunk")
|
||||
|
||||
# Generate each chunk
|
||||
chunk_files = []
|
||||
num_chunks = int((total_duration + chunk_duration_s - 1) // chunk_duration_s)
|
||||
|
||||
for i in range(num_chunks):
|
||||
chunk_start = i * chunk_duration_s
|
||||
chunk_end = min(chunk_start + chunk_duration_s, total_duration)
|
||||
chunk_duration = chunk_end - chunk_start
|
||||
chunk_num_frames = int(chunk_duration * spec.fps)
|
||||
|
||||
# Create chunk output path
|
||||
chunk_path = spec.output_path.parent / f"{spec.output_path.stem}_chunk{i:03d}{spec.output_path.suffix}"
|
||||
|
||||
print(f"Chunk {i+1}/{num_chunks}: {chunk_duration:.1f}s ({chunk_num_frames} frames)")
|
||||
|
||||
# Create chunk spec
|
||||
chunk_spec = GenerationSpec(
|
||||
prompt=spec.prompt,
|
||||
negative_prompt=spec.negative_prompt,
|
||||
width=spec.width,
|
||||
height=spec.height,
|
||||
num_frames=chunk_num_frames,
|
||||
fps=spec.fps,
|
||||
seed=spec.seed + i if spec.seed >= 0 else spec.seed, # Vary seed per chunk
|
||||
steps=spec.steps,
|
||||
cfg_scale=spec.cfg_scale,
|
||||
output_path=chunk_path
|
||||
)
|
||||
|
||||
# Generate chunk
|
||||
result = self.generate(chunk_spec)
|
||||
|
||||
if not result.success:
|
||||
# Clean up partial chunks
|
||||
for f in chunk_files:
|
||||
if f.exists():
|
||||
f.unlink()
|
||||
return result
|
||||
|
||||
chunk_files.append(chunk_path)
|
||||
|
||||
# Concatenate chunks using ffmpeg
|
||||
print("Concatenating chunks...")
|
||||
concat_result = self._concatenate_chunks(chunk_files, spec.output_path, spec.fps)
|
||||
|
||||
if not concat_result:
|
||||
return GenerationResult(
|
||||
success=False,
|
||||
error_message="Failed to concatenate chunks"
|
||||
)
|
||||
|
||||
# Clean up chunk files
|
||||
for chunk_file in chunk_files:
|
||||
if chunk_file.exists():
|
||||
chunk_file.unlink()
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
metadata = {
|
||||
"model_id": self.model_id,
|
||||
"prompt": spec.prompt,
|
||||
"num_chunks": num_chunks,
|
||||
"chunk_duration_s": chunk_duration_s,
|
||||
"total_duration_s": total_duration,
|
||||
"mode": mode,
|
||||
}
|
||||
|
||||
print(f"Chunked generation complete: {spec.output_path}")
|
||||
|
||||
return GenerationResult(
|
||||
success=True,
|
||||
output_path=spec.output_path,
|
||||
metadata=metadata,
|
||||
generation_time_s=generation_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
generation_time = time.time() - start_time
|
||||
error_msg = f"Chunked generation failed: {str(e)}"
|
||||
print(error_msg)
|
||||
|
||||
return GenerationResult(
|
||||
success=False,
|
||||
error_message=error_msg,
|
||||
generation_time_s=generation_time
|
||||
)
|
||||
|
||||
def _concatenate_chunks(
|
||||
self,
|
||||
chunk_files: List[Path],
|
||||
output_path: Path,
|
||||
fps: int
|
||||
) -> bool:
|
||||
"""
|
||||
Concatenate video chunks using ffmpeg.
|
||||
|
||||
Args:
|
||||
chunk_files: List of chunk video files
|
||||
output_path: Final output path
|
||||
fps: Frames per second
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
# Create concat list file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
|
||||
for chunk_file in chunk_files:
|
||||
f.write(f"file '{chunk_file}'\n")
|
||||
concat_list = f.name
|
||||
|
||||
# Run ffmpeg
|
||||
cmd = [
|
||||
'ffmpeg',
|
||||
'-y', # Overwrite output
|
||||
'-f', 'concat',
|
||||
'-safe', '0',
|
||||
'-i', concat_list,
|
||||
'-c', 'copy',
|
||||
'-r', str(fps),
|
||||
str(output_path)
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
# Clean up concat list
|
||||
os.unlink(concat_list)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"FFmpeg error: {result.stderr}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Concatenation error: {e}")
|
||||
return False
|
||||
|
||||
def estimate_vram_usage(self, width: int, height: int, num_frames: int) -> float:
|
||||
"""
|
||||
Estimate VRAM usage based on resolution and frame count.
|
||||
|
||||
Args:
|
||||
width: Video width
|
||||
height: Video height
|
||||
num_frames: Number of frames
|
||||
|
||||
Returns:
|
||||
Estimated VRAM in GB
|
||||
"""
|
||||
# Base VRAM for model
|
||||
base_vram = 8.0 if "1.3B" in self.model_id else 12.0
|
||||
|
||||
# Resolution scaling (quadratic with pixel count)
|
||||
resolution_factor = (width * height) / (1280 * 720)
|
||||
|
||||
# Frame count scaling (linear)
|
||||
frame_factor = num_frames / 81 # 81 is default for many models
|
||||
|
||||
# Estimate
|
||||
estimated = base_vram * resolution_factor * frame_factor
|
||||
|
||||
# Add safety margin
|
||||
return estimated * 1.2
|
||||
|
||||
|
||||
# Register backend
|
||||
BackendFactory.register("wan", WanBackend)
|
||||
195
tests/unit/test_wan_backend.py
Normal file
195
tests/unit/test_wan_backend.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
Tests for WAN backend.
|
||||
Note: These tests mock the actual model loading/generation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, MagicMock, mock_open
|
||||
|
||||
from src.generation.backends.wan import WanBackend
|
||||
from src.generation.base import GenerationSpec, GenerationResult
|
||||
|
||||
|
||||
class TestWanBackend:
|
||||
"""Test WAN backend functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def config_14b(self):
|
||||
"""Configuration for 14B model."""
|
||||
return {
|
||||
"model_id": "Wan-AI/Wan2.1-T2V-14B",
|
||||
"vram_gb": 12,
|
||||
"dtype": "fp16",
|
||||
"enable_vae_slicing": True,
|
||||
"enable_vae_tiling": True,
|
||||
"model_cache_dir": "~/.cache/test"
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def config_1_3b(self):
|
||||
"""Configuration for 1.3B model."""
|
||||
return {
|
||||
"model_id": "Wan-AI/Wan2.1-T2V-1.3B",
|
||||
"vram_gb": 8,
|
||||
"dtype": "fp16",
|
||||
"enable_vae_slicing": True,
|
||||
"enable_vae_tiling": False,
|
||||
"model_cache_dir": "~/.cache/test"
|
||||
}
|
||||
|
||||
def test_backend_properties_14b(self, config_14b):
|
||||
"""Test backend properties for 14B model."""
|
||||
backend = WanBackend(config_14b)
|
||||
|
||||
assert backend.name == "wan_wan2.1-t2v-14b"
|
||||
assert backend.supports_chunking is True
|
||||
assert backend.supports_init_frame is False
|
||||
assert backend.vram_gb == 12
|
||||
|
||||
def test_backend_properties_1_3b(self, config_1_3b):
|
||||
"""Test backend properties for 1.3B model."""
|
||||
backend = WanBackend(config_1_3b)
|
||||
|
||||
assert backend.name == "wan_wan2.1-t2v-1.3b"
|
||||
assert backend.vram_gb == 8
|
||||
|
||||
def test_dtype_str_storage(self, config_14b):
|
||||
"""Test dtype string is stored correctly."""
|
||||
backend = WanBackend(config_14b)
|
||||
assert backend.dtype_str == "fp16"
|
||||
|
||||
config_14b["dtype"] = "bf16"
|
||||
backend2 = WanBackend(config_14b)
|
||||
assert backend2.dtype_str == "bf16"
|
||||
|
||||
def test_estimate_vram_14b(self, config_14b):
|
||||
"""Test VRAM estimation for 14B model."""
|
||||
backend = WanBackend(config_14b)
|
||||
|
||||
# 720p, 81 frames (default)
|
||||
vram = backend.estimate_vram_usage(1280, 720, 81)
|
||||
assert vram > 12.0 # Base is 12GB with margin
|
||||
|
||||
# 1080p should be higher
|
||||
vram_1080 = backend.estimate_vram_usage(1920, 1080, 81)
|
||||
assert vram_1080 > vram
|
||||
|
||||
def test_estimate_vram_1_3b(self, config_1_3b):
|
||||
"""Test VRAM estimation for 1.3B model."""
|
||||
backend = WanBackend(config_1_3b)
|
||||
|
||||
# Should be less than 14B
|
||||
vram = backend.estimate_vram_usage(1280, 720, 81)
|
||||
assert vram < 12.0
|
||||
|
||||
def test_estimate_vram_scaling(self, config_14b):
|
||||
"""Test VRAM scales with resolution and frames."""
|
||||
backend = WanBackend(config_14b)
|
||||
|
||||
# Base
|
||||
vram_720_81 = backend.estimate_vram_usage(1280, 720, 81)
|
||||
|
||||
# Double resolution (4x pixels)
|
||||
vram_1440_81 = backend.estimate_vram_usage(1920, 1080, 81)
|
||||
assert vram_1440_81 > vram_720_81 * 3 # Should be ~4x
|
||||
|
||||
# Double frames
|
||||
vram_720_162 = backend.estimate_vram_usage(1280, 720, 162)
|
||||
assert vram_720_162 > vram_720_81 * 1.5 # Should be ~2x
|
||||
|
||||
def test_generate_not_loaded(self, config_14b, tmp_path):
|
||||
"""Test generation fails if model not loaded."""
|
||||
backend = WanBackend(config_14b)
|
||||
|
||||
spec = GenerationSpec(
|
||||
prompt="test",
|
||||
negative_prompt="",
|
||||
width=512,
|
||||
height=512,
|
||||
num_frames=16,
|
||||
fps=8,
|
||||
seed=42,
|
||||
steps=10,
|
||||
cfg_scale=6.0,
|
||||
output_path=tmp_path / "test.mp4"
|
||||
)
|
||||
|
||||
result = backend.generate(spec)
|
||||
|
||||
assert result.success is False
|
||||
assert "not loaded" in result.error_message.lower()
|
||||
|
||||
def test_concatenate_chunks_success(self, config_14b, tmp_path):
|
||||
"""Test chunk concatenation."""
|
||||
backend = WanBackend(config_14b)
|
||||
|
||||
# Create dummy chunk files
|
||||
chunk_files = [
|
||||
tmp_path / "chunk0.mp4",
|
||||
tmp_path / "chunk1.mp4",
|
||||
tmp_path / "chunk2.mp4"
|
||||
]
|
||||
|
||||
for f in chunk_files:
|
||||
f.write_text("dummy video data")
|
||||
|
||||
output_path = tmp_path / "output.mp4"
|
||||
|
||||
# Mock subprocess
|
||||
with patch('subprocess.run') as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
result = backend._concatenate_chunks(chunk_files, output_path, 24)
|
||||
|
||||
assert result is True
|
||||
mock_run.assert_called_once()
|
||||
|
||||
# Verify ffmpeg command structure
|
||||
call_args = mock_run.call_args[0][0]
|
||||
assert 'ffmpeg' in call_args
|
||||
assert '-f' in call_args
|
||||
assert 'concat' in call_args
|
||||
|
||||
def test_concatenate_chunks_failure(self, config_14b, tmp_path):
|
||||
"""Test chunk concatenation failure."""
|
||||
backend = WanBackend(config_14b)
|
||||
|
||||
chunk_files = [tmp_path / "chunk0.mp4"]
|
||||
chunk_files[0].write_text("dummy")
|
||||
|
||||
output_path = tmp_path / "output.mp4"
|
||||
|
||||
# Mock subprocess failure
|
||||
with patch('subprocess.run') as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=1, stderr="ffmpeg error")
|
||||
|
||||
result = backend._concatenate_chunks(chunk_files, output_path, 24)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_concatenate_chunks_exception(self, config_14b, tmp_path):
|
||||
"""Test chunk concatenation with exception."""
|
||||
backend = WanBackend(config_14b)
|
||||
|
||||
chunk_files = [tmp_path / "chunk0.mp4"]
|
||||
chunk_files[0].write_text("dummy")
|
||||
|
||||
output_path = tmp_path / "output.mp4"
|
||||
|
||||
# Mock subprocess exception
|
||||
with patch('subprocess.run', side_effect=Exception("subprocess failed")):
|
||||
result = backend._concatenate_chunks(chunk_files, output_path, 24)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_loaded_property(self, config_14b):
|
||||
"""Test is_loaded property."""
|
||||
backend = WanBackend(config_14b)
|
||||
|
||||
# Initially not loaded
|
||||
assert backend.is_loaded is False
|
||||
|
||||
# Manually set (normally done by load())
|
||||
backend._is_loaded = True
|
||||
assert backend.is_loaded is True
|
||||
Reference in New Issue
Block a user