diff --git a/AGENTS.MD b/AGENTS.MD index e33e6d8..a7671b9 100644 --- a/AGENTS.MD +++ b/AGENTS.MD @@ -11,13 +11,13 @@ The owner (user) is not an ML expert. The system must: --- ## 1) High-Level Goal -Build a local pipeline that converts **text-only storyboards** into **15–30 second videos** by: +Build a local pipeline that converts text-only storyboards into 15-30 second videos by: 1) converting storyboard -> shot plan 2) generating shot clips (T2V or I2V when possible) 3) assembling clips into a final MP4 4) upscaling to 2K/4K if desired -This is a **shot-based** system, not “one prompt makes a whole movie”. +This is a shot-based system, not "one prompt makes a whole movie". --- @@ -38,9 +38,9 @@ Design must be stable under 12GB VRAM using: --- ## 3) Output Targets (Realistic) -- Native generation: 720p–1080p (preferred) +- Native generation: 720p-1080p (preferred) - Final delivery: 1080p required; 2K/4K via upscaling -- Duration: 15–30s per video (may be segmented) +- Duration: 15-30s per video (may be segmented) - FPS: 24 default - Output: MP4 (H.264/H.265) @@ -50,13 +50,15 @@ Design must be stable under 12GB VRAM using: User has CUDA Toolkit 13.1 installed. Current PyTorch builds generally ship with and target CUDA 12.x runtimes. We must NOT assume PyTorch will build/run against local CUDA 13.1 toolkit. -**Plan:** -- Use **PyTorch prebuilt binaries that bundle CUDA runtime** (e.g., cu121 / cu124). +Plan: +- Use PyTorch prebuilt binaries that bundle CUDA runtime (cu121/cu124/cu128). - Rely on NVIDIA driver compatibility rather than local CUDA toolkit version. - Avoid compiling custom CUDA extensions unless necessary. Implementation notes: -- Prefer installing PyTorch via conda or pip using official CUDA 12.x builds. +- For RTX 5070 (sm_120), use CUDA 12.8 wheels via pip: + `pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128` +- Prefer conda for Python, ffmpeg, and general deps; use pip for torch if sm_120 support is required. - If xFormers causes build issues, use PyTorch SDPA and disable xFormers. --- @@ -64,12 +66,13 @@ Implementation notes: ## 5) Approved Stack (Do Not Deviate) ### Core - Python 3.10 or 3.11 (conda env) -- PyTorch (CUDA 12.x build: cu121 or cu124) +- PyTorch (CUDA 12.x build, cu121/cu124/cu128) - diffusers + transformers + accelerate + safetensors - ffmpeg for assembly - opencv-python for frame IO (if needed) - pydantic for config/schema validation - rich / loguru for logs +- ftfy for text normalization (required by WAN) ### Testing - pytest @@ -124,7 +127,7 @@ We will later build a utility script: - For each shot: generate clip - Support: - seed control - - chunking (e.g., generate 4–6 seconds then continue) + - chunking (e.g., generate 4-6 seconds then continue) - optional init frame handoff between shots ### D) Assembly @@ -149,7 +152,7 @@ For each shot and final render, save: - timing + VRAM notes if possible Every run produces a folder: -- outputs/// +- outputs// - shots/ - assembled/ - metadata/ @@ -166,7 +169,7 @@ Every run produces a folder: - assembly command lines are correct - metadata is generated correctly -Do not require “visual quality” assertions. Test structure and determinism. +Do not require visual quality assertions. Test structure and determinism. --- @@ -201,10 +204,10 @@ Required: --- ## 13) Definition of Done -A feature is “done” only if: +A feature is "done" only if: - implemented - tests added/updated - docs updated - reproducible install instructions remain valid -End of file. +End of file. \ No newline at end of file diff --git a/config/models.yaml b/config/models.yaml index 698d0f0..f034335 100644 --- a/config/models.yaml +++ b/config/models.yaml @@ -6,7 +6,7 @@ backends: wan_t2v_14b: name: "Wan 2.1 T2V 14B" class: "generation.backends.wan.WanBackend" - model_id: "Wan-AI/Wan2.1-T2V-14B" + model_id: "Wan-AI/Wan2.1-T2V-14B-Diffusers" vram_gb: 12 dtype: "fp16" enable_vae_slicing: true @@ -20,7 +20,7 @@ backends: wan_t2v_1_3b: name: "Wan 2.1 T2V 1.3B" class: "generation.backends.wan.WanBackend" - model_id: "Wan-AI/Wan2.1-T2V-1.3B" + model_id: "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" vram_gb: 8 dtype: "fp16" enable_vae_slicing: true diff --git a/environment.yml b/environment.yml index d0a2907..b755aa1 100644 --- a/environment.yml +++ b/environment.yml @@ -1,8 +1,6 @@ # environment.yml name: storyboard-video channels: - - pytorch - - nvidia - conda-forge - defaults dependencies: @@ -12,13 +10,13 @@ dependencies: # FFmpeg in env is easiest on Windows if you rely on conda-forge - ffmpeg=7.1 - # PyTorch with CUDA runtime bundled (NOT using local CUDA 13.1 toolkit) - # Using compatible versions for PyTorch 2.5.1 with CUDA 12.4 - - pytorch=2.5.1 - - pytorch-cuda=12.4 - - torchvision=0.20.1 - - torchaudio=2.5.1 - # pip deps (mirrors requirements.txt) - pip: + # RTX 50xx (sm_120) requires CUDA 12.8 wheels + - --index-url https://download.pytorch.org/whl/cu128 + - torch + - torchvision + - torchaudio - -r requirements.txt + - diffusers==0.36.0 + - ftfy==6.3.1 diff --git a/requirements.txt b/requirements.txt index b97ffe0..94d1754 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,11 @@ # requirements.txt # Install PyTorch separately (see environment.yml and docs). -diffusers==0.32.2 +diffusers==0.36.0 transformers==4.48.3 accelerate==1.3.0 safetensors==0.5.2 +ftfy==6.3.1 pydantic==2.10.6 pyyaml==6.0.2 diff --git a/src/cli/main.py b/src/cli/main.py index 2d88911..14a0037 100644 --- a/src/cli/main.py +++ b/src/cli/main.py @@ -2,34 +2,77 @@ CLI entry point for storyboard video generation. """ -import sys +import os + from pathlib import Path -from typing import Optional +from typing import Optional, Dict import typer from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn from rich.table import Table +import time from src.storyboard.loader import StoryboardValidator from src.storyboard.prompt_compiler import PromptCompiler from src.storyboard.shot_planner import ShotPlanner from src.core.config import ConfigLoader -from src.core.checkpoint import CheckpointManager +from src.core.checkpoint import CheckpointManager, ShotStatus, ProjectStatus from src.generation.base import BackendFactory from src.assembly.assembler import FFmpegAssembler, AssemblyConfig -from src.upscaling.upscaler import UpscaleManager, UpscaleConfig +from src.upscaling.upscaler import UpscaleManager, UpscaleConfig, UpscalerType, UpscaleFactor + +if "KMP_DUPLICATE_LIB_OK" not in os.environ: + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" +if "PYTORCH_ALLOC_CONF" not in os.environ: + os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" app = typer.Typer(help="Storyboard to Video Generation Pipeline") console = Console() +def align_resolution(width: int, height: int, multiple: int = 16) -> tuple[int, int]: + """Align resolution to model-required multiples.""" + aligned_width = ((width + multiple - 1) // multiple) * multiple + aligned_height = ((height + multiple - 1) // multiple) * multiple + return aligned_width, aligned_height + +def align_num_frames(num_frames: int) -> int: + """Align frame count so (num_frames - 1) is divisible by 4.""" + if (num_frames - 1) % 4 == 0: + return num_frames + return ((num_frames - 1 + 3) // 4) * 4 + 1 + +def cap_resolution_for_backend(backend_id: str, width: int, height: int) -> tuple[int, int]: + """Apply conservative resolution caps for smaller backends.""" + if backend_id == "wan_t2v_1_3b": + max_width = 1280 + max_height = 720 + if width > max_width or height > max_height: + return max_width, max_height + return width, height + +def cap_num_frames_for_backend(backend_id: str, num_frames: int) -> int: + """Apply conservative frame caps for smaller backends.""" + if backend_id == "wan_t2v_1_3b": + max_frames = 97 # (97 - 1) divisible by 4, ~4s at 24fps + if num_frames > max_frames: + return max_frames + return num_frames + +def format_duration(seconds: float) -> str: + """Format seconds into XmYYs.""" + total = int(seconds) + minutes = total // 60 + secs = total % 60 + return f"{minutes}m{secs:02d}s" + @app.command() def generate( storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"), output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"), config: Optional[Path] = typer.Option(None, "--config", "-c", help="Path to config file"), - backend: str = typer.Option("wan", "--backend", "-b", help="Generation backend (wan, svd)"), + backend: Optional[str] = typer.Option(None, "--backend", "-b", help="Generation backend (wan_t2v_14b, wan_t2v_1_3b, svd)"), resume: bool = typer.Option(False, "--resume", "-r", help="Resume from checkpoint"), skip_generation: bool = typer.Option(False, "--skip-generation", help="Skip generation, only assemble"), skip_assembly: bool = typer.Option(False, "--skip-assembly", help="Skip assembly, only generate shots"), @@ -37,12 +80,14 @@ def generate( dry_run: bool = typer.Option(False, "--dry-run", help="Validate storyboard without generating"), ): """Generate video from storyboard.""" - + + start_time = time.time() # Validate storyboard exists if not storyboard.exists(): console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]") + console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]") raise typer.Exit(1) - + # Load configuration try: if config: @@ -51,186 +96,250 @@ def generate( app_config = ConfigLoader.load() except Exception as e: console.print(f"[red]Error loading config: {e}[/red]") + console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]") raise typer.Exit(1) - + + # Normalize backend selection + backend_aliases: Dict[str, str] = { + "wan": "wan_t2v_14b", + "wan-1.3b": "wan_t2v_1_3b", + "wan_1_3b": "wan_t2v_1_3b", + } + backend_id = backend_aliases.get(backend, backend) if backend else app_config.active_backend + # Validate storyboard console.print("[bold blue]Validating storyboard...[/bold blue]") validator = StoryboardValidator() try: storyboard_data = validator.load(storyboard) - console.print(f"[green]✓[/green] Storyboard validated: {len(storyboard_data.shots)} shots") + console.print(f"[green]OK[/green] Storyboard validated: {len(storyboard_data.shots)} shots") except Exception as e: console.print(f"[red]Error validating storyboard: {e}[/red]") + console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]") raise typer.Exit(1) - + if dry_run: console.print("[yellow]Dry run complete. Exiting.[/yellow]") + console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]") return - + # Setup output directories project_dir = output / storyboard_data.project.title.replace(" ", "_") shots_dir = project_dir / "shots" assembled_dir = project_dir / "assembled" metadata_dir = project_dir / "metadata" - + for d in [shots_dir, assembled_dir, metadata_dir]: d.mkdir(parents=True, exist_ok=True) - + # Initialize checkpoint manager checkpoint_db = project_dir / "checkpoints.db" checkpoint_mgr = CheckpointManager(str(checkpoint_db)) - + project_id = storyboard_data.project.title + if checkpoint_mgr.get_project(project_id) is None: + checkpoint_mgr.create_project( + project_id=project_id, + storyboard_path=str(storyboard), + output_dir=str(project_dir), + backend_name=backend_id + ) + # Initialize components prompt_compiler = PromptCompiler(storyboard_data) - + # Get backend config for shot planner try: - backend_config = app_config.get_backend(backend) - max_chunk_duration = backend_config.max_chunk_seconds or 6.0 + backend_config = app_config.get_backend(backend_id) + max_chunk_seconds = backend_config.max_chunk_seconds or 6 except Exception: - max_chunk_duration = 6.0 - + max_chunk_seconds = 6 + shot_planner = ShotPlanner( fps=storyboard_data.project.fps or 24, - max_chunk_duration=max_chunk_duration + max_chunk_seconds=max_chunk_seconds ) - + # Initialize generation backend if not skip_generation: - console.print(f"[bold blue]Initializing {backend} backend...[/bold blue]") + console.print(f"[bold blue]Initializing {backend_id} backend...[/bold blue]") try: - backend_config = app_config.get_backend(backend) - video_backend = BackendFactory.create_backend(backend, backend_config) - console.print(f"[green]✓[/green] Backend initialized: {backend}") + backend_config = app_config.get_backend(backend_id) + backend_factory_name = "wan" if backend_id.startswith("wan") else backend_id + backend_config_dict = backend_config.__dict__.copy() + backend_config_dict["model_cache_dir"] = app_config.model_cache_dir + video_backend = BackendFactory.create(backend_factory_name, backend_config_dict) + video_backend.load() + console.print(f"[green]OK[/green] Backend initialized: {backend_id}") except Exception as e: - console.print(f"[red]Error initializing backend: {e}[/red]") - raise typer.Exit(1) - + import traceback + console.print(f"[red]Error initializing backend: {repr(e)}[/red]") + console.print(traceback.format_exc()) + console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]") + raise typer.Exit(1) + # Generate shots generated_shots = [] - + if not skip_generation: console.print(f"[bold blue]Generating {len(storyboard_data.shots)} shots...[/bold blue]") - + with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=console ) as progress: - - for i, shot in enumerate(storyboard_data.shots): + + for shot in storyboard_data.shots: task_id = progress.add_task(f"Shot {shot.id}", total=None) - + + # Ensure shot checkpoint exists + if checkpoint_mgr.get_shot(project_id, shot.id) is None: + checkpoint_mgr.create_shot(project_id, shot.id, ShotStatus.PENDING) + # Check checkpoint if resume: - checkpoint = checkpoint_mgr.get_shot_checkpoint( - storyboard_data.project.title, shot.id - ) - if checkpoint and checkpoint.status == "completed": - progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id} (cached)") - generated_shots.append(Path(checkpoint.video_path)) + checkpoint = checkpoint_mgr.get_shot(project_id, shot.id) + if checkpoint and checkpoint.status == ShotStatus.COMPLETED and checkpoint.output_path: + progress.update(task_id, description=f"[green]OK[/green] Shot {shot.id} (cached)") + generated_shots.append(Path(checkpoint.output_path)) continue - + # Compile prompt - prompt = prompt_compiler.compile_shot_prompt(shot) - negative_prompt = prompt_compiler.compile_negative_prompt(shot) - + compiled = prompt_compiler.compile_shot(shot) + prompt = compiled.positive + negative_prompt = compiled.negative + # Plan shot shot_plan = shot_planner.plan_shot(shot) - + num_frames = align_num_frames(shot_plan.total_frames) + if num_frames != shot_plan.total_frames: + console.print( + f"[yellow]Adjusting frames to {num_frames} " + f"(from {shot_plan.total_frames}) to meet model requirements.[/yellow]" + ) + capped_frames = cap_num_frames_for_backend(backend_id, num_frames) + if capped_frames != num_frames: + console.print( + f"[yellow]Capping frames to {capped_frames} " + f"(from {num_frames}) for {backend_id} VRAM safety.[/yellow]" + ) + num_frames = capped_frames + # Generate try: from src.generation.base import GenerationSpec - + + output_path = shots_dir / f"{shot.id}.mp4" + width = storyboard_data.project.resolution.width + height = storyboard_data.project.resolution.height + capped_width, capped_height = cap_resolution_for_backend(backend_id, width, height) + if capped_width != width or capped_height != height: + console.print( + f"[yellow]Capping resolution to {capped_width}x{capped_height} " + f"(from {width}x{height}) for {backend_id} VRAM safety.[/yellow]" + ) + aligned_width, aligned_height = align_resolution(capped_width, capped_height) + if aligned_width != width or aligned_height != height: + console.print( + f"[yellow]Adjusting resolution to {aligned_width}x{aligned_height} " + f"(from {width}x{height}) to meet model requirements.[/yellow]" + ) spec = GenerationSpec( prompt=prompt, negative_prompt=negative_prompt, - width=shot_plan.width, - height=shot_plan.height, - num_frames=shot_plan.total_frames, + width=aligned_width, + height=aligned_height, + num_frames=num_frames, fps=shot_plan.fps, - seed=shot.generation.seed + seed=shot.generation.seed, + steps=shot.generation.steps, + cfg_scale=shot.generation.cfg_scale, + output_path=output_path ) - - result = video_backend.generate(spec, output_dir=shots_dir) - + + checkpoint_mgr.update_shot(project_id, shot.id, status=ShotStatus.IN_PROGRESS) + result = video_backend.generate(spec) + if result.success: - generated_shots.append(result.video_path) - checkpoint_mgr.save_shot_checkpoint( - project_name=storyboard_data.project.title, - shot_id=shot.id, - status="completed", - video_path=str(result.video_path), + generated_shots.append(result.output_path) + checkpoint_mgr.update_shot( + project_id, + shot.id, + status=ShotStatus.COMPLETED, + output_path=str(result.output_path), metadata=result.metadata ) - progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id}") + progress.update(task_id, description=f"[green]OK[/green] Shot {shot.id}") else: - progress.update(task_id, description=f"[red]✗[/red] Shot {shot.id}: {result.error_message}") - checkpoint_mgr.save_shot_checkpoint( - project_name=storyboard_data.project.title, - shot_id=shot.id, - status="failed", + progress.update(task_id, description=f"[red]FAILED[/red] Shot {shot.id}: {result.error_message}") + checkpoint_mgr.update_shot( + project_id, + shot.id, + status=ShotStatus.FAILED, error_message=result.error_message ) - + except Exception as e: - progress.update(task_id, description=f"[red]✗[/red] Shot {shot.id}: {e}") - checkpoint_mgr.save_shot_checkpoint( - project_name=storyboard_data.project.title, - shot_id=shot.id, - status="failed", + progress.update(task_id, description=f"[red]FAILED[/red] Shot {shot.id}: {e}") + checkpoint_mgr.update_shot( + project_id, + shot.id, + status=ShotStatus.FAILED, error_message=str(e) ) - + try: + video_backend.unload() + except Exception: + pass + # Assembly if not skip_assembly and generated_shots: console.print("[bold blue]Assembling video...[/bold blue]") - + assembler = FFmpegAssembler() final_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}.mp4" - + assembly_config = AssemblyConfig( fps=storyboard_data.project.fps or 24, add_shot_labels=False ) - + result = assembler.assemble(generated_shots, final_output, assembly_config) - + if result.success: - console.print(f"[green]✓[/green] Video assembled: {final_output}") - + console.print(f"[green]OK[/green] Video assembled: {final_output}") + # Upscale if requested if upscale: console.print(f"[bold blue]Upscaling {upscale}x...[/bold blue]") - - upscale_config = UpscaleConfig( - factor=upscale, - upscaler_type="ffmpeg_sr" - ) - upscale_mgr = UpscaleManager(upscale_config) - - upscaled_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}_upscaled.mp4" - upscale_result = upscale_mgr.upscale(final_output, upscaled_output) - - if upscale_result.success: - console.print(f"[green]✓[/green] Upscaled video: {upscaled_output}") - final_output = upscaled_output + if upscale not in (2, 4): + console.print("[yellow]Warning: Upscale factor must be 2 or 4. Skipping upscaling.[/yellow]") else: - console.print(f"[yellow]Warning: Upscaling failed: {upscale_result.error_message}[/yellow]") - - # Save project metadata - checkpoint_mgr.save_project_checkpoint( - project_name=storyboard_data.project.title, - storyboard_path=str(storyboard), - output_path=str(final_output), - status="completed", - num_shots=len(generated_shots) - ) - + upscale_config = UpscaleConfig( + factor=UpscaleFactor.X2 if upscale == 2 else UpscaleFactor.X4, + upscaler_type=UpscalerType.FFMPEG_SR + ) + upscale_mgr = UpscaleManager(upscale_config) + + upscaled_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}_upscaled.mp4" + upscale_result = upscale_mgr.upscale(final_output, upscaled_output) + + if upscale_result.success: + console.print(f"[green]OK[/green] Upscaled video: {upscaled_output}") + final_output = upscaled_output + else: + console.print(f"[yellow]Warning: Upscaling failed: {upscale_result.error_message}[/yellow]") + + checkpoint_mgr.update_project_status(project_id, ProjectStatus.COMPLETED) + console.print(f"\n[bold green]Success![/bold green] Video saved to: {final_output}") + console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]") else: console.print(f"[red]Error assembling video: {result.error_message}[/red]") + console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]") raise typer.Exit(1) + else: + console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]") @app.command() @@ -239,34 +348,34 @@ def validate( verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed information"), ): """Validate a storyboard file.""" - + if not storyboard.exists(): console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]") raise typer.Exit(1) - + validator = StoryboardValidator() - + try: data = validator.load(storyboard) - - console.print(f"[green]✓[/green] Storyboard is valid") + + console.print("[green]OK[/green] Storyboard is valid") console.print(f"\n[bold]Title:[/bold] {data.project.title}") console.print(f"[bold]Shots:[/bold] {len(data.shots)}") console.print(f"[bold]FPS:[/bold] {data.project.fps or 'Not specified'}") resolution = f"{data.project.resolution.width}x{data.project.resolution.height}" if data.project.resolution else "Not specified" console.print(f"[bold]Resolution:[/bold] {resolution}") - + if verbose: table = Table(title="Shots") table.add_column("ID", style="cyan") table.add_column("Prompt", style="green") table.add_column("Duration", style="yellow") - + for shot in data.shots: table.add_row(shot.id, shot.prompt[:50], f"{shot.duration_s}s") - + console.print(table) - + except Exception as e: console.print(f"[red]Validation failed: {e}[/red]") raise typer.Exit(1) @@ -278,24 +387,24 @@ def resume( output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"), ): """Resume a failed or interrupted project.""" - + project_dir = output / project.replace(" ", "_") checkpoint_db = project_dir / "checkpoints.db" - + if not checkpoint_db.exists(): console.print(f"[red]Error: No checkpoint found for project '{project}'[/red]") raise typer.Exit(1) - + checkpoint_mgr = CheckpointManager(str(checkpoint_db)) - project_checkpoint = checkpoint_mgr.get_project_checkpoint(project) - + project_checkpoint = checkpoint_mgr.get_project(project) + if not project_checkpoint: - console.print(f"[red]Error: No project checkpoint found[/red]") + console.print("[red]Error: No project checkpoint found[/red]") raise typer.Exit(1) - + console.print(f"[bold blue]Resuming project:[/bold blue] {project}") console.print(f"Storyboard: {project_checkpoint.storyboard_path}") - + # Re-run generation with resume flag generate( storyboard=Path(project_checkpoint.storyboard_path), @@ -307,16 +416,16 @@ def resume( @app.command() def list_backends(): """List available generation backends.""" - + table = Table(title="Available Backends") table.add_column("Name", style="cyan") table.add_column("Type", style="green") table.add_column("Description", style="white") - - table.add_row("wan", "T2V", "WAN 2.x text-to-video") - table.add_row("wan-1.3b", "T2V", "WAN 1.3B (faster, lower quality)") + + table.add_row("wan_t2v_14b", "T2V", "WAN 2.x text-to-video (14B)") + table.add_row("wan_t2v_1_3b", "T2V", "WAN 1.3B (faster, lower quality)") table.add_row("svd", "I2V", "Stable Video Diffusion (fallback)") - + console.print(table) diff --git a/src/generation/backends/wan.py b/src/generation/backends/wan.py index ac1cd0f..35c5ad5 100644 --- a/src/generation/backends/wan.py +++ b/src/generation/backends/wan.py @@ -79,7 +79,8 @@ class WanBackend(BaseVideoBackend): return import torch - from diffusers import AutoencoderKLWan, WanPipeline + from diffusers import WanPipeline + from diffusers.models import AutoencoderKLWan print(f"Loading WAN model: {self.model_id}") @@ -92,13 +93,12 @@ class WanBackend(BaseVideoBackend): # Load VAE print("Loading VAE...") - vae_id = self.model_id.replace("-T2V-", "-VAE-") if "-T2V-" in self.model_id else self.model_id - self.vae = AutoencoderKLWan.from_pretrained( - vae_id, - subfolder="vae" if "-T2V-" in self.model_id else None, + self.model_id, + subfolder="vae", torch_dtype=self.dtype, - cache_dir=str(cache_dir) + cache_dir=str(cache_dir), + low_cpu_mem_usage=True ) # Load pipeline @@ -107,24 +107,30 @@ class WanBackend(BaseVideoBackend): self.model_id, vae=self.vae, torch_dtype=self.dtype, - cache_dir=str(cache_dir) + cache_dir=str(cache_dir), + low_cpu_mem_usage=True ) # Move to GPU - if torch.cuda.is_available(): - print("Moving to CUDA...") - self.pipeline = self.pipeline.to("cuda") - - # Enable memory optimizations - if self.enable_vae_slicing: - print("Enabling VAE slicing...") - self.pipeline.vae.enable_slicing() - - if self.enable_vae_tiling: - print("Enabling VAE tiling...") - self.pipeline.vae.enable_tiling() - else: - print("WARNING: CUDA not available, using CPU (will be very slow)") + if not torch.cuda.is_available(): + raise RuntimeError( + "CUDA is not available. This backend requires an NVIDIA GPU. " + "Install a CUDA-enabled PyTorch build (cu121/cu124) and ensure the " + "correct conda environment is active." + ) + + total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3) + print(f"Moving to CUDA... (GPU VRAM: {total_vram:.1f} GB)") + self.pipeline = self.pipeline.to("cuda") + + # Enable memory optimizations + if self.enable_vae_slicing: + print("Enabling VAE slicing...") + self.pipeline.vae.enable_slicing() + + if self.enable_vae_tiling: + print("Enabling VAE tiling...") + self.pipeline.vae.enable_tiling() self._is_loaded = True print("WAN model loaded successfully") diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 2f6b710..031609d 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -8,7 +8,14 @@ from pathlib import Path from unittest.mock import patch, MagicMock from typer.testing import CliRunner -from src.cli.main import app +from src.cli.main import ( + app, + align_resolution, + align_num_frames, + cap_resolution_for_backend, + cap_num_frames_for_backend, + format_duration, +) runner = CliRunner() @@ -61,8 +68,8 @@ class TestValidateCommand: result = runner.invoke(app, ["validate", str(storyboard_file)]) assert result.exit_code == 0 - assert "valid" in result.output - assert "Test Storyboard" in result.output + assert "valid" in result.output + assert "Test Storyboard" in result.output def test_validate_verbose(self): """Test validation with verbose flag.""" @@ -82,7 +89,7 @@ class TestListBackendsCommand: """Test listing available backends.""" result = runner.invoke(app, ["list-backends"]) assert result.exit_code == 0 - assert "wan" in result.output + assert "wan_t2v_14b" in result.output assert "Available Backends" in result.output @@ -140,7 +147,7 @@ class TestResumeCommand: with patch('src.cli.main.CheckpointManager') as mock_mgr: mock_checkpoint = MagicMock() mock_checkpoint.storyboard_path = str(Path(tmpdir) / "storyboard.json") - mock_mgr.return_value.get_project_checkpoint.return_value = mock_checkpoint + mock_mgr.return_value.get_project.return_value = mock_checkpoint # Create storyboard file storyboard_file = Path(tmpdir) / "storyboard.json" @@ -185,3 +192,52 @@ class TestCLIHelp: result = runner.invoke(app, ["resume", "--help"]) assert result.exit_code == 0 assert "Resume a failed or interrupted project" in result.output + + +class TestCLIAlignmentHelpers: + """Test CLI alignment helpers.""" + + def test_align_resolution_no_change(self): + """Resolution already aligned to 16.""" + width, height = align_resolution(1920, 1088) + assert width == 1920 + assert height == 1088 + + def test_align_resolution_adjusts(self): + """Resolution is rounded up to multiple of 16.""" + width, height = align_resolution(1920, 1080) + assert width == 1920 + assert height == 1088 + + def test_align_num_frames_no_change(self): + """Frame count already valid.""" + assert align_num_frames(97) == 97 # 97 - 1 is divisible by 4 + + def test_align_num_frames_adjusts(self): + """Frame count is rounded up to valid value.""" + assert align_num_frames(96) == 97 + + def test_cap_resolution_for_small_backend(self): + """Smaller backend caps resolution.""" + width, height = cap_resolution_for_backend("wan_t2v_1_3b", 1920, 1080) + assert width == 1280 + assert height == 720 + + def test_cap_resolution_for_other_backends(self): + """No cap for other backends.""" + width, height = cap_resolution_for_backend("wan_t2v_14b", 1920, 1080) + assert width == 1920 + assert height == 1080 + + def test_cap_num_frames_for_small_backend(self): + """Smaller backend caps frame count.""" + assert cap_num_frames_for_backend("wan_t2v_1_3b", 121) == 97 + + def test_cap_num_frames_for_other_backends(self): + """No cap for other backends.""" + assert cap_num_frames_for_backend("wan_t2v_14b", 121) == 121 + + def test_format_duration(self): + """Format duration for display.""" + assert format_duration(5) == "0m05s" + assert format_duration(65) == "1m05s"