adjust the resolution based on available VRAM. add elapsed time.

This commit is contained in:
2026-02-04 03:06:31 -05:00
parent 1252518832
commit c33c1d2f36
7 changed files with 347 additions and 174 deletions

View File

@@ -11,13 +11,13 @@ The owner (user) is not an ML expert. The system must:
--- ---
## 1) High-Level Goal ## 1) High-Level Goal
Build a local pipeline that converts **text-only storyboards** into **1530 second videos** by: Build a local pipeline that converts text-only storyboards into 15-30 second videos by:
1) converting storyboard -> shot plan 1) converting storyboard -> shot plan
2) generating shot clips (T2V or I2V when possible) 2) generating shot clips (T2V or I2V when possible)
3) assembling clips into a final MP4 3) assembling clips into a final MP4
4) upscaling to 2K/4K if desired 4) upscaling to 2K/4K if desired
This is a **shot-based** system, not one prompt makes a whole movie. This is a shot-based system, not "one prompt makes a whole movie".
--- ---
@@ -38,9 +38,9 @@ Design must be stable under 12GB VRAM using:
--- ---
## 3) Output Targets (Realistic) ## 3) Output Targets (Realistic)
- Native generation: 720p1080p (preferred) - Native generation: 720p-1080p (preferred)
- Final delivery: 1080p required; 2K/4K via upscaling - Final delivery: 1080p required; 2K/4K via upscaling
- Duration: 1530s per video (may be segmented) - Duration: 15-30s per video (may be segmented)
- FPS: 24 default - FPS: 24 default
- Output: MP4 (H.264/H.265) - Output: MP4 (H.264/H.265)
@@ -50,13 +50,15 @@ Design must be stable under 12GB VRAM using:
User has CUDA Toolkit 13.1 installed. Current PyTorch builds generally ship with and target CUDA 12.x runtimes. User has CUDA Toolkit 13.1 installed. Current PyTorch builds generally ship with and target CUDA 12.x runtimes.
We must NOT assume PyTorch will build/run against local CUDA 13.1 toolkit. We must NOT assume PyTorch will build/run against local CUDA 13.1 toolkit.
**Plan:** Plan:
- Use **PyTorch prebuilt binaries that bundle CUDA runtime** (e.g., cu121 / cu124). - Use PyTorch prebuilt binaries that bundle CUDA runtime (cu121/cu124/cu128).
- Rely on NVIDIA driver compatibility rather than local CUDA toolkit version. - Rely on NVIDIA driver compatibility rather than local CUDA toolkit version.
- Avoid compiling custom CUDA extensions unless necessary. - Avoid compiling custom CUDA extensions unless necessary.
Implementation notes: Implementation notes:
- Prefer installing PyTorch via conda or pip using official CUDA 12.x builds. - For RTX 5070 (sm_120), use CUDA 12.8 wheels via pip:
`pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128`
- Prefer conda for Python, ffmpeg, and general deps; use pip for torch if sm_120 support is required.
- If xFormers causes build issues, use PyTorch SDPA and disable xFormers. - If xFormers causes build issues, use PyTorch SDPA and disable xFormers.
--- ---
@@ -64,12 +66,13 @@ Implementation notes:
## 5) Approved Stack (Do Not Deviate) ## 5) Approved Stack (Do Not Deviate)
### Core ### Core
- Python 3.10 or 3.11 (conda env) - Python 3.10 or 3.11 (conda env)
- PyTorch (CUDA 12.x build: cu121 or cu124) - PyTorch (CUDA 12.x build, cu121/cu124/cu128)
- diffusers + transformers + accelerate + safetensors - diffusers + transformers + accelerate + safetensors
- ffmpeg for assembly - ffmpeg for assembly
- opencv-python for frame IO (if needed) - opencv-python for frame IO (if needed)
- pydantic for config/schema validation - pydantic for config/schema validation
- rich / loguru for logs - rich / loguru for logs
- ftfy for text normalization (required by WAN)
### Testing ### Testing
- pytest - pytest
@@ -124,7 +127,7 @@ We will later build a utility script:
- For each shot: generate clip - For each shot: generate clip
- Support: - Support:
- seed control - seed control
- chunking (e.g., generate 46 seconds then continue) - chunking (e.g., generate 4-6 seconds then continue)
- optional init frame handoff between shots - optional init frame handoff between shots
### D) Assembly ### D) Assembly
@@ -149,7 +152,7 @@ For each shot and final render, save:
- timing + VRAM notes if possible - timing + VRAM notes if possible
Every run produces a folder: Every run produces a folder:
- outputs/<project>/<timestamp>/ - outputs/<project>/
- shots/ - shots/
- assembled/ - assembled/
- metadata/ - metadata/
@@ -166,7 +169,7 @@ Every run produces a folder:
- assembly command lines are correct - assembly command lines are correct
- metadata is generated correctly - metadata is generated correctly
Do not require visual quality assertions. Test structure and determinism. Do not require visual quality assertions. Test structure and determinism.
--- ---
@@ -201,7 +204,7 @@ Required:
--- ---
## 13) Definition of Done ## 13) Definition of Done
A feature is done only if: A feature is "done" only if:
- implemented - implemented
- tests added/updated - tests added/updated
- docs updated - docs updated

