adjust the resolution based on available VRAM. add elapsed time.
This commit is contained in:
27
AGENTS.MD
27
AGENTS.MD
@@ -11,13 +11,13 @@ The owner (user) is not an ML expert. The system must:
|
||||
---
|
||||
|
||||
## 1) High-Level Goal
|
||||
Build a local pipeline that converts **text-only storyboards** into **15–30 second videos** by:
|
||||
Build a local pipeline that converts text-only storyboards into 15-30 second videos by:
|
||||
1) converting storyboard -> shot plan
|
||||
2) generating shot clips (T2V or I2V when possible)
|
||||
3) assembling clips into a final MP4
|
||||
4) upscaling to 2K/4K if desired
|
||||
|
||||
This is a **shot-based** system, not “one prompt makes a whole movie”.
|
||||
This is a shot-based system, not "one prompt makes a whole movie".
|
||||
|
||||
---
|
||||
|
||||
@@ -38,9 +38,9 @@ Design must be stable under 12GB VRAM using:
|
||||
---
|
||||
|
||||
## 3) Output Targets (Realistic)
|
||||
- Native generation: 720p–1080p (preferred)
|
||||
- Native generation: 720p-1080p (preferred)
|
||||
- Final delivery: 1080p required; 2K/4K via upscaling
|
||||
- Duration: 15–30s per video (may be segmented)
|
||||
- Duration: 15-30s per video (may be segmented)
|
||||
- FPS: 24 default
|
||||
- Output: MP4 (H.264/H.265)
|
||||
|
||||
@@ -50,13 +50,15 @@ Design must be stable under 12GB VRAM using:
|
||||
User has CUDA Toolkit 13.1 installed. Current PyTorch builds generally ship with and target CUDA 12.x runtimes.
|
||||
We must NOT assume PyTorch will build/run against local CUDA 13.1 toolkit.
|
||||
|
||||
**Plan:**
|
||||
- Use **PyTorch prebuilt binaries that bundle CUDA runtime** (e.g., cu121 / cu124).
|
||||
Plan:
|
||||
- Use PyTorch prebuilt binaries that bundle CUDA runtime (cu121/cu124/cu128).
|
||||
- Rely on NVIDIA driver compatibility rather than local CUDA toolkit version.
|
||||
- Avoid compiling custom CUDA extensions unless necessary.
|
||||
|
||||
Implementation notes:
|
||||
- Prefer installing PyTorch via conda or pip using official CUDA 12.x builds.
|
||||
- For RTX 5070 (sm_120), use CUDA 12.8 wheels via pip:
|
||||
`pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128`
|
||||
- Prefer conda for Python, ffmpeg, and general deps; use pip for torch if sm_120 support is required.
|
||||
- If xFormers causes build issues, use PyTorch SDPA and disable xFormers.
|
||||
|
||||
---
|
||||
@@ -64,12 +66,13 @@ Implementation notes:
|
||||
## 5) Approved Stack (Do Not Deviate)
|
||||
### Core
|
||||
- Python 3.10 or 3.11 (conda env)
|
||||
- PyTorch (CUDA 12.x build: cu121 or cu124)
|
||||
- PyTorch (CUDA 12.x build, cu121/cu124/cu128)
|
||||
- diffusers + transformers + accelerate + safetensors
|
||||
- ffmpeg for assembly
|
||||
- opencv-python for frame IO (if needed)
|
||||
- pydantic for config/schema validation
|
||||
- rich / loguru for logs
|
||||
- ftfy for text normalization (required by WAN)
|
||||
|
||||
### Testing
|
||||
- pytest
|
||||
@@ -124,7 +127,7 @@ We will later build a utility script:
|
||||
- For each shot: generate clip
|
||||
- Support:
|
||||
- seed control
|
||||
- chunking (e.g., generate 4–6 seconds then continue)
|
||||
- chunking (e.g., generate 4-6 seconds then continue)
|
||||
- optional init frame handoff between shots
|
||||
|
||||
### D) Assembly
|
||||
@@ -149,7 +152,7 @@ For each shot and final render, save:
|
||||
- timing + VRAM notes if possible
|
||||
|
||||
Every run produces a folder:
|
||||
- outputs/<project>/<timestamp>/
|
||||
- outputs/<project>/
|
||||
- shots/
|
||||
- assembled/
|
||||
- metadata/
|
||||
@@ -166,7 +169,7 @@ Every run produces a folder:
|
||||
- assembly command lines are correct
|
||||
- metadata is generated correctly
|
||||
|
||||
Do not require “visual quality” assertions. Test structure and determinism.
|
||||
Do not require visual quality assertions. Test structure and determinism.
|
||||
|
||||
---
|
||||
|
||||
@@ -201,7 +204,7 @@ Required:
|
||||
---
|
||||
|
||||
## 13) Definition of Done
|
||||
A feature is “done” only if:
|
||||
A feature is "done" only if:
|
||||
- implemented
|
||||
- tests added/updated
|
||||
- docs updated
|
||||
|
||||
@@ -6,7 +6,7 @@ backends:
|
||||
wan_t2v_14b:
|
||||
name: "Wan 2.1 T2V 14B"
|
||||
class: "generation.backends.wan.WanBackend"
|
||||
model_id: "Wan-AI/Wan2.1-T2V-14B"
|
||||
model_id: "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
||||
vram_gb: 12
|
||||
dtype: "fp16"
|
||||
enable_vae_slicing: true
|
||||
@@ -20,7 +20,7 @@ backends:
|
||||
wan_t2v_1_3b:
|
||||
name: "Wan 2.1 T2V 1.3B"
|
||||
class: "generation.backends.wan.WanBackend"
|
||||
model_id: "Wan-AI/Wan2.1-T2V-1.3B"
|
||||
model_id: "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
||||
vram_gb: 8
|
||||
dtype: "fp16"
|
||||
enable_vae_slicing: true
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# environment.yml
|
||||
name: storyboard-video
|
||||
channels:
|
||||
- pytorch
|
||||
- nvidia
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
@@ -12,13 +10,13 @@ dependencies:
|
||||
# FFmpeg in env is easiest on Windows if you rely on conda-forge
|
||||
- ffmpeg=7.1
|
||||
|
||||
# PyTorch with CUDA runtime bundled (NOT using local CUDA 13.1 toolkit)
|
||||
# Using compatible versions for PyTorch 2.5.1 with CUDA 12.4
|
||||
- pytorch=2.5.1
|
||||
- pytorch-cuda=12.4
|
||||
- torchvision=0.20.1
|
||||
- torchaudio=2.5.1
|
||||
|
||||
# pip deps (mirrors requirements.txt)
|
||||
- pip:
|
||||
# RTX 50xx (sm_120) requires CUDA 12.8 wheels
|
||||
- --index-url https://download.pytorch.org/whl/cu128
|
||||
- torch
|
||||
- torchvision
|
||||
- torchaudio
|
||||
- -r requirements.txt
|
||||
- diffusers==0.36.0
|
||||
- ftfy==6.3.1
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# requirements.txt
|
||||
# Install PyTorch separately (see environment.yml and docs).
|
||||
|
||||
diffusers==0.32.2
|
||||
diffusers==0.36.0
|
||||
transformers==4.48.3
|
||||
accelerate==1.3.0
|
||||
safetensors==0.5.2
|
||||
ftfy==6.3.1
|
||||
|
||||
pydantic==2.10.6
|
||||
pyyaml==6.0.2
|
||||
|
||||
259
src/cli/main.py
259
src/cli/main.py
@@ -2,34 +2,77 @@
|
||||
CLI entry point for storyboard video generation.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||
from rich.table import Table
|
||||
import time
|
||||
|
||||
from src.storyboard.loader import StoryboardValidator
|
||||
from src.storyboard.prompt_compiler import PromptCompiler
|
||||
from src.storyboard.shot_planner import ShotPlanner
|
||||
from src.core.config import ConfigLoader
|
||||
from src.core.checkpoint import CheckpointManager
|
||||
from src.core.checkpoint import CheckpointManager, ShotStatus, ProjectStatus
|
||||
from src.generation.base import BackendFactory
|
||||
from src.assembly.assembler import FFmpegAssembler, AssemblyConfig
|
||||
from src.upscaling.upscaler import UpscaleManager, UpscaleConfig
|
||||
from src.upscaling.upscaler import UpscaleManager, UpscaleConfig, UpscalerType, UpscaleFactor
|
||||
|
||||
if "KMP_DUPLICATE_LIB_OK" not in os.environ:
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
if "PYTORCH_ALLOC_CONF" not in os.environ:
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
||||
app = typer.Typer(help="Storyboard to Video Generation Pipeline")
|
||||
console = Console()
|
||||
|
||||
def align_resolution(width: int, height: int, multiple: int = 16) -> tuple[int, int]:
|
||||
"""Align resolution to model-required multiples."""
|
||||
aligned_width = ((width + multiple - 1) // multiple) * multiple
|
||||
aligned_height = ((height + multiple - 1) // multiple) * multiple
|
||||
return aligned_width, aligned_height
|
||||
|
||||
def align_num_frames(num_frames: int) -> int:
|
||||
"""Align frame count so (num_frames - 1) is divisible by 4."""
|
||||
if (num_frames - 1) % 4 == 0:
|
||||
return num_frames
|
||||
return ((num_frames - 1 + 3) // 4) * 4 + 1
|
||||
|
||||
def cap_resolution_for_backend(backend_id: str, width: int, height: int) -> tuple[int, int]:
|
||||
"""Apply conservative resolution caps for smaller backends."""
|
||||
if backend_id == "wan_t2v_1_3b":
|
||||
max_width = 1280
|
||||
max_height = 720
|
||||
if width > max_width or height > max_height:
|
||||
return max_width, max_height
|
||||
return width, height
|
||||
|
||||
def cap_num_frames_for_backend(backend_id: str, num_frames: int) -> int:
|
||||
"""Apply conservative frame caps for smaller backends."""
|
||||
if backend_id == "wan_t2v_1_3b":
|
||||
max_frames = 97 # (97 - 1) divisible by 4, ~4s at 24fps
|
||||
if num_frames > max_frames:
|
||||
return max_frames
|
||||
return num_frames
|
||||
|
||||
def format_duration(seconds: float) -> str:
|
||||
"""Format seconds into XmYYs."""
|
||||
total = int(seconds)
|
||||
minutes = total // 60
|
||||
secs = total % 60
|
||||
return f"{minutes}m{secs:02d}s"
|
||||
|
||||
|
||||
@app.command()
|
||||
def generate(
|
||||
storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"),
|
||||
output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"),
|
||||
config: Optional[Path] = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||
backend: str = typer.Option("wan", "--backend", "-b", help="Generation backend (wan, svd)"),
|
||||
backend: Optional[str] = typer.Option(None, "--backend", "-b", help="Generation backend (wan_t2v_14b, wan_t2v_1_3b, svd)"),
|
||||
resume: bool = typer.Option(False, "--resume", "-r", help="Resume from checkpoint"),
|
||||
skip_generation: bool = typer.Option(False, "--skip-generation", help="Skip generation, only assemble"),
|
||||
skip_assembly: bool = typer.Option(False, "--skip-assembly", help="Skip assembly, only generate shots"),
|
||||
@@ -38,9 +81,11 @@ def generate(
|
||||
):
|
||||
"""Generate video from storyboard."""
|
||||
|
||||
start_time = time.time()
|
||||
# Validate storyboard exists
|
||||
if not storyboard.exists():
|
||||
console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]")
|
||||
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Load configuration
|
||||
@@ -51,20 +96,31 @@ def generate(
|
||||
app_config = ConfigLoader.load()
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error loading config: {e}[/red]")
|
||||
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Normalize backend selection
|
||||
backend_aliases: Dict[str, str] = {
|
||||
"wan": "wan_t2v_14b",
|
||||
"wan-1.3b": "wan_t2v_1_3b",
|
||||
"wan_1_3b": "wan_t2v_1_3b",
|
||||
}
|
||||
backend_id = backend_aliases.get(backend, backend) if backend else app_config.active_backend
|
||||
|
||||
# Validate storyboard
|
||||
console.print("[bold blue]Validating storyboard...[/bold blue]")
|
||||
validator = StoryboardValidator()
|
||||
try:
|
||||
storyboard_data = validator.load(storyboard)
|
||||
console.print(f"[green]✓[/green] Storyboard validated: {len(storyboard_data.shots)} shots")
|
||||
console.print(f"[green]OK[/green] Storyboard validated: {len(storyboard_data.shots)} shots")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error validating storyboard: {e}[/red]")
|
||||
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
if dry_run:
|
||||
console.print("[yellow]Dry run complete. Exiting.[/yellow]")
|
||||
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
|
||||
return
|
||||
|
||||
# Setup output directories
|
||||
@@ -79,32 +135,47 @@ def generate(
|
||||
# Initialize checkpoint manager
|
||||
checkpoint_db = project_dir / "checkpoints.db"
|
||||
checkpoint_mgr = CheckpointManager(str(checkpoint_db))
|
||||
project_id = storyboard_data.project.title
|
||||
if checkpoint_mgr.get_project(project_id) is None:
|
||||
checkpoint_mgr.create_project(
|
||||
project_id=project_id,
|
||||
storyboard_path=str(storyboard),
|
||||
output_dir=str(project_dir),
|
||||
backend_name=backend_id
|
||||
)
|
||||
|
||||
# Initialize components
|
||||
prompt_compiler = PromptCompiler(storyboard_data)
|
||||
|
||||
# Get backend config for shot planner
|
||||
try:
|
||||
backend_config = app_config.get_backend(backend)
|
||||
max_chunk_duration = backend_config.max_chunk_seconds or 6.0
|
||||
backend_config = app_config.get_backend(backend_id)
|
||||
max_chunk_seconds = backend_config.max_chunk_seconds or 6
|
||||
except Exception:
|
||||
max_chunk_duration = 6.0
|
||||
max_chunk_seconds = 6
|
||||
|
||||
shot_planner = ShotPlanner(
|
||||
fps=storyboard_data.project.fps or 24,
|
||||
max_chunk_duration=max_chunk_duration
|
||||
max_chunk_seconds=max_chunk_seconds
|
||||
)
|
||||
|
||||
# Initialize generation backend
|
||||
if not skip_generation:
|
||||
console.print(f"[bold blue]Initializing {backend} backend...[/bold blue]")
|
||||
console.print(f"[bold blue]Initializing {backend_id} backend...[/bold blue]")
|
||||
try:
|
||||
backend_config = app_config.get_backend(backend)
|
||||
video_backend = BackendFactory.create_backend(backend, backend_config)
|
||||
console.print(f"[green]✓[/green] Backend initialized: {backend}")
|
||||
backend_config = app_config.get_backend(backend_id)
|
||||
backend_factory_name = "wan" if backend_id.startswith("wan") else backend_id
|
||||
backend_config_dict = backend_config.__dict__.copy()
|
||||
backend_config_dict["model_cache_dir"] = app_config.model_cache_dir
|
||||
video_backend = BackendFactory.create(backend_factory_name, backend_config_dict)
|
||||
video_backend.load()
|
||||
console.print(f"[green]OK[/green] Backend initialized: {backend_id}")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error initializing backend: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
import traceback
|
||||
console.print(f"[red]Error initializing backend: {repr(e)}[/red]")
|
||||
console.print(traceback.format_exc())
|
||||
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Generate shots
|
||||
generated_shots = []
|
||||
@@ -118,69 +189,108 @@ def generate(
|
||||
console=console
|
||||
) as progress:
|
||||
|
||||
for i, shot in enumerate(storyboard_data.shots):
|
||||
for shot in storyboard_data.shots:
|
||||
task_id = progress.add_task(f"Shot {shot.id}", total=None)
|
||||
|
||||
# Ensure shot checkpoint exists
|
||||
if checkpoint_mgr.get_shot(project_id, shot.id) is None:
|
||||
checkpoint_mgr.create_shot(project_id, shot.id, ShotStatus.PENDING)
|
||||
|
||||
# Check checkpoint
|
||||
if resume:
|
||||
checkpoint = checkpoint_mgr.get_shot_checkpoint(
|
||||
storyboard_data.project.title, shot.id
|
||||
)
|
||||
if checkpoint and checkpoint.status == "completed":
|
||||
progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id} (cached)")
|
||||
generated_shots.append(Path(checkpoint.video_path))
|
||||
checkpoint = checkpoint_mgr.get_shot(project_id, shot.id)
|
||||
if checkpoint and checkpoint.status == ShotStatus.COMPLETED and checkpoint.output_path:
|
||||
progress.update(task_id, description=f"[green]OK[/green] Shot {shot.id} (cached)")
|
||||
generated_shots.append(Path(checkpoint.output_path))
|
||||
continue
|
||||
|
||||
# Compile prompt
|
||||
prompt = prompt_compiler.compile_shot_prompt(shot)
|
||||
negative_prompt = prompt_compiler.compile_negative_prompt(shot)
|
||||
compiled = prompt_compiler.compile_shot(shot)
|
||||
prompt = compiled.positive
|
||||
negative_prompt = compiled.negative
|
||||
|
||||
# Plan shot
|
||||
shot_plan = shot_planner.plan_shot(shot)
|
||||
num_frames = align_num_frames(shot_plan.total_frames)
|
||||
if num_frames != shot_plan.total_frames:
|
||||
console.print(
|
||||
f"[yellow]Adjusting frames to {num_frames} "
|
||||
f"(from {shot_plan.total_frames}) to meet model requirements.[/yellow]"
|
||||
)
|
||||
capped_frames = cap_num_frames_for_backend(backend_id, num_frames)
|
||||
if capped_frames != num_frames:
|
||||
console.print(
|
||||
f"[yellow]Capping frames to {capped_frames} "
|
||||
f"(from {num_frames}) for {backend_id} VRAM safety.[/yellow]"
|
||||
)
|
||||
num_frames = capped_frames
|
||||
|
||||
# Generate
|
||||
try:
|
||||
from src.generation.base import GenerationSpec
|
||||
|
||||
output_path = shots_dir / f"{shot.id}.mp4"
|
||||
width = storyboard_data.project.resolution.width
|
||||
height = storyboard_data.project.resolution.height
|
||||
capped_width, capped_height = cap_resolution_for_backend(backend_id, width, height)
|
||||
if capped_width != width or capped_height != height:
|
||||
console.print(
|
||||
f"[yellow]Capping resolution to {capped_width}x{capped_height} "
|
||||
f"(from {width}x{height}) for {backend_id} VRAM safety.[/yellow]"
|
||||
)
|
||||
aligned_width, aligned_height = align_resolution(capped_width, capped_height)
|
||||
if aligned_width != width or aligned_height != height:
|
||||
console.print(
|
||||
f"[yellow]Adjusting resolution to {aligned_width}x{aligned_height} "
|
||||
f"(from {width}x{height}) to meet model requirements.[/yellow]"
|
||||
)
|
||||
spec = GenerationSpec(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=shot_plan.width,
|
||||
height=shot_plan.height,
|
||||
num_frames=shot_plan.total_frames,
|
||||
width=aligned_width,
|
||||
height=aligned_height,
|
||||
num_frames=num_frames,
|
||||
fps=shot_plan.fps,
|
||||
seed=shot.generation.seed
|
||||
seed=shot.generation.seed,
|
||||
steps=shot.generation.steps,
|
||||
cfg_scale=shot.generation.cfg_scale,
|
||||
output_path=output_path
|
||||
)
|
||||
|
||||
result = video_backend.generate(spec, output_dir=shots_dir)
|
||||
checkpoint_mgr.update_shot(project_id, shot.id, status=ShotStatus.IN_PROGRESS)
|
||||
result = video_backend.generate(spec)
|
||||
|
||||
if result.success:
|
||||
generated_shots.append(result.video_path)
|
||||
checkpoint_mgr.save_shot_checkpoint(
|
||||
project_name=storyboard_data.project.title,
|
||||
shot_id=shot.id,
|
||||
status="completed",
|
||||
video_path=str(result.video_path),
|
||||
generated_shots.append(result.output_path)
|
||||
checkpoint_mgr.update_shot(
|
||||
project_id,
|
||||
shot.id,
|
||||
status=ShotStatus.COMPLETED,
|
||||
output_path=str(result.output_path),
|
||||
metadata=result.metadata
|
||||
)
|
||||
progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id}")
|
||||
progress.update(task_id, description=f"[green]OK[/green] Shot {shot.id}")
|
||||
else:
|
||||
progress.update(task_id, description=f"[red]✗[/red] Shot {shot.id}: {result.error_message}")
|
||||
checkpoint_mgr.save_shot_checkpoint(
|
||||
project_name=storyboard_data.project.title,
|
||||
shot_id=shot.id,
|
||||
status="failed",
|
||||
progress.update(task_id, description=f"[red]FAILED[/red] Shot {shot.id}: {result.error_message}")
|
||||
checkpoint_mgr.update_shot(
|
||||
project_id,
|
||||
shot.id,
|
||||
status=ShotStatus.FAILED,
|
||||
error_message=result.error_message
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
progress.update(task_id, description=f"[red]✗[/red] Shot {shot.id}: {e}")
|
||||
checkpoint_mgr.save_shot_checkpoint(
|
||||
project_name=storyboard_data.project.title,
|
||||
shot_id=shot.id,
|
||||
status="failed",
|
||||
progress.update(task_id, description=f"[red]FAILED[/red] Shot {shot.id}: {e}")
|
||||
checkpoint_mgr.update_shot(
|
||||
project_id,
|
||||
shot.id,
|
||||
status=ShotStatus.FAILED,
|
||||
error_message=str(e)
|
||||
)
|
||||
try:
|
||||
video_backend.unload()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Assembly
|
||||
if not skip_assembly and generated_shots:
|
||||
@@ -197,40 +307,39 @@ def generate(
|
||||
result = assembler.assemble(generated_shots, final_output, assembly_config)
|
||||
|
||||
if result.success:
|
||||
console.print(f"[green]✓[/green] Video assembled: {final_output}")
|
||||
console.print(f"[green]OK[/green] Video assembled: {final_output}")
|
||||
|
||||
# Upscale if requested
|
||||
if upscale:
|
||||
console.print(f"[bold blue]Upscaling {upscale}x...[/bold blue]")
|
||||
|
||||
upscale_config = UpscaleConfig(
|
||||
factor=upscale,
|
||||
upscaler_type="ffmpeg_sr"
|
||||
)
|
||||
upscale_mgr = UpscaleManager(upscale_config)
|
||||
|
||||
upscaled_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}_upscaled.mp4"
|
||||
upscale_result = upscale_mgr.upscale(final_output, upscaled_output)
|
||||
|
||||
if upscale_result.success:
|
||||
console.print(f"[green]✓[/green] Upscaled video: {upscaled_output}")
|
||||
final_output = upscaled_output
|
||||
if upscale not in (2, 4):
|
||||
console.print("[yellow]Warning: Upscale factor must be 2 or 4. Skipping upscaling.[/yellow]")
|
||||
else:
|
||||
console.print(f"[yellow]Warning: Upscaling failed: {upscale_result.error_message}[/yellow]")
|
||||
upscale_config = UpscaleConfig(
|
||||
factor=UpscaleFactor.X2 if upscale == 2 else UpscaleFactor.X4,
|
||||
upscaler_type=UpscalerType.FFMPEG_SR
|
||||
)
|
||||
upscale_mgr = UpscaleManager(upscale_config)
|
||||
|
||||
# Save project metadata
|
||||
checkpoint_mgr.save_project_checkpoint(
|
||||
project_name=storyboard_data.project.title,
|
||||
storyboard_path=str(storyboard),
|
||||
output_path=str(final_output),
|
||||
status="completed",
|
||||
num_shots=len(generated_shots)
|
||||
)
|
||||
upscaled_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}_upscaled.mp4"
|
||||
upscale_result = upscale_mgr.upscale(final_output, upscaled_output)
|
||||
|
||||
if upscale_result.success:
|
||||
console.print(f"[green]OK[/green] Upscaled video: {upscaled_output}")
|
||||
final_output = upscaled_output
|
||||
else:
|
||||
console.print(f"[yellow]Warning: Upscaling failed: {upscale_result.error_message}[/yellow]")
|
||||
|
||||
checkpoint_mgr.update_project_status(project_id, ProjectStatus.COMPLETED)
|
||||
|
||||
console.print(f"\n[bold green]Success![/bold green] Video saved to: {final_output}")
|
||||
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
|
||||
else:
|
||||
console.print(f"[red]Error assembling video: {result.error_message}[/red]")
|
||||
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -249,7 +358,7 @@ def validate(
|
||||
try:
|
||||
data = validator.load(storyboard)
|
||||
|
||||
console.print(f"[green]✓[/green] Storyboard is valid")
|
||||
console.print("[green]OK[/green] Storyboard is valid")
|
||||
console.print(f"\n[bold]Title:[/bold] {data.project.title}")
|
||||
console.print(f"[bold]Shots:[/bold] {len(data.shots)}")
|
||||
console.print(f"[bold]FPS:[/bold] {data.project.fps or 'Not specified'}")
|
||||
@@ -287,10 +396,10 @@ def resume(
|
||||
raise typer.Exit(1)
|
||||
|
||||
checkpoint_mgr = CheckpointManager(str(checkpoint_db))
|
||||
project_checkpoint = checkpoint_mgr.get_project_checkpoint(project)
|
||||
project_checkpoint = checkpoint_mgr.get_project(project)
|
||||
|
||||
if not project_checkpoint:
|
||||
console.print(f"[red]Error: No project checkpoint found[/red]")
|
||||
console.print("[red]Error: No project checkpoint found[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"[bold blue]Resuming project:[/bold blue] {project}")
|
||||
@@ -313,8 +422,8 @@ def list_backends():
|
||||
table.add_column("Type", style="green")
|
||||
table.add_column("Description", style="white")
|
||||
|
||||
table.add_row("wan", "T2V", "WAN 2.x text-to-video")
|
||||
table.add_row("wan-1.3b", "T2V", "WAN 1.3B (faster, lower quality)")
|
||||
table.add_row("wan_t2v_14b", "T2V", "WAN 2.x text-to-video (14B)")
|
||||
table.add_row("wan_t2v_1_3b", "T2V", "WAN 1.3B (faster, lower quality)")
|
||||
table.add_row("svd", "I2V", "Stable Video Diffusion (fallback)")
|
||||
|
||||
console.print(table)
|
||||
|
||||
@@ -79,7 +79,8 @@ class WanBackend(BaseVideoBackend):
|
||||
return
|
||||
|
||||
import torch
|
||||
from diffusers import AutoencoderKLWan, WanPipeline
|
||||
from diffusers import WanPipeline
|
||||
from diffusers.models import AutoencoderKLWan
|
||||
|
||||
print(f"Loading WAN model: {self.model_id}")
|
||||
|
||||
@@ -92,13 +93,12 @@ class WanBackend(BaseVideoBackend):
|
||||
|
||||
# Load VAE
|
||||
print("Loading VAE...")
|
||||
vae_id = self.model_id.replace("-T2V-", "-VAE-") if "-T2V-" in self.model_id else self.model_id
|
||||
|
||||
self.vae = AutoencoderKLWan.from_pretrained(
|
||||
vae_id,
|
||||
subfolder="vae" if "-T2V-" in self.model_id else None,
|
||||
self.model_id,
|
||||
subfolder="vae",
|
||||
torch_dtype=self.dtype,
|
||||
cache_dir=str(cache_dir)
|
||||
cache_dir=str(cache_dir),
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
# Load pipeline
|
||||
@@ -107,24 +107,30 @@ class WanBackend(BaseVideoBackend):
|
||||
self.model_id,
|
||||
vae=self.vae,
|
||||
torch_dtype=self.dtype,
|
||||
cache_dir=str(cache_dir)
|
||||
cache_dir=str(cache_dir),
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
# Move to GPU
|
||||
if torch.cuda.is_available():
|
||||
print("Moving to CUDA...")
|
||||
self.pipeline = self.pipeline.to("cuda")
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(
|
||||
"CUDA is not available. This backend requires an NVIDIA GPU. "
|
||||
"Install a CUDA-enabled PyTorch build (cu121/cu124) and ensure the "
|
||||
"correct conda environment is active."
|
||||
)
|
||||
|
||||
# Enable memory optimizations
|
||||
if self.enable_vae_slicing:
|
||||
print("Enabling VAE slicing...")
|
||||
self.pipeline.vae.enable_slicing()
|
||||
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
||||
print(f"Moving to CUDA... (GPU VRAM: {total_vram:.1f} GB)")
|
||||
self.pipeline = self.pipeline.to("cuda")
|
||||
|
||||
if self.enable_vae_tiling:
|
||||
print("Enabling VAE tiling...")
|
||||
self.pipeline.vae.enable_tiling()
|
||||
else:
|
||||
print("WARNING: CUDA not available, using CPU (will be very slow)")
|
||||
# Enable memory optimizations
|
||||
if self.enable_vae_slicing:
|
||||
print("Enabling VAE slicing...")
|
||||
self.pipeline.vae.enable_slicing()
|
||||
|
||||
if self.enable_vae_tiling:
|
||||
print("Enabling VAE tiling...")
|
||||
self.pipeline.vae.enable_tiling()
|
||||
|
||||
self._is_loaded = True
|
||||
print("WAN model loaded successfully")
|
||||
|
||||
@@ -8,7 +8,14 @@ from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from src.cli.main import app
|
||||
from src.cli.main import (
|
||||
app,
|
||||
align_resolution,
|
||||
align_num_frames,
|
||||
cap_resolution_for_backend,
|
||||
cap_num_frames_for_backend,
|
||||
format_duration,
|
||||
)
|
||||
|
||||
|
||||
runner = CliRunner()
|
||||
@@ -61,8 +68,8 @@ class TestValidateCommand:
|
||||
|
||||
result = runner.invoke(app, ["validate", str(storyboard_file)])
|
||||
assert result.exit_code == 0
|
||||
assert "valid" in result.output
|
||||
assert "Test Storyboard" in result.output
|
||||
assert "valid" in result.output
|
||||
assert "Test Storyboard" in result.output
|
||||
|
||||
def test_validate_verbose(self):
|
||||
"""Test validation with verbose flag."""
|
||||
@@ -82,7 +89,7 @@ class TestListBackendsCommand:
|
||||
"""Test listing available backends."""
|
||||
result = runner.invoke(app, ["list-backends"])
|
||||
assert result.exit_code == 0
|
||||
assert "wan" in result.output
|
||||
assert "wan_t2v_14b" in result.output
|
||||
assert "Available Backends" in result.output
|
||||
|
||||
|
||||
@@ -140,7 +147,7 @@ class TestResumeCommand:
|
||||
with patch('src.cli.main.CheckpointManager') as mock_mgr:
|
||||
mock_checkpoint = MagicMock()
|
||||
mock_checkpoint.storyboard_path = str(Path(tmpdir) / "storyboard.json")
|
||||
mock_mgr.return_value.get_project_checkpoint.return_value = mock_checkpoint
|
||||
mock_mgr.return_value.get_project.return_value = mock_checkpoint
|
||||
|
||||
# Create storyboard file
|
||||
storyboard_file = Path(tmpdir) / "storyboard.json"
|
||||
@@ -185,3 +192,52 @@ class TestCLIHelp:
|
||||
result = runner.invoke(app, ["resume", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "Resume a failed or interrupted project" in result.output
|
||||
|
||||
|
||||
class TestCLIAlignmentHelpers:
|
||||
"""Test CLI alignment helpers."""
|
||||
|
||||
def test_align_resolution_no_change(self):
|
||||
"""Resolution already aligned to 16."""
|
||||
width, height = align_resolution(1920, 1088)
|
||||
assert width == 1920
|
||||
assert height == 1088
|
||||
|
||||
def test_align_resolution_adjusts(self):
|
||||
"""Resolution is rounded up to multiple of 16."""
|
||||
width, height = align_resolution(1920, 1080)
|
||||
assert width == 1920
|
||||
assert height == 1088
|
||||
|
||||
def test_align_num_frames_no_change(self):
|
||||
"""Frame count already valid."""
|
||||
assert align_num_frames(97) == 97 # 97 - 1 is divisible by 4
|
||||
|
||||
def test_align_num_frames_adjusts(self):
|
||||
"""Frame count is rounded up to valid value."""
|
||||
assert align_num_frames(96) == 97
|
||||
|
||||
def test_cap_resolution_for_small_backend(self):
|
||||
"""Smaller backend caps resolution."""
|
||||
width, height = cap_resolution_for_backend("wan_t2v_1_3b", 1920, 1080)
|
||||
assert width == 1280
|
||||
assert height == 720
|
||||
|
||||
def test_cap_resolution_for_other_backends(self):
|
||||
"""No cap for other backends."""
|
||||
width, height = cap_resolution_for_backend("wan_t2v_14b", 1920, 1080)
|
||||
assert width == 1920
|
||||
assert height == 1080
|
||||
|
||||
def test_cap_num_frames_for_small_backend(self):
|
||||
"""Smaller backend caps frame count."""
|
||||
assert cap_num_frames_for_backend("wan_t2v_1_3b", 121) == 97
|
||||
|
||||
def test_cap_num_frames_for_other_backends(self):
|
||||
"""No cap for other backends."""
|
||||
assert cap_num_frames_for_backend("wan_t2v_14b", 121) == 121
|
||||
|
||||
def test_format_duration(self):
|
||||
"""Format duration for display."""
|
||||
assert format_duration(5) == "0m05s"
|
||||
assert format_duration(65) == "1m05s"
|
||||
|
||||
Reference in New Issue
Block a user