adjust the resolution based on available VRAM. add elapsed time.

This commit is contained in:
2026-02-04 03:06:31 -05:00
parent 1252518832
commit c33c1d2f36
7 changed files with 347 additions and 174 deletions

View File

@@ -11,13 +11,13 @@ The owner (user) is not an ML expert. The system must:
---
## 1) High-Level Goal
Build a local pipeline that converts **text-only storyboards** into **1530 second videos** by:
Build a local pipeline that converts text-only storyboards into 15-30 second videos by:
1) converting storyboard -> shot plan
2) generating shot clips (T2V or I2V when possible)
3) assembling clips into a final MP4
4) upscaling to 2K/4K if desired
This is a **shot-based** system, not one prompt makes a whole movie.
This is a shot-based system, not "one prompt makes a whole movie".
---
@@ -38,9 +38,9 @@ Design must be stable under 12GB VRAM using:
---
## 3) Output Targets (Realistic)
- Native generation: 720p1080p (preferred)
- Native generation: 720p-1080p (preferred)
- Final delivery: 1080p required; 2K/4K via upscaling
- Duration: 1530s per video (may be segmented)
- Duration: 15-30s per video (may be segmented)
- FPS: 24 default
- Output: MP4 (H.264/H.265)
@@ -50,13 +50,15 @@ Design must be stable under 12GB VRAM using:
User has CUDA Toolkit 13.1 installed. Current PyTorch builds generally ship with and target CUDA 12.x runtimes.
We must NOT assume PyTorch will build/run against local CUDA 13.1 toolkit.
**Plan:**
- Use **PyTorch prebuilt binaries that bundle CUDA runtime** (e.g., cu121 / cu124).
Plan:
- Use PyTorch prebuilt binaries that bundle CUDA runtime (cu121/cu124/cu128).
- Rely on NVIDIA driver compatibility rather than local CUDA toolkit version.
- Avoid compiling custom CUDA extensions unless necessary.
Implementation notes:
- Prefer installing PyTorch via conda or pip using official CUDA 12.x builds.
- For RTX 5070 (sm_120), use CUDA 12.8 wheels via pip:
`pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128`
- Prefer conda for Python, ffmpeg, and general deps; use pip for torch if sm_120 support is required.
- If xFormers causes build issues, use PyTorch SDPA and disable xFormers.
---
@@ -64,12 +66,13 @@ Implementation notes:
## 5) Approved Stack (Do Not Deviate)
### Core
- Python 3.10 or 3.11 (conda env)
- PyTorch (CUDA 12.x build: cu121 or cu124)
- PyTorch (CUDA 12.x build, cu121/cu124/cu128)
- diffusers + transformers + accelerate + safetensors
- ffmpeg for assembly
- opencv-python for frame IO (if needed)
- pydantic for config/schema validation
- rich / loguru for logs
- ftfy for text normalization (required by WAN)
### Testing
- pytest
@@ -124,7 +127,7 @@ We will later build a utility script:
- For each shot: generate clip
- Support:
- seed control
- chunking (e.g., generate 46 seconds then continue)
- chunking (e.g., generate 4-6 seconds then continue)
- optional init frame handoff between shots
### D) Assembly
@@ -149,7 +152,7 @@ For each shot and final render, save:
- timing + VRAM notes if possible
Every run produces a folder:
- outputs/<project>/<timestamp>/
- outputs/<project>/
- shots/
- assembled/
- metadata/
@@ -166,7 +169,7 @@ Every run produces a folder:
- assembly command lines are correct
- metadata is generated correctly
Do not require visual quality assertions. Test structure and determinism.
Do not require visual quality assertions. Test structure and determinism.
---
@@ -201,7 +204,7 @@ Required:
---
## 13) Definition of Done
A feature is done only if:
A feature is "done" only if:
- implemented
- tests added/updated
- docs updated

View File