View File

@@ -6,7 +6,7 @@ backends:
wan_t2v_14b: wan_t2v_14b:
name: "Wan 2.1 T2V 14B" name: "Wan 2.1 T2V 14B"
class: "generation.backends.wan.WanBackend" class: "generation.backends.wan.WanBackend"
model_id: "Wan-AI/Wan2.1-T2V-14B" model_id: "Wan-AI/Wan2.1-T2V-14B-Diffusers"
vram_gb: 12 vram_gb: 12
dtype: "fp16" dtype: "fp16"
enable_vae_slicing: true enable_vae_slicing: true
@@ -20,7 +20,7 @@ backends:
wan_t2v_1_3b: wan_t2v_1_3b:
name: "Wan 2.1 T2V 1.3B" name: "Wan 2.1 T2V 1.3B"
class: "generation.backends.wan.WanBackend" class: "generation.backends.wan.WanBackend"
model_id: "Wan-AI/Wan2.1-T2V-1.3B" model_id: "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vram_gb: 8 vram_gb: 8
dtype: "fp16" dtype: "fp16"
enable_vae_slicing: true enable_vae_slicing: true

View File

@@ -1,8 +1,6 @@
# environment.yml # environment.yml
name: storyboard-video name: storyboard-video
channels: channels:
- pytorch
- nvidia
- conda-forge - conda-forge
- defaults - defaults
dependencies: dependencies:
@@ -12,13 +10,13 @@ dependencies:
# FFmpeg in env is easiest on Windows if you rely on conda-forge # FFmpeg in env is easiest on Windows if you rely on conda-forge
- ffmpeg=7.1 - ffmpeg=7.1
# PyTorch with CUDA runtime bundled (NOT using local CUDA 13.1 toolkit)
# Using compatible versions for PyTorch 2.5.1 with CUDA 12.4
- pytorch=2.5.1
- pytorch-cuda=12.4
- torchvision=0.20.1
- torchaudio=2.5.1
# pip deps (mirrors requirements.txt) # pip deps (mirrors requirements.txt)
- pip: - pip:
# RTX 50xx (sm_120) requires CUDA 12.8 wheels
- --index-url https://download.pytorch.org/whl/cu128
- torch
- torchvision
- torchaudio
- -r requirements.txt - -r requirements.txt
- diffusers==0.36.0
- ftfy==6.3.1

View File

@@ -1,10 +1,11 @@
# requirements.txt # requirements.txt
# Install PyTorch separately (see environment.yml and docs). # Install PyTorch separately (see environment.yml and docs).
diffusers==0.32.2 diffusers==0.36.0
transformers==4.48.3 transformers==4.48.3
accelerate==1.3.0 accelerate==1.3.0
safetensors==0.5.2 safetensors==0.5.2
ftfy==6.3.1
pydantic==2.10.6 pydantic==2.10.6
pyyaml==6.0.2 pyyaml==6.0.2

View File

