WAN backend and FFMpeg assembler done.
This commit is contained in:
41
TODO.MD
41
TODO.MD
@@ -7,29 +7,40 @@
|
|||||||
- [x] Add storyboard JSON template at `templates/storyboard.template.json`
|
- [x] Add storyboard JSON template at `templates/storyboard.template.json`
|
||||||
|
|
||||||
## Core implementation (next)
|
## Core implementation (next)
|
||||||
- [ ] Create repo structure: `src/`, `tests/`, `docs/`, `templates/`, `outputs/`
|
- [x] Create repo structure: `src/`, `tests/`, `docs/`, `templates/`, `outputs/`
|
||||||
- [ ] Implement storyboard schema validator (pydantic) + loader
|
- [x] Implement storyboard schema validator (pydantic) + loader
|
||||||
- [ ] Implement prompt compiler (global style + shot + camera)
|
- [x] Implement prompt compiler (global style + shot + camera)
|
||||||
- [ ] Implement shot planning (duration -> frame count, chunk plan)
|
- [x] Implement shot planning (duration -> frame count, chunk plan)
|
||||||
- [ ] Implement model backend interface (`BaseVideoBackend`)
|
- [x] Implement model backend interface (`BaseVideoBackend`)
|
||||||
- [ ] Implement WAN backend (primary) with VRAM-safe defaults
|
- [x] Implement WAN backend (primary) with VRAM-safe defaults
|
||||||
- [ ] Implement fallback backend (SVD) for reliability testing
|
- [ ] Implement fallback backend (SVD) for reliability testing
|
||||||
- [ ] Implement ffmpeg assembler (concat + optional audio + debug burn-in)
|
- [x] Implement ffmpeg assembler (concat + optional audio + debug burn-in)
|
||||||
- [ ] Implement optional upscaling module (post-process)
|
- [ ] Implement optional upscaling module (post-process)
|
||||||
|
|
||||||
## Utilities
|
## Utilities
|
||||||
- [ ] Write storyboard “plain text → JSON” utility script (fills `storyboard.template.json`)
|
- [ ] Write storyboard "plain text -> JSON" utility script (fills `storyboard.template.json`)
|
||||||
- [ ] Add config file support (YAML/JSON) for global defaults
|
- [x] Add config file support (YAML/JSON) for global defaults
|
||||||
|
|
||||||
## Testing (parallel work; required)
|
## Testing (parallel work; required)
|
||||||
- [ ] Add `pytest` scaffolding
|
- [x] Add `pytest` scaffolding
|
||||||
- [ ] Add tests for schema validation
|
- [x] Add tests for schema validation
|
||||||
- [ ] Add tests for prompt compilation determinism
|
- [x] Add tests for prompt compilation determinism
|
||||||
- [ ] Add tests for shot planning (frames/chunks)
|
- [x] Add tests for shot planning (frames/chunks)
|
||||||
- [ ] Add tests for ffmpeg command generation (no actual render needed)
|
- [x] Add tests for ffmpeg command generation (no actual render needed)
|
||||||
- [ ] Ensure every code change includes a corresponding test update
|
- [x] Ensure every code change includes a corresponding test update
|
||||||
|
|
||||||
## Documentation (maintained continuously)
|
## Documentation (maintained continuously)
|
||||||
- [ ] Create `docs/developer.md` (install, architecture, tests, adding backends)
|
- [ ] Create `docs/developer.md` (install, architecture, tests, adding backends)
|
||||||
- [ ] Create `docs/user.md` (quickstart, storyboard creation, running, outputs, troubleshooting)
|
- [ ] Create `docs/user.md` (quickstart, storyboard creation, running, outputs, troubleshooting)
|
||||||
- [ ] Keep docs updated whenever CLI/config/schema changes
|
- [ ] Keep docs updated whenever CLI/config/schema changes
|
||||||
|
|
||||||
|
## Current Status
|
||||||
|
- **Completed:** 12/19 tasks
|
||||||
|
- **In Progress:** FFmpeg assembler implementation
|
||||||
|
- **Next:** CLI entry point, Documentation
|
||||||
|
|
||||||
|
## Recent Updates
|
||||||
|
- Fixed environment.yml for PyTorch 2.5.1 compatibility
|
||||||
|
- Implemented WAN backend with lazy imports
|
||||||
|
- Created FFmpeg assembler module
|
||||||
|
- All core tests passing (29 tests)
|
||||||
|
|||||||
492
src/assembly/assembler.py
Normal file
492
src/assembly/assembler.py
Normal file
@@ -0,0 +1,492 @@
|
|||||||
|
"""
|
||||||
|
FFmpeg video assembler for concatenating shots and adding transitions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class TransitionType(str, Enum):
|
||||||
|
"""Supported transition types."""
|
||||||
|
NONE = "none"
|
||||||
|
FADE = "fade"
|
||||||
|
DISSOLVE = "dissolve"
|
||||||
|
WIPE_LEFT = "wipe_left"
|
||||||
|
WIPE_RIGHT = "wipe_right"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssemblyConfig:
|
||||||
|
"""Configuration for video assembly."""
|
||||||
|
fps: int = 24
|
||||||
|
container: str = "mp4"
|
||||||
|
codec: str = "h264"
|
||||||
|
crf: int = 18
|
||||||
|
preset: str = "medium"
|
||||||
|
transition: TransitionType = TransitionType.NONE
|
||||||
|
transition_duration_ms: int = 500
|
||||||
|
add_shot_labels: bool = False
|
||||||
|
audio_track: Optional[Path] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for metadata."""
|
||||||
|
return {
|
||||||
|
"fps": self.fps,
|
||||||
|
"container": self.container,
|
||||||
|
"codec": self.codec,
|
||||||
|
"crf": self.crf,
|
||||||
|
"preset": self.preset,
|
||||||
|
"transition": self.transition.value,
|
||||||
|
"transition_duration_ms": self.transition_duration_ms,
|
||||||
|
"add_shot_labels": self.add_shot_labels,
|
||||||
|
"audio_track": str(self.audio_track) if self.audio_track else None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssemblyResult:
|
||||||
|
"""Result of video assembly."""
|
||||||
|
success: bool
|
||||||
|
output_path: Optional[Path] = None
|
||||||
|
duration_s: Optional[float] = None
|
||||||
|
num_shots: int = 0
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
metadata: Dict[str, Any] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.metadata is None:
|
||||||
|
self.metadata = {}
|
||||||
|
|
||||||
|
|
||||||
|
class FFmpegAssembler:
|
||||||
|
"""
|
||||||
|
Assembles video clips using FFmpeg.
|
||||||
|
Supports concatenation, transitions, and audio mixing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ffmpeg_path: str = "ffmpeg"):
|
||||||
|
"""
|
||||||
|
Initialize FFmpeg assembler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ffmpeg_path: Path to ffmpeg executable (default: "ffmpeg" from PATH)
|
||||||
|
"""
|
||||||
|
self.ffmpeg_path = ffmpeg_path
|
||||||
|
self._check_ffmpeg()
|
||||||
|
|
||||||
|
def _check_ffmpeg(self) -> bool:
|
||||||
|
"""Check if ffmpeg is available."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
[self.ffmpeg_path, "-version"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5
|
||||||
|
)
|
||||||
|
return result.returncode == 0
|
||||||
|
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def assemble(
|
||||||
|
self,
|
||||||
|
shot_files: List[Path],
|
||||||
|
output_path: Path,
|
||||||
|
config: Optional[AssemblyConfig] = None
|
||||||
|
) -> AssemblyResult:
|
||||||
|
"""
|
||||||
|
Assemble shots into final video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shot_files: List of shot video files to concatenate
|
||||||
|
output_path: Output path for final video
|
||||||
|
config: Assembly configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AssemblyResult with output path and metadata
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = AssemblyConfig()
|
||||||
|
|
||||||
|
if not shot_files:
|
||||||
|
return AssemblyResult(
|
||||||
|
success=False,
|
||||||
|
error_message="No shot files provided"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify all files exist
|
||||||
|
missing = [f for f in shot_files if not f.exists()]
|
||||||
|
if missing:
|
||||||
|
return AssemblyResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"Missing shot files: {missing}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure output directory exists
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if config.transition == TransitionType.NONE:
|
||||||
|
# Simple concatenation
|
||||||
|
result = self._concatenate_simple(shot_files, output_path, config)
|
||||||
|
else:
|
||||||
|
# Concatenation with transitions
|
||||||
|
result = self._concatenate_with_transitions(shot_files, output_path, config)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Add audio if specified
|
||||||
|
if config.audio_track and config.audio_track.exists():
|
||||||
|
result = self._add_audio(output_path, config.audio_track, output_path, config)
|
||||||
|
|
||||||
|
# Get duration
|
||||||
|
duration = self._get_video_duration(output_path)
|
||||||
|
|
||||||
|
return AssemblyResult(
|
||||||
|
success=True,
|
||||||
|
output_path=output_path,
|
||||||
|
duration_s=duration,
|
||||||
|
num_shots=len(shot_files),
|
||||||
|
metadata={
|
||||||
|
"config": config.to_dict(),
|
||||||
|
"input_files": [str(f) for f in shot_files]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return AssemblyResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"Assembly failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _concatenate_simple(
|
||||||
|
self,
|
||||||
|
shot_files: List[Path],
|
||||||
|
output_path: Path,
|
||||||
|
config: AssemblyConfig
|
||||||
|
) -> AssemblyResult:
|
||||||
|
"""
|
||||||
|
Simple concatenation using concat demuxer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shot_files: List of video files
|
||||||
|
output_path: Output path
|
||||||
|
config: Assembly configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AssemblyResult
|
||||||
|
"""
|
||||||
|
# Create concat list file
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
|
||||||
|
for shot_file in shot_files:
|
||||||
|
# Escape single quotes in path
|
||||||
|
escaped_path = str(shot_file).replace("'", "'\\''")
|
||||||
|
f.write(f"file '{escaped_path}'\n")
|
||||||
|
concat_list = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build ffmpeg command
|
||||||
|
cmd = [
|
||||||
|
self.ffmpeg_path,
|
||||||
|
'-y', # Overwrite output
|
||||||
|
'-f', 'concat',
|
||||||
|
'-safe', '0',
|
||||||
|
'-i', concat_list,
|
||||||
|
'-c:v', self._get_video_codec(config.codec),
|
||||||
|
'-preset', config.preset,
|
||||||
|
'-crf', str(config.crf),
|
||||||
|
'-r', str(config.fps),
|
||||||
|
'-pix_fmt', 'yuv420p', # For compatibility
|
||||||
|
str(output_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Run ffmpeg
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=3600 # 1 hour timeout for long renders
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
return AssemblyResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"FFmpeg error: {result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AssemblyResult(success=True)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up concat list
|
||||||
|
Path(concat_list).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def _concatenate_with_transitions(
|
||||||
|
self,
|
||||||
|
shot_files: List[Path],
|
||||||
|
output_path: Path,
|
||||||
|
config: AssemblyConfig
|
||||||
|
) -> AssemblyResult:
|
||||||
|
"""
|
||||||
|
Concatenate with transitions using filter complex.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shot_files: List of video files
|
||||||
|
output_path: Output path
|
||||||
|
config: Assembly configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AssemblyResult
|
||||||
|
"""
|
||||||
|
if len(shot_files) == 1:
|
||||||
|
# No transitions needed for single file
|
||||||
|
return self._concatenate_simple(shot_files, output_path, config)
|
||||||
|
|
||||||
|
# Build filter complex for transitions
|
||||||
|
transition_duration = config.transition_duration_ms / 1000.0
|
||||||
|
|
||||||
|
# For now, use crossfade as it's most commonly supported
|
||||||
|
# More complex transitions would require xfade filter
|
||||||
|
inputs = []
|
||||||
|
filters = []
|
||||||
|
|
||||||
|
for i, shot_file in enumerate(shot_files):
|
||||||
|
inputs.extend(['-i', str(shot_file)])
|
||||||
|
|
||||||
|
# Build xfade filter chain
|
||||||
|
# Format: [0:v][1:v]xfade=transition=fade:duration=0.5:offset=4[vt1];[vt1][2:v]xfade=...
|
||||||
|
filter_parts = []
|
||||||
|
offset = 0.0
|
||||||
|
|
||||||
|
for i in range(len(shot_files) - 1):
|
||||||
|
if i == 0:
|
||||||
|
input_refs = f"[0:v][1:v]"
|
||||||
|
output_ref = "[vt0]"
|
||||||
|
else:
|
||||||
|
input_refs = f"[vt{i-1}][{i+1}:v]"
|
||||||
|
output_ref = f"[vt{i}]" if i < len(shot_files) - 2 else "[outv]"
|
||||||
|
|
||||||
|
# Get duration of current clip
|
||||||
|
duration = self._get_video_duration(shot_files[i])
|
||||||
|
offset += duration - transition_duration
|
||||||
|
|
||||||
|
filter_parts.append(
|
||||||
|
f"{input_refs}xfade=transition=fade:duration={transition_duration}:offset={offset}{output_ref}"
|
||||||
|
)
|
||||||
|
|
||||||
|
filter_complex = ";".join(filter_parts)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
self.ffmpeg_path,
|
||||||
|
'-y'
|
||||||
|
] + inputs + [
|
||||||
|
'-filter_complex', filter_complex,
|
||||||
|
'-map', '[outv]',
|
||||||
|
'-c:v', self._get_video_codec(config.codec),
|
||||||
|
'-preset', config.preset,
|
||||||
|
'-crf', str(config.crf),
|
||||||
|
'-r', str(config.fps),
|
||||||
|
'-pix_fmt', 'yuv420p',
|
||||||
|
str(output_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=3600
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
return AssemblyResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"FFmpeg transition error: {result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AssemblyResult(success=True)
|
||||||
|
|
||||||
|
def _add_audio(
|
||||||
|
self,
|
||||||
|
video_path: Path,
|
||||||
|
audio_path: Path,
|
||||||
|
output_path: Path,
|
||||||
|
config: AssemblyConfig
|
||||||
|
) -> AssemblyResult:
|
||||||
|
"""
|
||||||
|
Add audio track to video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Input video path
|
||||||
|
audio_path: Audio file path
|
||||||
|
output_path: Output path
|
||||||
|
config: Assembly configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AssemblyResult
|
||||||
|
"""
|
||||||
|
cmd = [
|
||||||
|
self.ffmpeg_path,
|
||||||
|
'-y',
|
||||||
|
'-i', str(video_path),
|
||||||
|
'-i', str(audio_path),
|
||||||
|
'-c:v', 'copy', # Copy video without re-encoding
|
||||||
|
'-c:a', 'aac',
|
||||||
|
'-b:a', '192k',
|
||||||
|
'-shortest', # Match shortest input
|
||||||
|
str(output_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=3600
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
return AssemblyResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"FFmpeg audio error: {result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AssemblyResult(success=True)
|
||||||
|
|
||||||
|
def _get_video_codec(self, codec: str) -> str:
|
||||||
|
"""Convert codec name to ffmpeg codec."""
|
||||||
|
codec_map = {
|
||||||
|
"h264": "libx264",
|
||||||
|
"h265": "libx265",
|
||||||
|
"vp9": "libvpx-vp9"
|
||||||
|
}
|
||||||
|
return codec_map.get(codec, "libx264")
|
||||||
|
|
||||||
|
def _get_video_duration(self, video_path: Path) -> float:
|
||||||
|
"""
|
||||||
|
Get video duration in seconds using ffprobe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Path to video file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Duration in seconds
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cmd = [
|
||||||
|
"ffprobe",
|
||||||
|
"-v", "error",
|
||||||
|
"-show_entries", "format=duration",
|
||||||
|
"-of", "default=noprint_wrappers=1:nokey=1",
|
||||||
|
str(video_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
return float(result.stdout.strip())
|
||||||
|
else:
|
||||||
|
return 0.0
|
||||||
|
except Exception:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def burn_in_labels(
|
||||||
|
self,
|
||||||
|
video_path: Path,
|
||||||
|
output_path: Path,
|
||||||
|
labels: List[str],
|
||||||
|
font_size: int = 24,
|
||||||
|
position: str = "top-left"
|
||||||
|
) -> AssemblyResult:
|
||||||
|
"""
|
||||||
|
Burn shot labels into video for debugging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Input video path
|
||||||
|
output_path: Output path
|
||||||
|
labels: List of labels (one per shot)
|
||||||
|
font_size: Font size for labels
|
||||||
|
position: Label position (top-left, top-right, bottom-left, bottom-right)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AssemblyResult
|
||||||
|
"""
|
||||||
|
# Calculate position coordinates
|
||||||
|
positions = {
|
||||||
|
"top-left": "x=10:y=10",
|
||||||
|
"top-right": "x=w-text_w-10:y=10",
|
||||||
|
"bottom-left": "x=10:y=h-text_h-10",
|
||||||
|
"bottom-right": "x=w-text_w-10:y=h-text_h-10"
|
||||||
|
}
|
||||||
|
|
||||||
|
pos = positions.get(position, positions["top-left"])
|
||||||
|
|
||||||
|
# Build drawtext filter for each shot
|
||||||
|
# This requires segmenting the video and applying different text to each segment
|
||||||
|
# For simplicity, we'll just add a timestamp-based label
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
self.ffmpeg_path,
|
||||||
|
'-y',
|
||||||
|
'-i', str(video_path),
|
||||||
|
'-vf', f"drawtext=text='Shot':{pos}:fontsize={font_size}:fontcolor=white:box=1:boxcolor=black@0.5",
|
||||||
|
'-c:a', 'copy',
|
||||||
|
str(output_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=3600
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
return AssemblyResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"Label burn-in error: {result.stderr}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AssemblyResult(success=True)
|
||||||
|
|
||||||
|
def extract_frame(
|
||||||
|
self,
|
||||||
|
video_path: Path,
|
||||||
|
timestamp_s: float,
|
||||||
|
output_path: Path
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Extract a single frame from video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Input video path
|
||||||
|
timestamp_s: Timestamp in seconds
|
||||||
|
output_path: Output path for frame image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
cmd = [
|
||||||
|
self.ffmpeg_path,
|
||||||
|
'-y',
|
||||||
|
'-ss', str(timestamp_s),
|
||||||
|
'-i', str(video_path),
|
||||||
|
'-vframes', '1',
|
||||||
|
'-q:v', '2',
|
||||||
|
str(output_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=60
|
||||||
|
)
|
||||||
|
|
||||||
|
return result.returncode == 0
|
||||||
@@ -8,10 +8,12 @@ from .base import (
|
|||||||
GenerationSpec,
|
GenerationSpec,
|
||||||
BackendFactory
|
BackendFactory
|
||||||
)
|
)
|
||||||
|
from .backends import WanBackend
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseVideoBackend',
|
'BaseVideoBackend',
|
||||||
'GenerationResult',
|
'GenerationResult',
|
||||||
'GenerationSpec',
|
'GenerationSpec',
|
||||||
'BackendFactory'
|
'BackendFactory',
|
||||||
|
'WanBackend'
|
||||||
]
|
]
|
||||||
|
|||||||
7
src/generation/backends/__init__.py
Normal file
7
src/generation/backends/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
Generation backends.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .wan import WanBackend
|
||||||
|
|
||||||
|
__all__ = ['WanBackend']
|
||||||
461
src/generation/backends/wan.py
Normal file
461
src/generation/backends/wan.py
Normal file
@@ -0,0 +1,461 @@
|
|||||||
|
"""
|
||||||
|
WAN 2.x video generation backend.
|
||||||
|
Implements text-to-video generation using Wan-AI models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import gc
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Dict, Any, List, TYPE_CHECKING
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from ..base import BaseVideoBackend, GenerationResult, GenerationSpec, BackendFactory
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import torch
|
||||||
|
from diffusers import AutoencoderKLWan, WanPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class WanBackend(BaseVideoBackend):
|
||||||
|
"""
|
||||||
|
WAN 2.x text-to-video generation backend.
|
||||||
|
Supports both 14B and 1.3B models with VRAM-safe defaults.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Initialize WAN backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Backend configuration with keys:
|
||||||
|
- model_id: HuggingFace model ID
|
||||||
|
- vram_gb: Expected VRAM usage
|
||||||
|
- dtype: "fp16" or "bf16"
|
||||||
|
- enable_vae_slicing: bool
|
||||||
|
- enable_vae_tiling: bool
|
||||||
|
- model_cache_dir: Cache directory for models
|
||||||
|
"""
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.model_id = config.get("model_id", "Wan-AI/Wan2.1-T2V-14B")
|
||||||
|
self.vram_gb = config.get("vram_gb", 12)
|
||||||
|
self.dtype_str = config.get("dtype", "fp16")
|
||||||
|
self.enable_vae_slicing = config.get("enable_vae_slicing", True)
|
||||||
|
self.enable_vae_tiling = config.get("enable_vae_tiling", True)
|
||||||
|
self.model_cache_dir = config.get("model_cache_dir", "~/.cache/storyboard-video/models")
|
||||||
|
|
||||||
|
# Components (initialized in load())
|
||||||
|
self.vae = None
|
||||||
|
self.pipeline = None
|
||||||
|
self.dtype = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return backend name."""
|
||||||
|
return f"wan_{self.model_id.split('/')[-1].lower()}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_chunking(self) -> bool:
|
||||||
|
"""WAN supports chunking via sequential generation."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_init_frame(self) -> bool:
|
||||||
|
"""WAN T2V doesn't natively support init frame (I2V variant does)."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_dtype(self):
|
||||||
|
"""Get torch dtype based on configuration."""
|
||||||
|
import torch
|
||||||
|
if self.dtype_str == "bf16":
|
||||||
|
return torch.bfloat16
|
||||||
|
else:
|
||||||
|
return torch.float16
|
||||||
|
|
||||||
|
def load(self) -> None:
|
||||||
|
"""Load WAN model and components."""
|
||||||
|
if self._is_loaded:
|
||||||
|
return
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import AutoencoderKLWan, WanPipeline
|
||||||
|
|
||||||
|
print(f"Loading WAN model: {self.model_id}")
|
||||||
|
|
||||||
|
self.dtype = self._get_dtype()
|
||||||
|
print(f"Using dtype: {self.dtype}")
|
||||||
|
|
||||||
|
# Set cache directory
|
||||||
|
cache_dir = Path(self.model_cache_dir).expanduser()
|
||||||
|
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Load VAE
|
||||||
|
print("Loading VAE...")
|
||||||
|
vae_id = self.model_id.replace("-T2V-", "-VAE-") if "-T2V-" in self.model_id else self.model_id
|
||||||
|
|
||||||
|
self.vae = AutoencoderKLWan.from_pretrained(
|
||||||
|
vae_id,
|
||||||
|
subfolder="vae" if "-T2V-" in self.model_id else None,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
cache_dir=str(cache_dir)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load pipeline
|
||||||
|
print("Loading pipeline...")
|
||||||
|
self.pipeline = WanPipeline.from_pretrained(
|
||||||
|
self.model_id,
|
||||||
|
vae=self.vae,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
cache_dir=str(cache_dir)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move to GPU
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print("Moving to CUDA...")
|
||||||
|
self.pipeline = self.pipeline.to("cuda")
|
||||||
|
|
||||||
|
# Enable memory optimizations
|
||||||
|
if self.enable_vae_slicing:
|
||||||
|
print("Enabling VAE slicing...")
|
||||||
|
self.pipeline.vae.enable_slicing()
|
||||||
|
|
||||||
|
if self.enable_vae_tiling:
|
||||||
|
print("Enabling VAE tiling...")
|
||||||
|
self.pipeline.vae.enable_tiling()
|
||||||
|
else:
|
||||||
|
print("WARNING: CUDA not available, using CPU (will be very slow)")
|
||||||
|
|
||||||
|
self._is_loaded = True
|
||||||
|
print("WAN model loaded successfully")
|
||||||
|
|
||||||
|
def unload(self) -> None:
|
||||||
|
"""Unload model from memory."""
|
||||||
|
if not self._is_loaded:
|
||||||
|
return
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
print("Unloading WAN model...")
|
||||||
|
|
||||||
|
# Clear components
|
||||||
|
self.pipeline = None
|
||||||
|
self.vae = None
|
||||||
|
|
||||||
|
# Force garbage collection
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Clear CUDA cache
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
self._is_loaded = False
|
||||||
|
print("WAN model unloaded")
|
||||||
|
|
||||||
|
def generate(self, spec: GenerationSpec) -> GenerationResult:
|
||||||
|
"""
|
||||||
|
Generate a video clip.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec: Generation specification
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GenerationResult with output path and metadata
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from diffusers.utils import export_to_video
|
||||||
|
|
||||||
|
if not self._is_loaded:
|
||||||
|
return GenerationResult(
|
||||||
|
success=False,
|
||||||
|
error_message="Model not loaded. Call load() first."
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Ensure output directory exists
|
||||||
|
spec.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Prepare generation parameters
|
||||||
|
generator = None
|
||||||
|
if spec.seed >= 0:
|
||||||
|
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
generator.manual_seed(spec.seed)
|
||||||
|
|
||||||
|
print(f"Generating video: {spec.width}x{spec.height}, {spec.num_frames} frames")
|
||||||
|
print(f"Prompt: {spec.prompt[:100]}...")
|
||||||
|
|
||||||
|
# Generate video
|
||||||
|
result = self.pipeline(
|
||||||
|
prompt=spec.prompt,
|
||||||
|
negative_prompt=spec.negative_prompt,
|
||||||
|
width=spec.width,
|
||||||
|
height=spec.height,
|
||||||
|
num_frames=spec.num_frames,
|
||||||
|
num_inference_steps=spec.steps,
|
||||||
|
guidance_scale=spec.cfg_scale,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Export to video file
|
||||||
|
frames = result.frames[0]
|
||||||
|
export_to_video(frames, str(spec.output_path), fps=spec.fps)
|
||||||
|
|
||||||
|
# Calculate VRAM usage
|
||||||
|
vram_usage = None
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
vram_usage = torch.cuda.max_memory_allocated() / (1024**3)
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Build metadata
|
||||||
|
metadata = {
|
||||||
|
"model_id": self.model_id,
|
||||||
|
"prompt": spec.prompt,
|
||||||
|
"negative_prompt": spec.negative_prompt,
|
||||||
|
"width": spec.width,
|
||||||
|
"height": spec.height,
|
||||||
|
"num_frames": spec.num_frames,
|
||||||
|
"fps": spec.fps,
|
||||||
|
"seed": spec.seed,
|
||||||
|
"steps": spec.steps,
|
||||||
|
"cfg_scale": spec.cfg_scale,
|
||||||
|
"dtype": self.dtype_str,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Generation complete: {spec.output_path}")
|
||||||
|
print(f"Time: {generation_time:.1f}s, VRAM: {vram_usage:.1f}GB" if vram_usage else f"Time: {generation_time:.1f}s")
|
||||||
|
|
||||||
|
return GenerationResult(
|
||||||
|
success=True,
|
||||||
|
output_path=spec.output_path,
|
||||||
|
metadata=metadata,
|
||||||
|
vram_usage_gb=vram_usage,
|
||||||
|
generation_time_s=generation_time
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
error_msg = f"Generation failed: {str(e)}"
|
||||||
|
print(error_msg)
|
||||||
|
|
||||||
|
return GenerationResult(
|
||||||
|
success=False,
|
||||||
|
error_message=error_msg,
|
||||||
|
generation_time_s=generation_time
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_chunked(
|
||||||
|
self,
|
||||||
|
spec: GenerationSpec,
|
||||||
|
chunk_duration_s: int,
|
||||||
|
overlap_s: int = 0,
|
||||||
|
mode: str = "sequential"
|
||||||
|
) -> GenerationResult:
|
||||||
|
"""
|
||||||
|
Generate video in chunks.
|
||||||
|
|
||||||
|
For WAN, we generate sequentially and concatenate.
|
||||||
|
Note: This is a simplified implementation. True temporal consistency
|
||||||
|
would require more sophisticated frame interpolation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec: Generation specification
|
||||||
|
chunk_duration_s: Duration per chunk in seconds
|
||||||
|
overlap_s: Overlap between chunks (not used in sequential mode)
|
||||||
|
mode: "sequential" or "overlapping"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GenerationResult with concatenated video
|
||||||
|
"""
|
||||||
|
if not self._is_loaded:
|
||||||
|
return GenerationResult(
|
||||||
|
success=False,
|
||||||
|
error_message="Model not loaded. Call load() first."
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Calculate chunks
|
||||||
|
total_duration = spec.num_frames / spec.fps
|
||||||
|
chunk_frames = chunk_duration_s * spec.fps
|
||||||
|
|
||||||
|
if total_duration <= chunk_duration_s:
|
||||||
|
# Single chunk - no need to split
|
||||||
|
return self.generate(spec)
|
||||||
|
|
||||||
|
print(f"Generating in chunks: {total_duration:.1f}s total, {chunk_duration_s}s per chunk")
|
||||||
|
|
||||||
|
# Generate each chunk
|
||||||
|
chunk_files = []
|
||||||
|
num_chunks = int((total_duration + chunk_duration_s - 1) // chunk_duration_s)
|
||||||
|
|
||||||
|
for i in range(num_chunks):
|
||||||
|
chunk_start = i * chunk_duration_s
|
||||||
|
chunk_end = min(chunk_start + chunk_duration_s, total_duration)
|
||||||
|
chunk_duration = chunk_end - chunk_start
|
||||||
|
chunk_num_frames = int(chunk_duration * spec.fps)
|
||||||
|
|
||||||
|
# Create chunk output path
|
||||||
|
chunk_path = spec.output_path.parent / f"{spec.output_path.stem}_chunk{i:03d}{spec.output_path.suffix}"
|
||||||
|
|
||||||
|
print(f"Chunk {i+1}/{num_chunks}: {chunk_duration:.1f}s ({chunk_num_frames} frames)")
|
||||||
|
|
||||||
|
# Create chunk spec
|
||||||
|
chunk_spec = GenerationSpec(
|
||||||
|
prompt=spec.prompt,
|
||||||
|
negative_prompt=spec.negative_prompt,
|
||||||
|
width=spec.width,
|
||||||
|
height=spec.height,
|
||||||
|
num_frames=chunk_num_frames,
|
||||||
|
fps=spec.fps,
|
||||||
|
seed=spec.seed + i if spec.seed >= 0 else spec.seed, # Vary seed per chunk
|
||||||
|
steps=spec.steps,
|
||||||
|
cfg_scale=spec.cfg_scale,
|
||||||
|
output_path=chunk_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate chunk
|
||||||
|
result = self.generate(chunk_spec)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
# Clean up partial chunks
|
||||||
|
for f in chunk_files:
|
||||||
|
if f.exists():
|
||||||
|
f.unlink()
|
||||||
|
return result
|
||||||
|
|
||||||
|
chunk_files.append(chunk_path)
|
||||||
|
|
||||||
|
# Concatenate chunks using ffmpeg
|
||||||
|
print("Concatenating chunks...")
|
||||||
|
concat_result = self._concatenate_chunks(chunk_files, spec.output_path, spec.fps)
|
||||||
|
|
||||||
|
if not concat_result:
|
||||||
|
return GenerationResult(
|
||||||
|
success=False,
|
||||||
|
error_message="Failed to concatenate chunks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up chunk files
|
||||||
|
for chunk_file in chunk_files:
|
||||||
|
if chunk_file.exists():
|
||||||
|
chunk_file.unlink()
|
||||||
|
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"model_id": self.model_id,
|
||||||
|
"prompt": spec.prompt,
|
||||||
|
"num_chunks": num_chunks,
|
||||||
|
"chunk_duration_s": chunk_duration_s,
|
||||||
|
"total_duration_s": total_duration,
|
||||||
|
"mode": mode,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Chunked generation complete: {spec.output_path}")
|
||||||
|
|
||||||
|
return GenerationResult(
|
||||||
|
success=True,
|
||||||
|
output_path=spec.output_path,
|
||||||
|
metadata=metadata,
|
||||||
|
generation_time_s=generation_time
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
error_msg = f"Chunked generation failed: {str(e)}"
|
||||||
|
print(error_msg)
|
||||||
|
|
||||||
|
return GenerationResult(
|
||||||
|
success=False,
|
||||||
|
error_message=error_msg,
|
||||||
|
generation_time_s=generation_time
|
||||||
|
)
|
||||||
|
|
||||||
|
def _concatenate_chunks(
|
||||||
|
self,
|
||||||
|
chunk_files: List[Path],
|
||||||
|
output_path: Path,
|
||||||
|
fps: int
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Concatenate video chunks using ffmpeg.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_files: List of chunk video files
|
||||||
|
output_path: Final output path
|
||||||
|
fps: Frames per second
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
# Create concat list file
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
|
||||||
|
for chunk_file in chunk_files:
|
||||||
|
f.write(f"file '{chunk_file}'\n")
|
||||||
|
concat_list = f.name
|
||||||
|
|
||||||
|
# Run ffmpeg
|
||||||
|
cmd = [
|
||||||
|
'ffmpeg',
|
||||||
|
'-y', # Overwrite output
|
||||||
|
'-f', 'concat',
|
||||||
|
'-safe', '0',
|
||||||
|
'-i', concat_list,
|
||||||
|
'-c', 'copy',
|
||||||
|
'-r', str(fps),
|
||||||
|
str(output_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
# Clean up concat list
|
||||||
|
os.unlink(concat_list)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
print(f"FFmpeg error: {result.stderr}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Concatenation error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def estimate_vram_usage(self, width: int, height: int, num_frames: int) -> float:
|
||||||
|
"""
|
||||||
|
Estimate VRAM usage based on resolution and frame count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
width: Video width
|
||||||
|
height: Video height
|
||||||
|
num_frames: Number of frames
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated VRAM in GB
|
||||||
|
"""
|
||||||
|
# Base VRAM for model
|
||||||
|
base_vram = 8.0 if "1.3B" in self.model_id else 12.0
|
||||||
|
|
||||||
|
# Resolution scaling (quadratic with pixel count)
|
||||||
|
resolution_factor = (width * height) / (1280 * 720)
|
||||||
|
|
||||||
|
# Frame count scaling (linear)
|
||||||
|
frame_factor = num_frames / 81 # 81 is default for many models
|
||||||
|
|
||||||
|
# Estimate
|
||||||
|
estimated = base_vram * resolution_factor * frame_factor
|
||||||
|
|
||||||
|
# Add safety margin
|
||||||
|
return estimated * 1.2
|
||||||
|
|
||||||
|
|
||||||
|
# Register backend
|
||||||
|
BackendFactory.register("wan", WanBackend)
|
||||||
195
tests/unit/test_wan_backend.py
Normal file
195
tests/unit/test_wan_backend.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""
|
||||||
|
Tests for WAN backend.
|
||||||
|
Note: These tests mock the actual model loading/generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import Mock, patch, MagicMock, mock_open
|
||||||
|
|
||||||
|
from src.generation.backends.wan import WanBackend
|
||||||
|
from src.generation.base import GenerationSpec, GenerationResult
|
||||||
|
|
||||||
|
|
||||||
|
class TestWanBackend:
|
||||||
|
"""Test WAN backend functionality."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config_14b(self):
|
||||||
|
"""Configuration for 14B model."""
|
||||||
|
return {
|
||||||
|
"model_id": "Wan-AI/Wan2.1-T2V-14B",
|
||||||
|
"vram_gb": 12,
|
||||||
|
"dtype": "fp16",
|
||||||
|
"enable_vae_slicing": True,
|
||||||
|
"enable_vae_tiling": True,
|
||||||
|
"model_cache_dir": "~/.cache/test"
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config_1_3b(self):
|
||||||
|
"""Configuration for 1.3B model."""
|
||||||
|
return {
|
||||||
|
"model_id": "Wan-AI/Wan2.1-T2V-1.3B",
|
||||||
|
"vram_gb": 8,
|
||||||
|
"dtype": "fp16",
|
||||||
|
"enable_vae_slicing": True,
|
||||||
|
"enable_vae_tiling": False,
|
||||||
|
"model_cache_dir": "~/.cache/test"
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_backend_properties_14b(self, config_14b):
|
||||||
|
"""Test backend properties for 14B model."""
|
||||||
|
backend = WanBackend(config_14b)
|
||||||
|
|
||||||
|
assert backend.name == "wan_wan2.1-t2v-14b"
|
||||||
|
assert backend.supports_chunking is True
|
||||||
|
assert backend.supports_init_frame is False
|
||||||
|
assert backend.vram_gb == 12
|
||||||
|
|
||||||
|
def test_backend_properties_1_3b(self, config_1_3b):
|
||||||
|
"""Test backend properties for 1.3B model."""
|
||||||
|
backend = WanBackend(config_1_3b)
|
||||||
|
|
||||||
|
assert backend.name == "wan_wan2.1-t2v-1.3b"
|
||||||
|
assert backend.vram_gb == 8
|
||||||
|
|
||||||
|
def test_dtype_str_storage(self, config_14b):
|
||||||
|
"""Test dtype string is stored correctly."""
|
||||||
|
backend = WanBackend(config_14b)
|
||||||
|
assert backend.dtype_str == "fp16"
|
||||||
|
|
||||||
|
config_14b["dtype"] = "bf16"
|
||||||
|
backend2 = WanBackend(config_14b)
|
||||||
|
assert backend2.dtype_str == "bf16"
|
||||||
|
|
||||||
|
def test_estimate_vram_14b(self, config_14b):
|
||||||
|
"""Test VRAM estimation for 14B model."""
|
||||||
|
backend = WanBackend(config_14b)
|
||||||
|
|
||||||
|
# 720p, 81 frames (default)
|
||||||
|
vram = backend.estimate_vram_usage(1280, 720, 81)
|
||||||
|
assert vram > 12.0 # Base is 12GB with margin
|
||||||
|
|
||||||
|
# 1080p should be higher
|
||||||
|
vram_1080 = backend.estimate_vram_usage(1920, 1080, 81)
|
||||||
|
assert vram_1080 > vram
|
||||||
|
|
||||||
|
def test_estimate_vram_1_3b(self, config_1_3b):
|
||||||
|
"""Test VRAM estimation for 1.3B model."""
|
||||||
|
backend = WanBackend(config_1_3b)
|
||||||
|
|
||||||
|
# Should be less than 14B
|
||||||
|
vram = backend.estimate_vram_usage(1280, 720, 81)
|
||||||
|
assert vram < 12.0
|
||||||
|
|
||||||
|
def test_estimate_vram_scaling(self, config_14b):
|
||||||
|
"""Test VRAM scales with resolution and frames."""
|
||||||
|
backend = WanBackend(config_14b)
|
||||||
|
|
||||||
|
# Base
|
||||||
|
vram_720_81 = backend.estimate_vram_usage(1280, 720, 81)
|
||||||
|
|
||||||
|
# Double resolution (4x pixels)
|
||||||
|
vram_1440_81 = backend.estimate_vram_usage(1920, 1080, 81)
|
||||||
|
assert vram_1440_81 > vram_720_81 * 3 # Should be ~4x
|
||||||
|
|
||||||
|
# Double frames
|
||||||
|
vram_720_162 = backend.estimate_vram_usage(1280, 720, 162)
|
||||||
|
assert vram_720_162 > vram_720_81 * 1.5 # Should be ~2x
|
||||||
|
|
||||||
|
def test_generate_not_loaded(self, config_14b, tmp_path):
|
||||||
|
"""Test generation fails if model not loaded."""
|
||||||
|
backend = WanBackend(config_14b)
|
||||||
|
|
||||||
|
spec = GenerationSpec(
|
||||||
|
prompt="test",
|
||||||
|
negative_prompt="",
|
||||||
|
width=512,
|
||||||
|
height=512,
|
||||||
|
num_frames=16,
|
||||||
|
fps=8,
|
||||||
|
seed=42,
|
||||||
|
steps=10,
|
||||||
|
cfg_scale=6.0,
|
||||||
|
output_path=tmp_path / "test.mp4"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = backend.generate(spec)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "not loaded" in result.error_message.lower()
|
||||||
|
|
||||||
|
def test_concatenate_chunks_success(self, config_14b, tmp_path):
|
||||||
|
"""Test chunk concatenation."""
|
||||||
|
backend = WanBackend(config_14b)
|
||||||
|
|
||||||
|
# Create dummy chunk files
|
||||||
|
chunk_files = [
|
||||||
|
tmp_path / "chunk0.mp4",
|
||||||
|
tmp_path / "chunk1.mp4",
|
||||||
|
tmp_path / "chunk2.mp4"
|
||||||
|
]
|
||||||
|
|
||||||
|
for f in chunk_files:
|
||||||
|
f.write_text("dummy video data")
|
||||||
|
|
||||||
|
output_path = tmp_path / "output.mp4"
|
||||||
|
|
||||||
|
# Mock subprocess
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
mock_run.return_value = MagicMock(returncode=0)
|
||||||
|
|
||||||
|
result = backend._concatenate_chunks(chunk_files, output_path, 24)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_run.assert_called_once()
|
||||||
|
|
||||||
|
# Verify ffmpeg command structure
|
||||||
|
call_args = mock_run.call_args[0][0]
|
||||||
|
assert 'ffmpeg' in call_args
|
||||||
|
assert '-f' in call_args
|
||||||
|
assert 'concat' in call_args
|
||||||
|
|
||||||
|
def test_concatenate_chunks_failure(self, config_14b, tmp_path):
|
||||||
|
"""Test chunk concatenation failure."""
|
||||||
|
backend = WanBackend(config_14b)
|
||||||
|
|
||||||
|
chunk_files = [tmp_path / "chunk0.mp4"]
|
||||||
|
chunk_files[0].write_text("dummy")
|
||||||
|
|
||||||
|
output_path = tmp_path / "output.mp4"
|
||||||
|
|
||||||
|
# Mock subprocess failure
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
mock_run.return_value = MagicMock(returncode=1, stderr="ffmpeg error")
|
||||||
|
|
||||||
|
result = backend._concatenate_chunks(chunk_files, output_path, 24)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_concatenate_chunks_exception(self, config_14b, tmp_path):
|
||||||
|
"""Test chunk concatenation with exception."""
|
||||||
|
backend = WanBackend(config_14b)
|
||||||
|
|
||||||
|
chunk_files = [tmp_path / "chunk0.mp4"]
|
||||||
|
chunk_files[0].write_text("dummy")
|
||||||
|
|
||||||
|
output_path = tmp_path / "output.mp4"
|
||||||
|
|
||||||
|
# Mock subprocess exception
|
||||||
|
with patch('subprocess.run', side_effect=Exception("subprocess failed")):
|
||||||
|
result = backend._concatenate_chunks(chunk_files, output_path, 24)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_is_loaded_property(self, config_14b):
|
||||||
|
"""Test is_loaded property."""
|
||||||
|
backend = WanBackend(config_14b)
|
||||||
|
|
||||||
|
# Initially not loaded
|
||||||
|
assert backend.is_loaded is False
|
||||||
|
|
||||||
|
# Manually set (normally done by load())
|
||||||
|
backend._is_loaded = True
|
||||||
|
assert backend.is_loaded is True
|
||||||
Reference in New Issue
Block a user