@@ -6,7 +6,7 @@ backends:
wan_t2v_14b:
name: "Wan 2.1 T2V 14B"
class: "generation.backends.wan.WanBackend"
model_id: "Wan-AI/Wan2.1-T2V-14B"
model_id: "Wan-AI/Wan2.1-T2V-14B-Diffusers"
vram_gb: 12
dtype: "fp16"
enable_vae_slicing: true
@@ -20,7 +20,7 @@ backends:
wan_t2v_1_3b:
name: "Wan 2.1 T2V 1.3B"
class: "generation.backends.wan.WanBackend"
model_id: "Wan-AI/Wan2.1-T2V-1.3B"
model_id: "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vram_gb: 8
dtype: "fp16"
enable_vae_slicing: true

View File

@@ -1,8 +1,6 @@
# environment.yml
name: storyboard-video
channels:
- pytorch
- nvidia
- conda-forge
- defaults
dependencies:
@@ -12,13 +10,13 @@ dependencies:
# FFmpeg in env is easiest on Windows if you rely on conda-forge
- ffmpeg=7.1
# PyTorch with CUDA runtime bundled (NOT using local CUDA 13.1 toolkit)
# Using compatible versions for PyTorch 2.5.1 with CUDA 12.4
- pytorch=2.5.1
- pytorch-cuda=12.4
- torchvision=0.20.1
- torchaudio=2.5.1
# pip deps (mirrors requirements.txt)
- pip:
# RTX 50xx (sm_120) requires CUDA 12.8 wheels
- --index-url https://download.pytorch.org/whl/cu128
- torch
- torchvision
- torchaudio
- -r requirements.txt
- diffusers==0.36.0
- ftfy==6.3.1

View File

@@ -1,10 +1,11 @@
# requirements.txt
# Install PyTorch separately (see environment.yml and docs).
diffusers==0.32.2
diffusers==0.36.0
transformers==4.48.3
accelerate==1.3.0
safetensors==0.5.2
ftfy==6.3.1
pydantic==2.10.6
pyyaml==6.0.2

View File