@@ -2,34 +2,77 @@
CLI entry point for storyboard video generation. CLI entry point for storyboard video generation.
""" """
import sys import os
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Dict
import typer import typer
from rich.console import Console from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.table import Table from rich.table import Table
import time
from src.storyboard.loader import StoryboardValidator from src.storyboard.loader import StoryboardValidator
from src.storyboard.prompt_compiler import PromptCompiler from src.storyboard.prompt_compiler import PromptCompiler
from src.storyboard.shot_planner import ShotPlanner from src.storyboard.shot_planner import ShotPlanner
from src.core.config import ConfigLoader from src.core.config import ConfigLoader
from src.core.checkpoint import CheckpointManager from src.core.checkpoint import CheckpointManager, ShotStatus, ProjectStatus
from src.generation.base import BackendFactory from src.generation.base import BackendFactory
from src.assembly.assembler import FFmpegAssembler, AssemblyConfig from src.assembly.assembler import FFmpegAssembler, AssemblyConfig
from src.upscaling.upscaler import UpscaleManager, UpscaleConfig from src.upscaling.upscaler import UpscaleManager, UpscaleConfig, UpscalerType, UpscaleFactor
if "KMP_DUPLICATE_LIB_OK" not in os.environ:
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
if "PYTORCH_ALLOC_CONF" not in os.environ:
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
app = typer.Typer(help="Storyboard to Video Generation Pipeline") app = typer.Typer(help="Storyboard to Video Generation Pipeline")
console = Console() console = Console()
def align_resolution(width: int, height: int, multiple: int = 16) -> tuple[int, int]:
"""Align resolution to model-required multiples."""
aligned_width = ((width + multiple - 1) // multiple) * multiple
aligned_height = ((height + multiple - 1) // multiple) * multiple
return aligned_width, aligned_height
def align_num_frames(num_frames: int) -> int:
"""Align frame count so (num_frames - 1) is divisible by 4."""
if (num_frames - 1) % 4 == 0:
return num_frames
return ((num_frames - 1 + 3) // 4) * 4 + 1
def cap_resolution_for_backend(backend_id: str, width: int, height: int) -> tuple[int, int]:
"""Apply conservative resolution caps for smaller backends."""
if backend_id == "wan_t2v_1_3b":
max_width = 1280
max_height = 720
if width > max_width or height > max_height:
return max_width, max_height
return width, height
def cap_num_frames_for_backend(backend_id: str, num_frames: int) -> int:
"""Apply conservative frame caps for smaller backends."""
if backend_id == "wan_t2v_1_3b":
max_frames = 97 # (97 - 1) divisible by 4, ~4s at 24fps
if num_frames > max_frames:
return max_frames
return num_frames
def format_duration(seconds: float) -> str:
"""Format seconds into XmYYs."""
total = int(seconds)
minutes = total // 60
secs = total % 60
return f"{minutes}m{secs:02d}s"
@app.command() @app.command()
def generate( def generate(
storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"), storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"),
output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"), output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"),
config: Optional[Path] = typer.Option(None, "--config", "-c", help="Path to config file"), config: Optional[Path] = typer.Option(None, "--config", "-c", help="Path to config file"),
backend: str = typer.Option("wan", "--backend", "-b", help="Generation backend (wan, svd)"), backend: Optional[str] = typer.Option(None, "--backend", "-b", help="Generation backend (wan_t2v_14b, wan_t2v_1_3b, svd)"),
resume: bool = typer.Option(False, "--resume", "-r", help="Resume from checkpoint"), resume: bool = typer.Option(False, "--resume", "-r", help="Resume from checkpoint"),
skip_generation: bool = typer.Option(False, "--skip-generation", help="Skip generation, only assemble"), skip_generation: bool = typer.Option(False, "--skip-generation", help="Skip generation, only assemble"),
skip_assembly: bool = typer.Option(False, "--skip-assembly", help="Skip assembly, only generate shots"), skip_assembly: bool = typer.Option(False, "--skip-assembly", help="Skip assembly, only generate shots"),
@@ -38,9 +81,11 @@ def generate(
): ):
"""Generate video from storyboard.""" """Generate video from storyboard."""
start_time = time.time()
# Validate storyboard exists # Validate storyboard exists
if not storyboard.exists(): if not storyboard.exists():
console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]") console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]")
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
raise typer.Exit(1) raise typer.Exit(1)
# Load configuration # Load configuration
@@ -51,20 +96,31 @@ def generate(
app_config = ConfigLoader.load() app_config = ConfigLoader.load()
except Exception as e: except Exception as e:
console.print(f"[red]Error loading config: {e}[/red]") console.print(f"[red]Error loading config: {e}[/red]")
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
raise typer.Exit(1) raise typer.Exit(1)
# Normalize backend selection
backend_aliases: Dict[str, str] = {
"wan": "wan_t2v_14b",
"wan-1.3b": "wan_t2v_1_3b",
"wan_1_3b": "wan_t2v_1_3b",
}
backend_id = backend_aliases.get(backend, backend) if backend else app_config.active_backend
# Validate storyboard # Validate storyboard
console.print("[bold blue]Validating storyboard...[/bold blue]") console.print("[bold blue]Validating storyboard...[/bold blue]")
validator = StoryboardValidator() validator = StoryboardValidator()
try: try:
storyboard_data = validator.load(storyboard) storyboard_data = validator.load(storyboard)
console.print(f"[green][/green] Storyboard validated: {len(storyboard_data.shots)} shots") console.print(f"[green]OK[/green] Storyboard validated: {len(storyboard_data.shots)} shots")
except Exception as e: except Exception as e:
console.print(f"[red]Error validating storyboard: {e}[/red]") console.print(f"[red]Error validating storyboard: {e}[/red]")
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
raise typer.Exit(1) raise typer.Exit(1)
if dry_run: if dry_run:
console.print("[yellow]Dry run complete. Exiting.[/yellow]") console.print("[yellow]Dry run complete. Exiting.[/yellow]")
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
return return
# Setup output directories # Setup output directories
@@ -79,31 +135,46 @@ def generate(
# Initialize checkpoint manager # Initialize checkpoint manager
checkpoint_db = project_dir / "checkpoints.db" checkpoint_db = project_dir / "checkpoints.db"
checkpoint_mgr = CheckpointManager(str(checkpoint_db)) checkpoint_mgr = CheckpointManager(str(checkpoint_db))
project_id = storyboard_data.project.title
if checkpoint_mgr.get_project(project_id) is None:
checkpoint_mgr.create_project(
project_id=project_id,
storyboard_path=str(storyboard),
output_dir=str(project_dir),
backend_name=backend_id
)
# Initialize components # Initialize components
prompt_compiler = PromptCompiler(storyboard_data) prompt_compiler = PromptCompiler(storyboard_data)
# Get backend config for shot planner # Get backend config for shot planner
try: try:
backend_config = app_config.get_backend(backend) backend_config = app_config.get_backend(backend_id)
max_chunk_duration = backend_config.max_chunk_seconds or 6.0 max_chunk_seconds = backend_config.max_chunk_seconds or 6
except Exception: except Exception:
max_chunk_duration = 6.0 max_chunk_seconds = 6
shot_planner = ShotPlanner( shot_planner = ShotPlanner(
fps=storyboard_data.project.fps or 24, fps=storyboard_data.project.fps or 24,
max_chunk_duration=max_chunk_duration max_chunk_seconds=max_chunk_seconds
) )
# Initialize generation backend # Initialize generation backend
if not skip_generation: if not skip_generation:
console.print(f"[bold blue]Initializing {backend} backend...[/bold blue]") console.print(f"[bold blue]Initializing {backend_id} backend...[/bold blue]")
try: try:
backend_config = app_config.get_backend(backend) backend_config = app_config.get_backend(backend_id)
video_backend = BackendFactory.create_backend(backend, backend_config) backend_factory_name = "wan" if backend_id.startswith("wan") else backend_id
console.print(f"[green]✓[/green] Backend initialized: {backend}") backend_config_dict = backend_config.__dict__.copy()
backend_config_dict["model_cache_dir"] = app_config.model_cache_dir
video_backend = BackendFactory.create(backend_factory_name, backend_config_dict)
video_backend.load()
console.print(f"[green]OK[/green] Backend initialized: {backend_id}")
except Exception as e: except Exception as e:
console.print(f"[red]Error initializing backend: {e}[/red]") import traceback
console.print(f"[red]Error initializing backend: {repr(e)}[/red]")
console.print(traceback.format_exc())
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
raise typer.Exit(1) raise typer.Exit(1)
# Generate shots # Generate shots
@@ -118,69 +189,108 @@ def generate(
console=console console=console
) as progress: ) as progress:
for i, shot in enumerate(storyboard_data.shots): for shot in storyboard_data.shots:
task_id = progress.add_task(f"Shot {shot.id}", total=None) task_id = progress.add_task(f"Shot {shot.id}", total=None)
# Ensure shot checkpoint exists
if checkpoint_mgr.get_shot(project_id, shot.id) is None:
checkpoint_mgr.create_shot(project_id, shot.id, ShotStatus.PENDING)
# Check checkpoint # Check checkpoint
if resume: if resume:
checkpoint = checkpoint_mgr.get_shot_checkpoint( checkpoint = checkpoint_mgr.get_shot(project_id, shot.id)
storyboard_data.project.title, shot.id if checkpoint and checkpoint.status == ShotStatus.COMPLETED and checkpoint.output_path:
) progress.update(task_id, description=f"[green]OK[/green] Shot {shot.id} (cached)")
if checkpoint and checkpoint.status == "completed": generated_shots.append(Path(checkpoint.output_path))
progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id} (cached)")
generated_shots.append(Path(checkpoint.video_path))
continue continue
# Compile prompt # Compile prompt
prompt = prompt_compiler.compile_shot_prompt(shot) compiled = prompt_compiler.compile_shot(shot)
negative_prompt = prompt_compiler.compile_negative_prompt(shot) prompt = compiled.positive
negative_prompt = compiled.negative
# Plan shot # Plan shot
shot_plan = shot_planner.plan_shot(shot) shot_plan = shot_planner.plan_shot(shot)
num_frames = align_num_frames(shot_plan.total_frames)
if num_frames != shot_plan.total_frames:
console.print(
f"[yellow]Adjusting frames to {num_frames} "
f"(from {shot_plan.total_frames}) to meet model requirements.[/yellow]"
)
capped_frames = cap_num_frames_for_backend(backend_id, num_frames)
if capped_frames != num_frames:
console.print(
f"[yellow]Capping frames to {capped_frames} "
f"(from {num_frames}) for {backend_id} VRAM safety.[/yellow]"
)
num_frames = capped_frames
# Generate # Generate
try: try:
from src.generation.base import GenerationSpec from src.generation.base import GenerationSpec
output_path = shots_dir / f"{shot.id}.mp4"
width = storyboard_data.project.resolution.width
height = storyboard_data.project.resolution.height
capped_width, capped_height = cap_resolution_for_backend(backend_id, width, height)
if capped_width != width or capped_height != height:
console.print(
f"[yellow]Capping resolution to {capped_width}x{capped_height} "
f"(from {width}x{height}) for {backend_id} VRAM safety.[/yellow]"
)
aligned_width, aligned_height = align_resolution(capped_width, capped_height)
if aligned_width != width or aligned_height != height:
console.print(
f"[yellow]Adjusting resolution to {aligned_width}x{aligned_height} "
f"(from {width}x{height}) to meet model requirements.[/yellow]"
)
spec = GenerationSpec( spec = GenerationSpec(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
width=shot_plan.width, width=aligned_width,
height=shot_plan.height, height=aligned_height,
num_frames=shot_plan.total_frames, num_frames=num_frames,
fps=shot_plan.fps, fps=shot_plan.fps,
seed=shot.generation.seed seed=shot.generation.seed,
steps=shot.generation.steps,
cfg_scale=shot.generation.cfg_scale,
output_path=output_path
) )
result = video_backend.generate(spec, output_dir=shots_dir) checkpoint_mgr.update_shot(project_id, shot.id, status=ShotStatus.IN_PROGRESS)
result = video_backend.generate(spec)
if result.success: if result.success:
generated_shots.append(result.video_path) generated_shots.append(result.output_path)
checkpoint_mgr.save_shot_checkpoint( checkpoint_mgr.update_shot(
project_name=storyboard_data.project.title, project_id,
shot_id=shot.id, shot.id,
status="completed", status=ShotStatus.COMPLETED,
video_path=str(result.video_path), output_path=str(result.output_path),
metadata=result.metadata metadata=result.metadata
) )
progress.update(task_id, description=f"[green][/green] Shot {shot.id}") progress.update(task_id, description=f"[green]OK[/green] Shot {shot.id}")
else: else:
progress.update(task_id, description=f"[red][/red] Shot {shot.id}: {result.error_message}") progress.update(task_id, description=f"[red]FAILED[/red] Shot {shot.id}: {result.error_message}")
checkpoint_mgr.save_shot_checkpoint( checkpoint_mgr.update_shot(
project_name=storyboard_data.project.title, project_id,
shot_id=shot.id, shot.id,
status="failed", status=ShotStatus.FAILED,
error_message=result.error_message error_message=result.error_message
) )
except Exception as e: except Exception as e:
progress.update(task_id, description=f"[red][/red] Shot {shot.id}: {e}") progress.update(task_id, description=f"[red]FAILED[/red] Shot {shot.id}: {e}")
checkpoint_mgr.save_shot_checkpoint( checkpoint_mgr.update_shot(
project_name=storyboard_data.project.title, project_id,
shot_id=shot.id, shot.id,
status="failed", status=ShotStatus.FAILED,
error_message=str(e) error_message=str(e)
) )
try:
video_backend.unload()
except Exception:
pass
# Assembly # Assembly
if not skip_assembly and generated_shots: if not skip_assembly and generated_shots:
@@ -197,15 +307,17 @@ def generate(
result = assembler.assemble(generated_shots, final_output, assembly_config) result = assembler.assemble(generated_shots, final_output, assembly_config)
if result.success: if result.success:
console.print(f"[green][/green] Video assembled: {final_output}") console.print(f"[green]OK[/green] Video assembled: {final_output}")
# Upscale if requested # Upscale if requested
if upscale: if upscale:
console.print(f"[bold blue]Upscaling {upscale}x...[/bold blue]") console.print(f"[bold blue]Upscaling {upscale}x...[/bold blue]")
if upscale not in (2, 4):
console.print("[yellow]Warning: Upscale factor must be 2 or 4. Skipping upscaling.[/yellow]")
else:
upscale_config = UpscaleConfig( upscale_config = UpscaleConfig(
factor=upscale, factor=UpscaleFactor.X2 if upscale == 2 else UpscaleFactor.X4,
upscaler_type="ffmpeg_sr" upscaler_type=UpscalerType.FFMPEG_SR
) )
upscale_mgr = UpscaleManager(upscale_config) upscale_mgr = UpscaleManager(upscale_config)
@@ -213,24 +325,21 @@ def generate(
upscale_result = upscale_mgr.upscale(final_output, upscaled_output) upscale_result = upscale_mgr.upscale(final_output, upscaled_output)
if upscale_result.success: if upscale_result.success:
console.print(f"[green][/green] Upscaled video: {upscaled_output}") console.print(f"[green]OK[/green] Upscaled video: {upscaled_output}")
final_output = upscaled_output final_output = upscaled_output
else: else:
console.print(f"[yellow]Warning: Upscaling failed: {upscale_result.error_message}[/yellow]") console.print(f"[yellow]Warning: Upscaling failed: {upscale_result.error_message}[/yellow]")
# Save project metadata checkpoint_mgr.update_project_status(project_id, ProjectStatus.COMPLETED)
checkpoint_mgr.save_project_checkpoint(
project_name=storyboard_data.project.title,
storyboard_path=str(storyboard),
output_path=str(final_output),
status="completed",
num_shots=len(generated_shots)
)
console.print(f"\n[bold green]Success![/bold green] Video saved to: {final_output}") console.print(f"\n[bold green]Success![/bold green] Video saved to: {final_output}")
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
else: else:
console.print(f"[red]Error assembling video: {result.error_message}[/red]") console.print(f"[red]Error assembling video: {result.error_message}[/red]")
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
raise typer.Exit(1) raise typer.Exit(1)
else:
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
@app.command() @app.command()
@@ -249,7 +358,7 @@ def validate(
try: try:
data = validator.load(storyboard) data = validator.load(storyboard)
console.print(f"[green][/green] Storyboard is valid") console.print("[green]OK[/green] Storyboard is valid")
console.print(f"\n[bold]Title:[/bold] {data.project.title}") console.print(f"\n[bold]Title:[/bold] {data.project.title}")
console.print(f"[bold]Shots:[/bold] {len(data.shots)}") console.print(f"[bold]Shots:[/bold] {len(data.shots)}")
console.print(f"[bold]FPS:[/bold] {data.project.fps or 'Not specified'}") console.print(f"[bold]FPS:[/bold] {data.project.fps or 'Not specified'}")
@@ -287,10 +396,10 @@ def resume(
raise typer.Exit(1) raise typer.Exit(1)
checkpoint_mgr = CheckpointManager(str(checkpoint_db)) checkpoint_mgr = CheckpointManager(str(checkpoint_db))
project_checkpoint = checkpoint_mgr.get_project_checkpoint(project) project_checkpoint = checkpoint_mgr.get_project(project)
if not project_checkpoint: if not project_checkpoint:
console.print(f"[red]Error: No project checkpoint found[/red]") console.print("[red]Error: No project checkpoint found[/red]")
raise typer.Exit(1) raise typer.Exit(1)
console.print(f"[bold blue]Resuming project:[/bold blue] {project}") console.print(f"[bold blue]Resuming project:[/bold blue] {project}")
@@ -313,8 +422,8 @@ def list_backends():
table.add_column("Type", style="green") table.add_column("Type", style="green")
table.add_column("Description", style="white") table.add_column("Description", style="white")
table.add_row("wan", "T2V", "WAN 2.x text-to-video") table.add_row("wan_t2v_14b", "T2V", "WAN 2.x text-to-video (14B)")
table.add_row("wan-1.3b", "T2V", "WAN 1.3B (faster, lower quality)") table.add_row("wan_t2v_1_3b", "T2V", "WAN 1.3B (faster, lower quality)")
table.add_row("svd", "I2V", "Stable Video Diffusion (fallback)") table.add_row("svd", "I2V", "Stable Video Diffusion (fallback)")
console.print(table) console.print(table)

View File

@@ -79,7 +79,8 @@ class WanBackend(BaseVideoBackend):
return return
import torch import torch
from diffusers import AutoencoderKLWan, WanPipeline from diffusers import WanPipeline
from diffusers.models import AutoencoderKLWan
print(f"Loading WAN model: {self.model_id}") print(f"Loading WAN model: {self.model_id}")
@@ -92,13 +93,12 @@ class WanBackend(BaseVideoBackend):
# Load VAE # Load VAE
print("Loading VAE...") print("Loading VAE...")
vae_id = self.model_id.replace("-T2V-", "-VAE-") if "-T2V-" in self.model_id else self.model_id
self.vae = AutoencoderKLWan.from_pretrained( self.vae = AutoencoderKLWan.from_pretrained(
vae_id, self.model_id,
subfolder="vae" if "-T2V-" in self.model_id else None, subfolder="vae",
torch_dtype=self.dtype, torch_dtype=self.dtype,
cache_dir=str(cache_dir) cache_dir=str(cache_dir),
low_cpu_mem_usage=True
) )
# Load pipeline # Load pipeline
@@ -107,12 +107,20 @@ class WanBackend(BaseVideoBackend):
self.model_id, self.model_id,
vae=self.vae, vae=self.vae,
torch_dtype=self.dtype, torch_dtype=self.dtype,
cache_dir=str(cache_dir) cache_dir=str(cache_dir),
low_cpu_mem_usage=True
) )
# Move to GPU # Move to GPU
if torch.cuda.is_available(): if not torch.cuda.is_available():
print("Moving to CUDA...") raise RuntimeError(
"CUDA is not available. This backend requires an NVIDIA GPU. "
"Install a CUDA-enabled PyTorch build (cu121/cu124) and ensure the "
"correct conda environment is active."
)
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
print(f"Moving to CUDA... (GPU VRAM: {total_vram:.1f} GB)")
self.pipeline = self.pipeline.to("cuda") self.pipeline = self.pipeline.to("cuda")
# Enable memory optimizations # Enable memory optimizations
@@ -123,8 +131,6 @@ class WanBackend(BaseVideoBackend):
if self.enable_vae_tiling: if self.enable_vae_tiling:
print("Enabling VAE tiling...") print("Enabling VAE tiling...")
self.pipeline.vae.enable_tiling() self.pipeline.vae.enable_tiling()
else:
print("WARNING: CUDA not available, using CPU (will be very slow)")
self._is_loaded = True self._is_loaded = True
print("WAN model loaded successfully") print("WAN model loaded successfully")

View File

@@ -8,7 +8,14 @@ from pathlib import Path
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from typer.testing import CliRunner from typer.testing import CliRunner
from src.cli.main import app from src.cli.main import (
app,
align_resolution,
align_num_frames,
cap_resolution_for_backend,
cap_num_frames_for_backend,
format_duration,
)
runner = CliRunner() runner = CliRunner()
@@ -82,7 +89,7 @@ class TestListBackendsCommand:
"""Test listing available backends.""" """Test listing available backends."""
result = runner.invoke(app, ["list-backends"]) result = runner.invoke(app, ["list-backends"])
assert result.exit_code == 0 assert result.exit_code == 0
assert "wan" in result.output assert "wan_t2v_14b" in result.output
assert "Available Backends" in result.output assert "Available Backends" in result.output
@@ -140,7 +147,7 @@ class TestResumeCommand:
with patch('src.cli.main.CheckpointManager') as mock_mgr: with patch('src.cli.main.CheckpointManager') as mock_mgr:
mock_checkpoint = MagicMock() mock_checkpoint = MagicMock()
mock_checkpoint.storyboard_path = str(Path(tmpdir) / "storyboard.json") mock_checkpoint.storyboard_path = str(Path(tmpdir) / "storyboard.json")
mock_mgr.return_value.get_project_checkpoint.return_value = mock_checkpoint mock_mgr.return_value.get_project.return_value = mock_checkpoint
# Create storyboard file # Create storyboard file
storyboard_file = Path(tmpdir) / "storyboard.json" storyboard_file = Path(tmpdir) / "storyboard.json"
@@ -185,3 +192,52 @@ class TestCLIHelp:
result = runner.invoke(app, ["resume", "--help"]) result = runner.invoke(app, ["resume", "--help"])
assert result.exit_code == 0 assert result.exit_code == 0
assert "Resume a failed or interrupted project" in result.output assert "Resume a failed or interrupted project" in result.output
class TestCLIAlignmentHelpers:
"""Test CLI alignment helpers."""
def test_align_resolution_no_change(self):
"""Resolution already aligned to 16."""
width, height = align_resolution(1920, 1088)
assert width == 1920
assert height == 1088
def test_align_resolution_adjusts(self):
"""Resolution is rounded up to multiple of 16."""
width, height = align_resolution(1920, 1080)
assert width == 1920
assert height == 1088
def test_align_num_frames_no_change(self):
"""Frame count already valid."""
assert align_num_frames(97) == 97 # 97 - 1 is divisible by 4
def test_align_num_frames_adjusts(self):
"""Frame count is rounded up to valid value."""
assert align_num_frames(96) == 97
def test_cap_resolution_for_small_backend(self):
"""Smaller backend caps resolution."""
width, height = cap_resolution_for_backend("wan_t2v_1_3b", 1920, 1080)
assert width == 1280
assert height == 720
def test_cap_resolution_for_other_backends(self):
"""No cap for other backends."""
width, height = cap_resolution_for_backend("wan_t2v_14b", 1920, 1080)
assert width == 1920
assert height == 1080
def test_cap_num_frames_for_small_backend(self):
"""Smaller backend caps frame count."""
assert cap_num_frames_for_backend("wan_t2v_1_3b", 121) == 97
def test_cap_num_frames_for_other_backends(self):
"""No cap for other backends."""
assert cap_num_frames_for_backend("wan_t2v_14b", 121) == 121
def test_format_duration(self):
"""Format duration for display."""
assert format_duration(5) == "0m05s"
assert format_duration(65) == "1m05s"