@@ -2,34 +2,77 @@
CLI entry point for storyboard video generation.
"""
import sys
import os
from pathlib import Path
from typing import Optional
from typing import Optional, Dict
import typer
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.table import Table
import time
from src.storyboard.loader import StoryboardValidator
from src.storyboard.prompt_compiler import PromptCompiler
from src.storyboard.shot_planner import ShotPlanner
from src.core.config import ConfigLoader
from src.core.checkpoint import CheckpointManager
from src.core.checkpoint import CheckpointManager, ShotStatus, ProjectStatus
from src.generation.base import BackendFactory
from src.assembly.assembler import FFmpegAssembler, AssemblyConfig
from src.upscaling.upscaler import UpscaleManager, UpscaleConfig
from src.upscaling.upscaler import UpscaleManager, UpscaleConfig, UpscalerType, UpscaleFactor
if "KMP_DUPLICATE_LIB_OK" not in os.environ:
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
if "PYTORCH_ALLOC_CONF" not in os.environ:
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
app = typer.Typer(help="Storyboard to Video Generation Pipeline")
console = Console()
def align_resolution(width: int, height: int, multiple: int = 16) -> tuple[int, int]:
"""Align resolution to model-required multiples."""
aligned_width = ((width + multiple - 1) // multiple) * multiple
aligned_height = ((height + multiple - 1) // multiple) * multiple
return aligned_width, aligned_height
def align_num_frames(num_frames: int) -> int:
"""Align frame count so (num_frames - 1) is divisible by 4."""
if (num_frames - 1) % 4 == 0:
return num_frames
return ((num_frames - 1 + 3) // 4) * 4 + 1
def cap_resolution_for_backend(backend_id: str, width: int, height: int) -> tuple[int, int]:
"""Apply conservative resolution caps for smaller backends."""
if backend_id == "wan_t2v_1_3b":
max_width = 1280
max_height = 720
if width > max_width or height > max_height:
return max_width, max_height
return width, height
def cap_num_frames_for_backend(backend_id: str, num_frames: int) -> int:
"""Apply conservative frame caps for smaller backends."""
if backend_id == "wan_t2v_1_3b":
max_frames = 97 # (97 - 1) divisible by 4, ~4s at 24fps
if num_frames > max_frames:
return max_frames
return num_frames
def format_duration(seconds: float) -> str:
"""Format seconds into XmYYs."""
total = int(seconds)
minutes = total // 60
secs = total % 60
return f"{minutes}m{secs:02d}s"
@app.command()
def generate(
storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"),
output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"),
config: Optional[Path] = typer.Option(None, "--config", "-c", help="Path to config file"),
backend: str = typer.Option("wan", "--backend", "-b", help="Generation backend (wan, svd)"),
backend: Optional[str] = typer.Option(None, "--backend", "-b", help="Generation backend (wan_t2v_14b, wan_t2v_1_3b, svd)"),
resume: bool = typer.Option(False, "--resume", "-r", help="Resume from checkpoint"),
skip_generation: bool = typer.Option(False, "--skip-generation", help="Skip generation, only assemble"),
skip_assembly: bool = typer.Option(False, "--skip-assembly", help="Skip assembly, only generate shots"),
@@ -38,9 +81,11 @@ def generate(
):
"""Generate video from storyboard."""
start_time = time.time()
# Validate storyboard exists
if not storyboard.exists():
console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]")
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
raise typer.Exit(1)
# Load configuration
@@ -51,20 +96,31 @@ def generate(
app_config = ConfigLoader.load()
except Exception as e:
console.print(f"[red]Error loading config: {e}[/red]")
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
raise typer.Exit(1)
# Normalize backend selection
backend_aliases: Dict[str, str] = {
"wan": "wan_t2v_14b",
"wan-1.3b": "wan_t2v_1_3b",
"wan_1_3b": "wan_t2v_1_3b",
}
backend_id = backend_aliases.get(backend, backend) if backend else app_config.active_backend
# Validate storyboard
console.print("[bold blue]Validating storyboard...[/bold blue]")
validator = StoryboardValidator()
try:
storyboard_data = validator.load(storyboard)
console.print(f"[green][/green] Storyboard validated: {len(storyboard_data.shots)} shots")
console.print(f"[green]OK[/green] Storyboard validated: {len(storyboard_data.shots)} shots")
except Exception as e:
console.print(f"[red]Error validating storyboard: {e}[/red]")
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
raise typer.Exit(1)
if dry_run:
console.print("[yellow]Dry run complete. Exiting.[/yellow]")
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
return
# Setup output directories
@@ -79,32 +135,47 @@ def generate(
# Initialize checkpoint manager
checkpoint_db = project_dir / "checkpoints.db"
checkpoint_mgr = CheckpointManager(str(checkpoint_db))
project_id = storyboard_data.project.title
if checkpoint_mgr.get_project(project_id) is None:
checkpoint_mgr.create_project(
project_id=project_id,
storyboard_path=str(storyboard),
output_dir=str(project_dir),
backend_name=backend_id
)
# Initialize components
prompt_compiler = PromptCompiler(storyboard_data)
# Get backend config for shot planner
try:
backend_config = app_config.get_backend(backend)
max_chunk_duration = backend_config.max_chunk_seconds or 6.0
backend_config = app_config.get_backend(backend_id)
max_chunk_seconds = backend_config.max_chunk_seconds or 6
except Exception:
max_chunk_duration = 6.0
max_chunk_seconds = 6
shot_planner = ShotPlanner(
fps=storyboard_data.project.fps or 24,
max_chunk_duration=max_chunk_duration
max_chunk_seconds=max_chunk_seconds
)
# Initialize generation backend
if not skip_generation:
console.print(f"[bold blue]Initializing {backend} backend...[/bold blue]")
console.print(f"[bold blue]Initializing {backend_id} backend...[/bold blue]")
try:
backend_config = app_config.get_backend(backend)
video_backend = BackendFactory.create_backend(backend, backend_config)
console.print(f"[green]✓[/green] Backend initialized: {backend}")
backend_config = app_config.get_backend(backend_id)
backend_factory_name = "wan" if backend_id.startswith("wan") else backend_id
backend_config_dict = backend_config.__dict__.copy()
backend_config_dict["model_cache_dir"] = app_config.model_cache_dir
video_backend = BackendFactory.create(backend_factory_name, backend_config_dict)
video_backend.load()
console.print(f"[green]OK[/green] Backend initialized: {backend_id}")
except Exception as e:
console.print(f"[red]Error initializing backend: {e}[/red]")
raise typer.Exit(1)
import traceback
console.print(f"[red]Error initializing backend: {repr(e)}[/red]")
console.print(traceback.format_exc())
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
raise typer.Exit(1)
# Generate shots
generated_shots = []
@@ -118,69 +189,108 @@ def generate(
console=console
) as progress:
for i, shot in enumerate(storyboard_data.shots):
for shot in storyboard_data.shots:
task_id = progress.add_task(f"Shot {shot.id}", total=None)
# Ensure shot checkpoint exists
if checkpoint_mgr.get_shot(project_id, shot.id) is None:
checkpoint_mgr.create_shot(project_id, shot.id, ShotStatus.PENDING)
# Check checkpoint
if resume:
checkpoint = checkpoint_mgr.get_shot_checkpoint(
storyboard_data.project.title, shot.id
)
if checkpoint and checkpoint.status == "completed":
progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id} (cached)")
generated_shots.append(Path(checkpoint.video_path))
checkpoint = checkpoint_mgr.get_shot(project_id, shot.id)
if checkpoint and checkpoint.status == ShotStatus.COMPLETED and checkpoint.output_path:
progress.update(task_id, description=f"[green]OK[/green] Shot {shot.id} (cached)")
generated_shots.append(Path(checkpoint.output_path))
continue
# Compile prompt
prompt = prompt_compiler.compile_shot_prompt(shot)
negative_prompt = prompt_compiler.compile_negative_prompt(shot)
compiled = prompt_compiler.compile_shot(shot)
prompt = compiled.positive
negative_prompt = compiled.negative
# Plan shot
shot_plan = shot_planner.plan_shot(shot)
num_frames = align_num_frames(shot_plan.total_frames)
if num_frames != shot_plan.total_frames:
console.print(
f"[yellow]Adjusting frames to {num_frames} "
f"(from {shot_plan.total_frames}) to meet model requirements.[/yellow]"
)
capped_frames = cap_num_frames_for_backend(backend_id, num_frames)
if capped_frames != num_frames:
console.print(
f"[yellow]Capping frames to {capped_frames} "
f"(from {num_frames}) for {backend_id} VRAM safety.[/yellow]"
)
num_frames = capped_frames
# Generate
try:
from src.generation.base import GenerationSpec
output_path = shots_dir / f"{shot.id}.mp4"
width = storyboard_data.project.resolution.width
height = storyboard_data.project.resolution.height
capped_width, capped_height = cap_resolution_for_backend(backend_id, width, height)
if capped_width != width or capped_height != height:
console.print(
f"[yellow]Capping resolution to {capped_width}x{capped_height} "
f"(from {width}x{height}) for {backend_id} VRAM safety.[/yellow]"
)
aligned_width, aligned_height = align_resolution(capped_width, capped_height)
if aligned_width != width or aligned_height != height:
console.print(
f"[yellow]Adjusting resolution to {aligned_width}x{aligned_height} "
f"(from {width}x{height}) to meet model requirements.[/yellow]"
)
spec = GenerationSpec(
prompt=prompt,
negative_prompt=negative_prompt,
width=shot_plan.width,
height=shot_plan.height,
num_frames=shot_plan.total_frames,
width=aligned_width,
height=aligned_height,
num_frames=num_frames,
fps=shot_plan.fps,
seed=shot.generation.seed
seed=shot.generation.seed,
steps=shot.generation.steps,
cfg_scale=shot.generation.cfg_scale,
output_path=output_path
)
result = video_backend.generate(spec, output_dir=shots_dir)
checkpoint_mgr.update_shot(project_id, shot.id, status=ShotStatus.IN_PROGRESS)
result = video_backend.generate(spec)
if result.success:
generated_shots.append(result.video_path)
checkpoint_mgr.save_shot_checkpoint(
project_name=storyboard_data.project.title,
shot_id=shot.id,
status="completed",
video_path=str(result.video_path),
generated_shots.append(result.output_path)
checkpoint_mgr.update_shot(
project_id,
shot.id,
status=ShotStatus.COMPLETED,
output_path=str(result.output_path),
metadata=result.metadata
)
progress.update(task_id, description=f"[green][/green] Shot {shot.id}")
progress.update(task_id, description=f"[green]OK[/green] Shot {shot.id}")
else:
progress.update(task_id, description=f"[red][/red] Shot {shot.id}: {result.error_message}")
checkpoint_mgr.save_shot_checkpoint(
project_name=storyboard_data.project.title,
shot_id=shot.id,
status="failed",
progress.update(task_id, description=f"[red]FAILED[/red] Shot {shot.id}: {result.error_message}")
checkpoint_mgr.update_shot(
project_id,
shot.id,
status=ShotStatus.FAILED,
error_message=result.error_message
)
except Exception as e:
progress.update(task_id, description=f"[red][/red] Shot {shot.id}: {e}")
checkpoint_mgr.save_shot_checkpoint(
project_name=storyboard_data.project.title,
shot_id=shot.id,
status="failed",
progress.update(task_id, description=f"[red]FAILED[/red] Shot {shot.id}: {e}")
checkpoint_mgr.update_shot(
project_id,
shot.id,
status=ShotStatus.FAILED,
error_message=str(e)
)
try:
video_backend.unload()
except Exception:
pass
# Assembly
if not skip_assembly and generated_shots:
@@ -197,40 +307,39 @@ def generate(
result = assembler.assemble(generated_shots, final_output, assembly_config)
if result.success:
console.print(f"[green][/green] Video assembled: {final_output}")
console.print(f"[green]OK[/green] Video assembled: {final_output}")
# Upscale if requested
if upscale:
console.print(f"[bold blue]Upscaling {upscale}x...[/bold blue]")
upscale_config = UpscaleConfig(
factor=upscale,
upscaler_type="ffmpeg_sr"
)
upscale_mgr = UpscaleManager(upscale_config)
upscaled_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}_upscaled.mp4"
upscale_result = upscale_mgr.upscale(final_output, upscaled_output)
if upscale_result.success:
console.print(f"[green]✓[/green] Upscaled video: {upscaled_output}")
final_output = upscaled_output
if upscale not in (2, 4):
console.print("[yellow]Warning: Upscale factor must be 2 or 4. Skipping upscaling.[/yellow]")
else:
console.print(f"[yellow]Warning: Upscaling failed: {upscale_result.error_message}[/yellow]")
upscale_config = UpscaleConfig(
factor=UpscaleFactor.X2 if upscale == 2 else UpscaleFactor.X4,
upscaler_type=UpscalerType.FFMPEG_SR
)
upscale_mgr = UpscaleManager(upscale_config)
# Save project metadata
checkpoint_mgr.save_project_checkpoint(
project_name=storyboard_data.project.title,
storyboard_path=str(storyboard),
output_path=str(final_output),
status="completed",
num_shots=len(generated_shots)
)
upscaled_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}_upscaled.mp4"
upscale_result = upscale_mgr.upscale(final_output, upscaled_output)
if upscale_result.success:
console.print(f"[green]OK[/green] Upscaled video: {upscaled_output}")
final_output = upscaled_output
else:
console.print(f"[yellow]Warning: Upscaling failed: {upscale_result.error_message}[/yellow]")
checkpoint_mgr.update_project_status(project_id, ProjectStatus.COMPLETED)
console.print(f"\n[bold green]Success![/bold green] Video saved to: {final_output}")
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
else:
console.print(f"[red]Error assembling video: {result.error_message}[/red]")
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
raise typer.Exit(1)
else:
console.print(f"[yellow]Elapsed: {format_duration(time.time() - start_time)}[/yellow]")
@app.command()
@@ -249,7 +358,7 @@ def validate(
try:
data = validator.load(storyboard)
console.print(f"[green][/green] Storyboard is valid")
console.print("[green]OK[/green] Storyboard is valid")
console.print(f"\n[bold]Title:[/bold] {data.project.title}")
console.print(f"[bold]Shots:[/bold] {len(data.shots)}")
console.print(f"[bold]FPS:[/bold] {data.project.fps or 'Not specified'}")
@@ -287,10 +396,10 @@ def resume(
raise typer.Exit(1)
checkpoint_mgr = CheckpointManager(str(checkpoint_db))
project_checkpoint = checkpoint_mgr.get_project_checkpoint(project)
project_checkpoint = checkpoint_mgr.get_project(project)
if not project_checkpoint:
console.print(f"[red]Error: No project checkpoint found[/red]")
console.print("[red]Error: No project checkpoint found[/red]")
raise typer.Exit(1)
console.print(f"[bold blue]Resuming project:[/bold blue] {project}")
@@ -313,8 +422,8 @@ def list_backends():
table.add_column("Type", style="green")
table.add_column("Description", style="white")
table.add_row("wan", "T2V", "WAN 2.x text-to-video")
table.add_row("wan-1.3b", "T2V", "WAN 1.3B (faster, lower quality)")
table.add_row("wan_t2v_14b", "T2V", "WAN 2.x text-to-video (14B)")
table.add_row("wan_t2v_1_3b", "T2V", "WAN 1.3B (faster, lower quality)")
table.add_row("svd", "I2V", "Stable Video Diffusion (fallback)")
console.print(table)

View File

@@ -79,7 +79,8 @@ class WanBackend(BaseVideoBackend):
return
import torch
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers import WanPipeline
from diffusers.models import AutoencoderKLWan
print(f"Loading WAN model: {self.model_id}")
@@ -92,13 +93,12 @@ class WanBackend(BaseVideoBackend):
# Load VAE
print("Loading VAE...")
vae_id = self.model_id.replace("-T2V-", "-VAE-") if "-T2V-" in self.model_id else self.model_id
self.vae = AutoencoderKLWan.from_pretrained(
vae_id,
subfolder="vae" if "-T2V-" in self.model_id else None,
self.model_id,
subfolder="vae",
torch_dtype=self.dtype,
cache_dir=str(cache_dir)
cache_dir=str(cache_dir),
low_cpu_mem_usage=True
)
# Load pipeline
@@ -107,24 +107,30 @@ class WanBackend(BaseVideoBackend):
self.model_id,
vae=self.vae,
torch_dtype=self.dtype,
cache_dir=str(cache_dir)
cache_dir=str(cache_dir),
low_cpu_mem_usage=True
)
# Move to GPU
if torch.cuda.is_available():
print("Moving to CUDA...")
self.pipeline = self.pipeline.to("cuda")
if not torch.cuda.is_available():
raise RuntimeError(
"CUDA is not available. This backend requires an NVIDIA GPU. "
"Install a CUDA-enabled PyTorch build (cu121/cu124) and ensure the "
"correct conda environment is active."
)
# Enable memory optimizations
if self.enable_vae_slicing:
print("Enabling VAE slicing...")
self.pipeline.vae.enable_slicing()
total_vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
print(f"Moving to CUDA... (GPU VRAM: {total_vram:.1f} GB)")
self.pipeline = self.pipeline.to("cuda")
if self.enable_vae_tiling:
print("Enabling VAE tiling...")
self.pipeline.vae.enable_tiling()
else:
print("WARNING: CUDA not available, using CPU (will be very slow)")
# Enable memory optimizations
if self.enable_vae_slicing:
print("Enabling VAE slicing...")
self.pipeline.vae.enable_slicing()
if self.enable_vae_tiling:
print("Enabling VAE tiling...")
self.pipeline.vae.enable_tiling()
self._is_loaded = True
print("WAN model loaded successfully")

View File

@@ -8,7 +8,14 @@ from pathlib import Path
from unittest.mock import patch, MagicMock
from typer.testing import CliRunner
from src.cli.main import app
from src.cli.main import (
app,
align_resolution,
align_num_frames,
cap_resolution_for_backend,
cap_num_frames_for_backend,
format_duration,
)
runner = CliRunner()
@@ -61,8 +68,8 @@ class TestValidateCommand:
result = runner.invoke(app, ["validate", str(storyboard_file)])
assert result.exit_code == 0
assert "valid" in result.output
assert "Test Storyboard" in result.output
assert "valid" in result.output
assert "Test Storyboard" in result.output
def test_validate_verbose(self):
"""Test validation with verbose flag."""
@@ -82,7 +89,7 @@ class TestListBackendsCommand:
"""Test listing available backends."""
result = runner.invoke(app, ["list-backends"])
assert result.exit_code == 0
assert "wan" in result.output
assert "wan_t2v_14b" in result.output
assert "Available Backends" in result.output
@@ -140,7 +147,7 @@ class TestResumeCommand:
with patch('src.cli.main.CheckpointManager') as mock_mgr:
mock_checkpoint = MagicMock()
mock_checkpoint.storyboard_path = str(Path(tmpdir) / "storyboard.json")
mock_mgr.return_value.get_project_checkpoint.return_value = mock_checkpoint
mock_mgr.return_value.get_project.return_value = mock_checkpoint
# Create storyboard file
storyboard_file = Path(tmpdir) / "storyboard.json"
@@ -185,3 +192,52 @@ class TestCLIHelp:
result = runner.invoke(app, ["resume", "--help"])
assert result.exit_code == 0
assert "Resume a failed or interrupted project" in result.output
class TestCLIAlignmentHelpers:
"""Test CLI alignment helpers."""
def test_align_resolution_no_change(self):
"""Resolution already aligned to 16."""
width, height = align_resolution(1920, 1088)
assert width == 1920
assert height == 1088
def test_align_resolution_adjusts(self):
"""Resolution is rounded up to multiple of 16."""
width, height = align_resolution(1920, 1080)
assert width == 1920
assert height == 1088
def test_align_num_frames_no_change(self):
"""Frame count already valid."""
assert align_num_frames(97) == 97 # 97 - 1 is divisible by 4
def test_align_num_frames_adjusts(self):
"""Frame count is rounded up to valid value."""
assert align_num_frames(96) == 97
def test_cap_resolution_for_small_backend(self):
"""Smaller backend caps resolution."""
width, height = cap_resolution_for_backend("wan_t2v_1_3b", 1920, 1080)
assert width == 1280
assert height == 720
def test_cap_resolution_for_other_backends(self):
"""No cap for other backends."""
width, height = cap_resolution_for_backend("wan_t2v_14b", 1920, 1080)
assert width == 1920
assert height == 1080
def test_cap_num_frames_for_small_backend(self):
"""Smaller backend caps frame count."""
assert cap_num_frames_for_backend("wan_t2v_1_3b", 121) == 97
def test_cap_num_frames_for_other_backends(self):
"""No cap for other backends."""
assert cap_num_frames_for_backend("wan_t2v_14b", 121) == 121
def test_format_duration(self):
"""Format duration for display."""
assert format_duration(5) == "0m05s"
assert format_duration(65) == "1m05s"