CLI implementation

This commit is contained in:
2026-02-04 01:10:58 -05:00
parent 33687865fd
commit 77cc907f6e
8 changed files with 1853 additions and 1 deletions

View File

@@ -1,3 +1,7 @@
""" """
CLI entry points. CLI entry points.
""" """
from src.cli.main import app, main
__all__ = ["app", "main"]

321
src/cli/main.py Normal file
View File

@@ -0,0 +1,321 @@
"""
CLI entry point for storyboard video generation.
"""
import sys
from pathlib import Path
from typing import Optional
import typer
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.table import Table
from src.storyboard.loader import StoryboardValidator
from src.storyboard.prompt_compiler import PromptCompiler
from src.storyboard.shot_planner import ShotPlanner
from src.core.config import ConfigLoader
from src.core.checkpoint import CheckpointManager
from src.generation.base import BackendFactory
from src.assembly.assembler import FFmpegAssembler, AssemblyConfig
from src.upscaling.upscaler import UpscaleManager, UpscaleConfig
app = typer.Typer(help="Storyboard to Video Generation Pipeline")
console = Console()
@app.command()
def generate(
storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"),
output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"),
config: Optional[Path] = typer.Option(None, "--config", "-c", help="Path to config file"),
backend: str = typer.Option("wan", "--backend", "-b", help="Generation backend (wan, svd)"),
resume: bool = typer.Option(False, "--resume", "-r", help="Resume from checkpoint"),
skip_generation: bool = typer.Option(False, "--skip-generation", help="Skip generation, only assemble"),
skip_assembly: bool = typer.Option(False, "--skip-assembly", help="Skip assembly, only generate shots"),
upscale: Optional[int] = typer.Option(None, "--upscale", help="Upscale factor (2 or 4)"),
dry_run: bool = typer.Option(False, "--dry-run", help="Validate storyboard without generating"),
):
"""Generate video from storyboard."""
# Validate storyboard exists
if not storyboard.exists():
console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]")
raise typer.Exit(1)
# Load configuration
try:
if config:
app_config = ConfigLoader.load(config)
else:
app_config = ConfigLoader.load()
except Exception as e:
console.print(f"[red]Error loading config: {e}[/red]")
raise typer.Exit(1)
# Validate storyboard
console.print("[bold blue]Validating storyboard...[/bold blue]")
validator = StoryboardValidator()
try:
storyboard_data = validator.load(storyboard)
console.print(f"[green]✓[/green] Storyboard validated: {len(storyboard_data.shots)} shots")
except Exception as e:
console.print(f"[red]Error validating storyboard: {e}[/red]")
raise typer.Exit(1)
if dry_run:
console.print("[yellow]Dry run complete. Exiting.[/yellow]")
return
# Setup output directories
project_dir = output / storyboard_data.project.title.replace(" ", "_")
shots_dir = project_dir / "shots"
assembled_dir = project_dir / "assembled"
metadata_dir = project_dir / "metadata"
for d in [shots_dir, assembled_dir, metadata_dir]:
d.mkdir(parents=True, exist_ok=True)
# Initialize checkpoint manager
checkpoint_db = project_dir / "checkpoints.db"
checkpoint_mgr = CheckpointManager(str(checkpoint_db))
# Initialize components
prompt_compiler = PromptCompiler(storyboard_data.project.global_style)
shot_planner = ShotPlanner(
fps=storyboard_data.project.fps or 24,
max_chunk_duration=app_config.backend.max_chunk_seconds or 6.0
)
# Initialize generation backend
if not skip_generation:
console.print(f"[bold blue]Initializing {backend} backend...[/bold blue]")
try:
backend_config = app_config.get_backend_config(backend)
video_backend = BackendFactory.create_backend(backend, backend_config)
console.print(f"[green]✓[/green] Backend initialized: {backend}")
except Exception as e:
console.print(f"[red]Error initializing backend: {e}[/red]")
raise typer.Exit(1)
# Generate shots
generated_shots = []
if not skip_generation:
console.print(f"[bold blue]Generating {len(storyboard_data.shots)} shots...[/bold blue]")
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console
) as progress:
for i, shot in enumerate(storyboard_data.shots):
task_id = progress.add_task(f"Shot {shot.id}", total=None)
# Check checkpoint
if resume:
checkpoint = checkpoint_mgr.get_shot_checkpoint(
storyboard_data.project.title, shot.id
)
if checkpoint and checkpoint.status == "completed":
progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id} (cached)")
generated_shots.append(Path(checkpoint.video_path))
continue
# Compile prompt
prompt = prompt_compiler.compile_shot_prompt(shot)
negative_prompt = prompt_compiler.compile_negative_prompt(shot)
# Plan shot
shot_plan = shot_planner.plan_shot(shot)
# Generate
try:
from src.generation.base import GenerationSpec
spec = GenerationSpec(
prompt=prompt,
negative_prompt=negative_prompt,
width=shot_plan.width,
height=shot_plan.height,
num_frames=shot_plan.total_frames,
fps=shot_plan.fps,
seed=shot.generation.seed
)
result = video_backend.generate(spec, output_dir=shots_dir)
if result.success:
generated_shots.append(result.video_path)
checkpoint_mgr.save_shot_checkpoint(
project_name=storyboard_data.project.title,
shot_id=shot.id,
status="completed",
video_path=str(result.video_path),
metadata=result.metadata
)
progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id}")
else:
progress.update(task_id, description=f"[red]✗[/red] Shot {shot.id}: {result.error_message}")
checkpoint_mgr.save_shot_checkpoint(
project_name=storyboard_data.project.title,
shot_id=shot.id,
status="failed",
error_message=result.error_message
)
except Exception as e:
progress.update(task_id, description=f"[red]✗[/red] Shot {shot.id}: {e}")
checkpoint_mgr.save_shot_checkpoint(
project_name=storyboard_data.project.title,
shot_id=shot.id,
status="failed",
error_message=str(e)
)
# Assembly
if not skip_assembly and generated_shots:
console.print("[bold blue]Assembling video...[/bold blue]")
assembler = FFmpegAssembler()
final_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}.mp4"
assembly_config = AssemblyConfig(
fps=storyboard_data.project.fps or 24,
add_shot_labels=False
)
result = assembler.assemble(generated_shots, final_output, assembly_config)
if result.success:
console.print(f"[green]✓[/green] Video assembled: {final_output}")
# Upscale if requested
if upscale:
console.print(f"[bold blue]Upscaling {upscale}x...[/bold blue]")
upscale_config = UpscaleConfig(
factor=upscale,
upscaler_type="ffmpeg_sr"
)
upscale_mgr = UpscaleManager(upscale_config)
upscaled_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}_upscaled.mp4"
upscale_result = upscale_mgr.upscale(final_output, upscaled_output)
if upscale_result.success:
console.print(f"[green]✓[/green] Upscaled video: {upscaled_output}")
final_output = upscaled_output
else:
console.print(f"[yellow]Warning: Upscaling failed: {upscale_result.error_message}[/yellow]")
# Save project metadata
checkpoint_mgr.save_project_checkpoint(
project_name=storyboard_data.project.title,
storyboard_path=str(storyboard),
output_path=str(final_output),
status="completed",
num_shots=len(generated_shots)
)
console.print(f"\n[bold green]Success![/bold green] Video saved to: {final_output}")
else:
console.print(f"[red]Error assembling video: {result.error_message}[/red]")
raise typer.Exit(1)
@app.command()
def validate(
storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed information"),
):
"""Validate a storyboard file."""
if not storyboard.exists():
console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]")
raise typer.Exit(1)
validator = StoryboardValidator()
try:
data = validator.load(storyboard)
console.print(f"[green]✓[/green] Storyboard is valid")
console.print(f"\n[bold]Title:[/bold] {data.project.title}")
console.print(f"[bold]Shots:[/bold] {len(data.shots)}")
console.print(f"[bold]FPS:[/bold] {data.project.fps or 'Not specified'}")
resolution = f"{data.project.resolution.width}x{data.project.resolution.height}" if data.project.resolution else "Not specified"
console.print(f"[bold]Resolution:[/bold] {resolution}")
if verbose:
table = Table(title="Shots")
table.add_column("ID", style="cyan")
table.add_column("Prompt", style="green")
table.add_column("Duration", style="yellow")
for shot in data.shots:
table.add_row(shot.id, shot.prompt[:50], f"{shot.duration_s}s")
console.print(table)
except Exception as e:
console.print(f"[red]Validation failed: {e}[/red]")
raise typer.Exit(1)
@app.command()
def resume(
project: str = typer.Argument(..., help="Project name"),
output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"),
):
"""Resume a failed or interrupted project."""
project_dir = output / project.replace(" ", "_")
checkpoint_db = project_dir / "checkpoints.db"
if not checkpoint_db.exists():
console.print(f"[red]Error: No checkpoint found for project '{project}'[/red]")
raise typer.Exit(1)
checkpoint_mgr = CheckpointManager(str(checkpoint_db))
project_checkpoint = checkpoint_mgr.get_project_checkpoint(project)
if not project_checkpoint:
console.print(f"[red]Error: No project checkpoint found[/red]")
raise typer.Exit(1)
console.print(f"[bold blue]Resuming project:[/bold blue] {project}")
console.print(f"Storyboard: {project_checkpoint.storyboard_path}")
# Re-run generation with resume flag
generate(
storyboard=Path(project_checkpoint.storyboard_path),
output=output,
resume=True
)
@app.command()
def list_backends():
"""List available generation backends."""
table = Table(title="Available Backends")
table.add_column("Name", style="cyan")
table.add_column("Type", style="green")
table.add_column("Description", style="white")
table.add_row("wan", "T2V", "WAN 2.x text-to-video")
table.add_row("wan-1.3b", "T2V", "WAN 1.3B (faster, lower quality)")
table.add_row("svd", "I2V", "Stable Video Diffusion (fallback)")
console.print(table)
def main():
"""Entry point for the CLI."""
app()
if __name__ == "__main__":
main()

View File

@@ -1,3 +1,25 @@
""" """
Upscaling module. Upscaling module.
""" """
from src.upscaling.upscaler import (
BaseUpscaler,
FFmpegSRUpscaler,
RealESRGANUpscaler,
UpscaleManager,
UpscaleConfig,
UpscaleResult,
UpscalerType,
UpscaleFactor
)
__all__ = [
"BaseUpscaler",
"FFmpegSRUpscaler",
"RealESRGANUpscaler",
"UpscaleManager",
"UpscaleConfig",
"UpscaleResult",
"UpscalerType",
"UpscaleFactor"
]

579
src/upscaling/upscaler.py Normal file
View File

@@ -0,0 +1,579 @@
"""
Video upscaling module for enhancing video resolution.
Supports multiple upscaling backends including Real-ESRGAN, RealBasicVSR,
and FFmpeg-based super-resolution filters.
"""
import subprocess
import tempfile
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple
class UpscalerType(str, Enum):
"""Supported upscaler types."""
REAL_ESRGAN = "real_esrgan"
REAL_BASIC_VSR = "real_basic_vsr"
FFMPEG_SR = "ffmpeg_sr"
OPENCV_DNN = "opencv_dnn"
class UpscaleFactor(int, Enum):
"""Supported upscale factors."""
X2 = 2
X4 = 4
@dataclass
class UpscaleConfig:
"""Configuration for video upscaling."""
upscaler_type: UpscalerType = UpscalerType.FFMPEG_SR
factor: UpscaleFactor = UpscaleFactor.X2
denoise_strength: float = 0.5
tile_size: int = 0 # 0 = no tiling
tile_pad: int = 10
pre_pad: int = 0
half_precision: bool = True # Use fp16 for speed
device: str = "cuda"
batch_size: int = 1
# Model paths (optional, will use defaults if not specified)
model_path: Optional[Path] = None
# FFmpeg-specific options
ffmpeg_preset: str = "medium"
ffmpeg_crf: int = 18
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for metadata."""
return {
"upscaler_type": self.upscaler_type.value,
"factor": self.factor.value,
"denoise_strength": self.denoise_strength,
"tile_size": self.tile_size,
"tile_pad": self.tile_pad,
"pre_pad": self.pre_pad,
"half_precision": self.half_precision,
"device": self.device,
"batch_size": self.batch_size,
"model_path": str(self.model_path) if self.model_path else None,
"ffmpeg_preset": self.ffmpeg_preset,
"ffmpeg_crf": self.ffmpeg_crf
}
@dataclass
class UpscaleResult:
"""Result of video upscaling."""
success: bool
output_path: Optional[Path] = None
input_resolution: Optional[Tuple[int, int]] = None
output_resolution: Optional[Tuple[int, int]] = None
processing_time_s: Optional[float] = None
error_message: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
class BaseUpscaler(ABC):
"""Abstract base class for video upscalers."""
def __init__(self, config: UpscaleConfig):
"""
Initialize upscaler.
Args:
config: Upscale configuration
"""
self.config = config
@abstractmethod
def upscale(self, input_path: Path, output_path: Path) -> UpscaleResult:
"""
Upscale a video file.
Args:
input_path: Path to input video
output_path: Path for output video
Returns:
UpscaleResult with output details
"""
pass
@abstractmethod
def is_available(self) -> bool:
"""Check if this upscaler is available on the system."""
pass
def _get_video_info(self, video_path: Path) -> Dict[str, Any]:
"""
Get video information using ffprobe.
Args:
video_path: Path to video file
Returns:
Dictionary with video info (width, height, fps, duration)
"""
try:
cmd = [
"ffprobe",
"-v", "error",
"-select_streams", "v:0",
"-show_entries", "stream=width,height,r_frame_rate,duration",
"-of", "json",
str(video_path)
]
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=30
)
if result.returncode == 0:
import json
data = json.loads(result.stdout)
stream = data.get("streams", [{}])[0]
# Parse frame rate fraction
fps_str = stream.get("r_frame_rate", "24/1")
if "/" in fps_str:
num, den = fps_str.split("/")
fps = float(num) / float(den)
else:
fps = float(fps_str)
return {
"width": int(stream.get("width", 0)),
"height": int(stream.get("height", 0)),
"fps": fps,
"duration": float(stream.get("duration", 0))
}
except Exception:
pass
return {"width": 0, "height": 0, "fps": 24.0, "duration": 0.0}
class FFmpegSRUpscaler(BaseUpscaler):
"""
FFmpeg-based super-resolution upscaler.
Uses FFmpeg's scale filter with various algorithms.
Good for quick upscaling without ML models.
"""
def is_available(self) -> bool:
"""Check if ffmpeg is available."""
try:
result = subprocess.run(
["ffmpeg", "-version"],
capture_output=True,
timeout=5
)
return result.returncode == 0
except Exception:
return False
def upscale(self, input_path: Path, output_path: Path) -> UpscaleResult:
"""
Upscale using FFmpeg's scale filter.
Uses lanczos scaling for quality.
"""
if not input_path.exists():
return UpscaleResult(
success=False,
error_message=f"Input file not found: {input_path}"
)
# Get input resolution
info = self._get_video_info(input_path)
input_width = info["width"]
input_height = info["height"]
if input_width == 0 or input_height == 0:
return UpscaleResult(
success=False,
error_message="Could not determine input video resolution"
)
# Calculate output resolution
factor = self.config.factor.value
output_width = input_width * factor
output_height = input_height * factor
# Ensure output directory exists
output_path.parent.mkdir(parents=True, exist_ok=True)
import time
start_time = time.time()
try:
# Build FFmpeg command with high-quality scaling
# Using lanczos for better quality than bicubic
cmd = [
"ffmpeg",
"-y", # Overwrite output
"-i", str(input_path),
"-vf", f"scale={output_width}:{output_height}:flags=lanczos",
"-c:v", "libx264",
"-preset", self.config.ffmpeg_preset,
"-crf", str(self.config.ffmpeg_crf),
"-pix_fmt", "yuv420p",
"-movflags", "+faststart",
str(output_path)
]
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=7200 # 2 hour timeout
)
processing_time = time.time() - start_time
if result.returncode != 0:
return UpscaleResult(
success=False,
error_message=f"FFmpeg error: {result.stderr}",
processing_time_s=processing_time
)
return UpscaleResult(
success=True,
output_path=output_path,
input_resolution=(input_width, input_height),
output_resolution=(output_width, output_height),
processing_time_s=processing_time,
metadata={
"config": self.config.to_dict(),
"algorithm": "lanczos",
"fps": info["fps"],
"duration": info["duration"]
}
)
except subprocess.TimeoutExpired:
return UpscaleResult(
success=False,
error_message="FFmpeg upscaling timed out (2 hours)",
processing_time_s=time.time() - start_time
)
except Exception as e:
return UpscaleResult(
success=False,
error_message=f"Upscaling failed: {str(e)}",
processing_time_s=time.time() - start_time
)
class RealESRGANUpscaler(BaseUpscaler):
"""
Real-ESRGAN upscaler for high-quality 2x/4x upscaling.
Requires Real-ESRGAN to be installed.
"""
def is_available(self) -> bool:
"""Check if Real-ESRGAN is available."""
try:
import realesrgan
return True
except ImportError:
return False
def upscale(self, input_path: Path, output_path: Path) -> UpscaleResult:
"""
Upscale using Real-ESRGAN.
Processes video frame by frame.
"""
if not input_path.exists():
return UpscaleResult(
success=False,
error_message=f"Input file not found: {input_path}"
)
try:
import torch
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
except ImportError as e:
return UpscaleResult(
success=False,
error_message=f"Real-ESRGAN not installed: {e}"
)
# Get input resolution
info = self._get_video_info(input_path)
input_width = info["width"]
input_height = info["height"]
if input_width == 0 or input_height == 0:
return UpscaleResult(
success=False,
error_message="Could not determine input video resolution"
)
# Calculate output resolution
factor = self.config.factor.value
output_width = input_width * factor
output_height = input_height * factor
# Ensure output directory exists
output_path.parent.mkdir(parents=True, exist_ok=True)
import time
import cv2
import numpy as np
start_time = time.time()
try:
# Initialize Real-ESRGAN model
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=factor
)
# Determine model name based on factor
model_name = f"RealESRGAN_x{factor}plus.pth"
model_path = self.config.model_path or Path(f"weights/{model_name}")
upsampler = RealESRGANer(
scale=factor,
model_path=str(model_path),
model=model,
tile=self.config.tile_size,
tile_pad=self.config.tile_pad,
pre_pad=self.config.pre_pad,
half=self.config.half_precision,
device=torch.device(self.config.device)
)
# Extract frames, upscale, and reassemble
with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir)
frames_dir = tmp_path / "frames"
frames_dir.mkdir()
upscaled_dir = tmp_path / "upscaled"
upscaled_dir.mkdir()
# Extract frames
self._extract_frames(input_path, frames_dir, info["fps"])
# Upscale frames
frame_files = sorted(frames_dir.glob("*.png"))
for i, frame_file in enumerate(frame_files):
img = cv2.imread(str(frame_file), cv2.IMREAD_UNCHANGED)
if img is not None:
output, _ = upsampler.enhance(img, outscale=factor)
cv2.imwrite(str(upscaled_dir / f"frame_{i:06d}.png"), output)
# Reassemble video
self._assemble_video(
upscaled_dir,
output_path,
info["fps"],
output_width,
output_height
)
processing_time = time.time() - start_time
return UpscaleResult(
success=True,
output_path=output_path,
input_resolution=(input_width, input_height),
output_resolution=(output_width, output_height),
processing_time_s=processing_time,
metadata={
"config": self.config.to_dict(),
"algorithm": "Real-ESRGAN",
"model": model_name,
"frames_processed": len(frame_files)
}
)
except Exception as e:
return UpscaleResult(
success=False,
error_message=f"Real-ESRGAN upscaling failed: {str(e)}",
processing_time_s=time.time() - start_time
)
def _extract_frames(self, video_path: Path, output_dir: Path, fps: float):
"""Extract frames from video."""
cmd = [
"ffmpeg",
"-i", str(video_path),
"-vf", f"fps={fps}",
"-q:v", "2",
str(output_dir / "frame_%06d.png")
]
subprocess.run(cmd, capture_output=True, check=True)
def _assemble_video(
self,
frames_dir: Path,
output_path: Path,
fps: float,
width: int,
height: int
):
"""Assemble frames into video."""
cmd = [
"ffmpeg",
"-y",
"-framerate", str(fps),
"-i", str(frames_dir / "frame_%06d.png"),
"-c:v", "libx264",
"-preset", self.config.ffmpeg_preset,
"-crf", str(self.config.ffmpeg_crf),
"-pix_fmt", "yuv420p",
"-s", f"{width}x{height}",
str(output_path)
]
subprocess.run(cmd, capture_output=True, check=True)
class UpscaleManager:
"""
Manager for video upscaling operations.
Automatically selects best available upscaler.
"""
def __init__(self, config: Optional[UpscaleConfig] = None):
"""
Initialize upscaling manager.
Args:
config: Upscale configuration (uses defaults if None)
"""
self.config = config or UpscaleConfig()
self._upscalers: Dict[UpscalerType, BaseUpscaler] = {}
def get_upscaler(self, upscaler_type: Optional[UpscalerType] = None) -> BaseUpscaler:
"""
Get upscaler instance.
Args:
upscaler_type: Type of upscaler to use (uses config default if None)
Returns:
Upscaler instance
Raises:
RuntimeError: If requested upscaler is not available
"""
upscaler_type = upscaler_type or self.config.upscaler_type
if upscaler_type not in self._upscalers:
if upscaler_type == UpscalerType.FFMPEG_SR:
upscaler = FFmpegSRUpscaler(self.config)
elif upscaler_type == UpscalerType.REAL_ESRGAN:
upscaler = RealESRGANUpscaler(self.config)
else:
raise ValueError(f"Unknown upscaler type: {upscaler_type}")
if not upscaler.is_available():
raise RuntimeError(f"Upscaler {upscaler_type.value} is not available")
self._upscalers[upscaler_type] = upscaler
return self._upscalers[upscaler_type]
def upscale(
self,
input_path: Path,
output_path: Path,
upscaler_type: Optional[UpscalerType] = None
) -> UpscaleResult:
"""
Upscale a video using specified or default upscaler.
Args:
input_path: Path to input video
output_path: Path for output video
upscaler_type: Type of upscaler to use (uses config default if None)
Returns:
UpscaleResult
"""
try:
upscaler = self.get_upscaler(upscaler_type)
return upscaler.upscale(input_path, output_path)
except Exception as e:
return UpscaleResult(
success=False,
error_message=str(e)
)
def get_available_upscalers(self) -> List[UpscalerType]:
"""
Get list of available upscalers on this system.
Returns:
List of available upscaler types
"""
available = []
for upscaler_type in UpscalerType:
try:
upscaler = self.get_upscaler(upscaler_type)
if upscaler.is_available():
available.append(upscaler_type)
except Exception:
pass
return available
def upscale_with_fallback(
self,
input_path: Path,
output_path: Path,
preferred_order: Optional[List[UpscalerType]] = None
) -> UpscaleResult:
"""
Upscale video trying multiple upscalers in order.
Args:
input_path: Path to input video
output_path: Path for output video
preferred_order: List of upscalers to try in order
Returns:
UpscaleResult from first successful upscaler
"""
if preferred_order is None:
# Default order: try ML-based first, then FFmpeg
preferred_order = [
UpscalerType.REAL_ESRGAN,
UpscalerType.FFMPEG_SR
]
errors = []
for upscaler_type in preferred_order:
try:
upscaler = self.get_upscaler(upscaler_type)
if upscaler.is_available():
result = upscaler.upscale(input_path, output_path)
if result.success:
return result
else:
errors.append(f"{upscaler_type.value}: {result.error_message}")
except Exception as e:
errors.append(f"{upscaler_type.value}: {str(e)}")
return UpscaleResult(
success=False,
error_message="All upscalers failed. Errors: " + "; ".join(errors)
)

View File

@@ -0,0 +1,389 @@
"""
Tests for FFmpeg assembler module.
"""
import pytest
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
from src.assembly.assembler import (
FFmpegAssembler,
AssemblyConfig,
AssemblyResult,
TransitionType
)
class TestAssemblyConfig:
"""Test AssemblyConfig dataclass."""
def test_default_config(self):
"""Test default configuration values."""
config = AssemblyConfig()
assert config.fps == 24
assert config.container == "mp4"
assert config.codec == "h264"
assert config.crf == 18
assert config.preset == "medium"
assert config.transition == TransitionType.NONE
assert config.transition_duration_ms == 500
assert config.add_shot_labels is False
assert config.audio_track is None
def test_config_to_dict(self):
"""Test configuration serialization."""
config = AssemblyConfig(
fps=30,
codec="h265",
transition=TransitionType.FADE
)
d = config.to_dict()
assert d["fps"] == 30
assert d["codec"] == "h265"
assert d["transition"] == "fade"
assert d["audio_track"] is None
def test_config_with_audio(self):
"""Test configuration with audio track."""
audio_path = Path("/path/to/audio.mp3")
config = AssemblyConfig(audio_track=audio_path)
d = config.to_dict()
assert d["audio_track"] == str(audio_path)
class TestFFmpegAssemblerInit:
"""Test FFmpegAssembler initialization."""
def test_default_init(self):
"""Test default initialization."""
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0)
assembler = FFmpegAssembler()
assert assembler.ffmpeg_path == "ffmpeg"
def test_custom_ffmpeg_path(self):
"""Test initialization with custom ffmpeg path."""
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0)
assembler = FFmpegAssembler(ffmpeg_path="/usr/bin/ffmpeg")
assert assembler.ffmpeg_path == "/usr/bin/ffmpeg"
def test_ffmpeg_not_found(self):
"""Test behavior when ffmpeg is not available."""
with patch('subprocess.run', side_effect=FileNotFoundError()):
assembler = FFmpegAssembler()
# Should not raise, just have _check_ffmpeg return False
assert assembler._check_ffmpeg() is False
class TestFFmpegAssemblerAssemble:
"""Test video assembly functionality."""
@pytest.fixture
def mock_subprocess(self):
"""Fixture to mock subprocess.run."""
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0, stderr="")
yield mock_run
@pytest.fixture
def temp_dir(self):
"""Fixture for temporary directory."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
def test_assemble_no_files(self, mock_subprocess):
"""Test assembly with no input files."""
assembler = FFmpegAssembler()
result = assembler.assemble([], Path("output.mp4"))
assert result.success is False
assert "No shot files provided" in result.error_message
def test_assemble_missing_files(self, mock_subprocess, temp_dir):
"""Test assembly with missing input files."""
assembler = FFmpegAssembler()
missing_file = temp_dir / "nonexistent.mp4"
result = assembler.assemble([missing_file], temp_dir / "output.mp4")
assert result.success is False
assert "Missing shot files" in result.error_message
def test_assemble_single_file(self, mock_subprocess, temp_dir):
"""Test simple concatenation with single file."""
# Create dummy input file
input_file = temp_dir / "shot_001.mp4"
input_file.write_bytes(b"dummy video data")
output_file = temp_dir / "output.mp4"
assembler = FFmpegAssembler()
# Mock ffprobe for duration check
with patch.object(assembler, '_get_video_duration', return_value=4.0):
result = assembler.assemble([input_file], output_file)
assert result.success is True
assert result.output_path == output_file
assert result.num_shots == 1
# Verify ffmpeg was called
assert mock_subprocess.called
call_args = mock_subprocess.call_args[0][0]
assert "ffmpeg" in call_args[0]
assert "-f" in call_args
assert "concat" in call_args
def test_assemble_multiple_files(self, mock_subprocess, temp_dir):
"""Test concatenation with multiple files."""
# Create dummy input files
input_files = [
temp_dir / "shot_001.mp4",
temp_dir / "shot_002.mp4",
temp_dir / "shot_003.mp4"
]
for f in input_files:
f.write_bytes(b"dummy video data")
output_file = temp_dir / "output.mp4"
assembler = FFmpegAssembler()
with patch.object(assembler, '_get_video_duration', return_value=4.0):
result = assembler.assemble(input_files, output_file)
assert result.success is True
assert result.num_shots == 3
assert result.output_path == output_file
def test_assemble_with_config(self, mock_subprocess, temp_dir):
"""Test assembly with custom configuration."""
input_file = temp_dir / "shot_001.mp4"
input_file.write_bytes(b"dummy video data")
output_file = temp_dir / "output.mp4"
config = AssemblyConfig(
fps=30,
codec="h265",
crf=20,
preset="slow"
)
assembler = FFmpegAssembler()
with patch.object(assembler, '_get_video_duration', return_value=4.0):
result = assembler.assemble([input_file], output_file, config)
assert result.success is True
# Verify config was used
call_args = mock_subprocess.call_args[0][0]
assert "-r" in call_args
assert "30" in call_args
assert "libx265" in call_args
assert "-crf" in call_args
assert "20" in call_args
assert "slow" in call_args
def test_assemble_ffmpeg_error(self, mock_subprocess, temp_dir):
"""Test handling of ffmpeg error."""
input_file = temp_dir / "shot_001.mp4"
input_file.write_bytes(b"dummy video data")
output_file = temp_dir / "output.mp4"
# Make ffmpeg return error
mock_subprocess.return_value = MagicMock(
returncode=1,
stderr="Error: Invalid data"
)
assembler = FFmpegAssembler()
result = assembler.assemble([input_file], output_file)
assert result.success is False
assert "FFmpeg error" in result.error_message
class TestFFmpegAssemblerTransitions:
"""Test transition functionality."""
@pytest.fixture
def mock_subprocess(self):
"""Fixture to mock subprocess.run."""
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0, stderr="")
yield mock_run
@pytest.fixture
def temp_dir(self):
"""Fixture for temporary directory."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
def test_fade_transition(self, mock_subprocess, temp_dir):
"""Test assembly with fade transition."""
input_files = [
temp_dir / "shot_001.mp4",
temp_dir / "shot_002.mp4"
]
for f in input_files:
f.write_bytes(b"dummy video data")
output_file = temp_dir / "output.mp4"
config = AssemblyConfig(
transition=TransitionType.FADE,
transition_duration_ms=500
)
assembler = FFmpegAssembler()
with patch.object(assembler, '_get_video_duration', return_value=4.0):
result = assembler.assemble(input_files, output_file, config)
assert result.success is True
# Verify filter_complex was used
call_args = mock_subprocess.call_args[0][0]
assert "-filter_complex" in call_args
assert "xfade" in str(call_args)
class TestFFmpegAssemblerAudio:
"""Test audio functionality."""
@pytest.fixture
def mock_subprocess(self):
"""Fixture to mock subprocess.run."""
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0, stderr="")
yield mock_run
@pytest.fixture
def temp_dir(self):
"""Fixture for temporary directory."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
def test_add_audio(self, mock_subprocess, temp_dir):
"""Test adding audio track."""
input_file = temp_dir / "shot_001.mp4"
input_file.write_bytes(b"dummy video data")
audio_file = temp_dir / "audio.mp3"
audio_file.write_bytes(b"dummy audio data")
output_file = temp_dir / "output.mp4"
config = AssemblyConfig(audio_track=audio_file)
assembler = FFmpegAssembler()
with patch.object(assembler, '_get_video_duration', return_value=4.0):
result = assembler.assemble([input_file], output_file, config)
assert result.success is True
# Verify audio was added (second ffmpeg call)
calls = mock_subprocess.call_args_list
# First call is for video, second should be for audio
assert len(calls) >= 2
class TestFFmpegAssemblerUtilities:
"""Test utility methods."""
def test_get_video_codec(self):
"""Test codec name mapping."""
assembler = FFmpegAssembler()
assert assembler._get_video_codec("h264") == "libx264"
assert assembler._get_video_codec("h265") == "libx265"
assert assembler._get_video_codec("vp9") == "libvpx-vp9"
assert assembler._get_video_codec("unknown") == "libx264" # Default
def test_get_video_duration(self):
"""Test duration extraction."""
assembler = FFmpegAssembler()
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(
returncode=0,
stdout="4.500\n"
)
duration = assembler._get_video_duration(Path("test.mp4"))
assert duration == 4.5
def test_get_video_duration_error(self):
"""Test duration extraction with error."""
assembler = FFmpegAssembler()
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=1)
duration = assembler._get_video_duration(Path("test.mp4"))
assert duration == 0.0
def test_extract_frame(self):
"""Test frame extraction."""
assembler = FFmpegAssembler()
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0)
result = assembler.extract_frame(
Path("video.mp4"),
2.5,
Path("frame.jpg")
)
assert result is True
# Verify ffmpeg command
call_args = mock_run.call_args[0][0]
assert "-ss" in call_args
assert "2.5" in call_args
assert "-vframes" in call_args
assert "1" in call_args
def test_burn_in_labels(self):
"""Test label burn-in."""
assembler = FFmpegAssembler()
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0)
result = assembler.burn_in_labels(
Path("input.mp4"),
Path("output.mp4"),
["Shot 1", "Shot 2"],
font_size=24,
position="top-left"
)
assert result.success is True
# Verify drawtext filter
call_args = mock_run.call_args[0][0]
assert "-vf" in call_args
assert "drawtext" in str(call_args)
class TestAssemblyResult:
"""Test AssemblyResult dataclass."""
def test_success_result(self):
"""Test successful result."""
result = AssemblyResult(
success=True,
output_path=Path("output.mp4"),
duration_s=12.5,
num_shots=3
)
assert result.success is True
assert result.output_path == Path("output.mp4")
assert result.duration_s == 12.5
assert result.num_shots == 3
assert result.metadata == {}
def test_failure_result(self):
"""Test failure result."""
result = AssemblyResult(
success=False,
error_message="FFmpeg failed"
)
assert result.success is False
assert result.error_message == "FFmpeg failed"
assert result.output_path is None

187
tests/unit/test_cli.py Normal file
View File

@@ -0,0 +1,187 @@
"""
Tests for CLI module.
"""
import pytest
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock
from typer.testing import CliRunner
from src.cli.main import app
runner = CliRunner()
# Valid storyboard JSON template
VALID_STORYBOARD = """
{
"schema_version": "1.0",
"project": {
"title": "Test Storyboard",
"fps": 24,
"target_duration_s": 20,
"resolution": {"width": 1280, "height": 720},
"aspect_ratio": "16:9"
},
"shots": [
{
"id": "shot_001",
"duration_s": 4,
"prompt": "A test shot description",
"camera": {"framing": "wide", "movement": "static"},
"generation": {"seed": 42, "steps": 30, "cfg_scale": 6.0}
}
],
"output": {
"container": "mp4",
"codec": "h264",
"crf": 18,
"preset": "medium"
}
}
"""
class TestValidateCommand:
"""Test the validate command."""
def test_validate_nonexistent_file(self):
"""Test validation with non-existent file."""
result = runner.invoke(app, ["validate", "nonexistent.json"])
assert result.exit_code == 1
assert "not found" in result.output
def test_validate_valid_storyboard(self):
"""Test validation with valid storyboard."""
with tempfile.TemporaryDirectory() as tmpdir:
storyboard_file = Path(tmpdir) / "test.json"
storyboard_file.write_text(VALID_STORYBOARD)
result = runner.invoke(app, ["validate", str(storyboard_file)])
assert result.exit_code == 0
assert "valid" in result.output
assert "Test Storyboard" in result.output
def test_validate_verbose(self):
"""Test validation with verbose flag."""
with tempfile.TemporaryDirectory() as tmpdir:
storyboard_file = Path(tmpdir) / "test.json"
storyboard_file.write_text(VALID_STORYBOARD)
result = runner.invoke(app, ["validate", str(storyboard_file), "--verbose"])
assert result.exit_code == 0
assert "Shots" in result.output
class TestListBackendsCommand:
"""Test the list-backends command."""
def test_list_backends(self):
"""Test listing available backends."""
result = runner.invoke(app, ["list-backends"])
assert result.exit_code == 0
assert "wan" in result.output
assert "Available Backends" in result.output
class TestGenerateCommand:
"""Test the generate command."""
@pytest.fixture
def mock_storyboard(self):
"""Create a mock storyboard file."""
with tempfile.TemporaryDirectory() as tmpdir:
storyboard_file = Path(tmpdir) / "test.json"
storyboard_file.write_text(VALID_STORYBOARD)
yield storyboard_file
def test_generate_dry_run(self, mock_storyboard):
"""Test generate command with dry-run flag."""
result = runner.invoke(app, [
"generate",
str(mock_storyboard),
"--dry-run"
])
assert result.exit_code == 0
assert "Dry run complete" in result.output
def test_generate_nonexistent_storyboard(self):
"""Test generate with non-existent storyboard."""
result = runner.invoke(app, ["generate", "nonexistent.json"])
assert result.exit_code == 1
assert "not found" in result.output
class TestResumeCommand:
"""Test the resume command."""
def test_resume_nonexistent_project(self):
"""Test resume with non-existent project."""
with tempfile.TemporaryDirectory() as tmpdir:
result = runner.invoke(app, [
"resume",
"nonexistent_project",
"--output", tmpdir
])
assert result.exit_code == 1
assert "No checkpoint found" in result.output
def test_resume_existing_project(self):
"""Test resume with existing project checkpoint."""
with tempfile.TemporaryDirectory() as tmpdir:
project_dir = Path(tmpdir) / "test_project"
project_dir.mkdir()
checkpoint_db = project_dir / "checkpoints.db"
checkpoint_db.touch()
# Create a mock checkpoint
with patch('src.cli.main.CheckpointManager') as mock_mgr:
mock_checkpoint = MagicMock()
mock_checkpoint.storyboard_path = str(Path(tmpdir) / "storyboard.json")
mock_mgr.return_value.get_project_checkpoint.return_value = mock_checkpoint
# Create storyboard file
storyboard_file = Path(tmpdir) / "storyboard.json"
storyboard_file.write_text(VALID_STORYBOARD)
result = runner.invoke(app, [
"resume",
"test_project",
"--output", tmpdir
])
# Should attempt to resume (may fail on actual generation)
assert "Resuming project" in result.output or result.exit_code != 0
class TestCLIHelp:
"""Test CLI help messages."""
def test_main_help(self):
"""Test main help message."""
result = runner.invoke(app, ["--help"])
assert result.exit_code == 0
assert "Storyboard to Video Generation Pipeline" in result.output
def test_generate_help(self):
"""Test generate command help."""
result = runner.invoke(app, ["generate", "--help"])
assert result.exit_code == 0
assert "Generate video from storyboard" in result.output
assert "--output" in result.output
assert "--backend" in result.output
assert "--dry-run" in result.output
def test_validate_help(self):
"""Test validate command help."""
result = runner.invoke(app, ["validate", "--help"])
assert result.exit_code == 0
assert "Validate a storyboard file" in result.output
def test_resume_help(self):
"""Test resume command help."""
result = runner.invoke(app, ["resume", "--help"])
assert result.exit_code == 0
assert "Resume a failed or interrupted project" in result.output

350
tests/unit/test_upscaler.py Normal file
View File

@@ -0,0 +1,350 @@
"""
Tests for upscaling module.
"""
import pytest
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock, mock_open
from src.upscaling.upscaler import (
UpscaleConfig,
UpscaleResult,
UpscalerType,
UpscaleFactor,
FFmpegSRUpscaler,
RealESRGANUpscaler,
UpscaleManager,
BaseUpscaler
)
class TestUpscaleConfig:
"""Test UpscaleConfig dataclass."""
def test_default_config(self):
"""Test default configuration values."""
config = UpscaleConfig()
assert config.upscaler_type == UpscalerType.FFMPEG_SR
assert config.factor == UpscaleFactor.X2
assert config.denoise_strength == 0.5
assert config.tile_size == 0
assert config.half_precision is True
assert config.device == "cuda"
def test_custom_config(self):
"""Test custom configuration."""
config = UpscaleConfig(
upscaler_type=UpscalerType.REAL_ESRGAN,
factor=UpscaleFactor.X4,
denoise_strength=0.8,
device="cpu"
)
assert config.upscaler_type == UpscalerType.REAL_ESRGAN
assert config.factor == UpscaleFactor.X4
assert config.denoise_strength == 0.8
assert config.device == "cpu"
def test_config_to_dict(self):
"""Test configuration serialization."""
config = UpscaleConfig(
factor=UpscaleFactor.X4,
model_path=Path("/models/model.pth")
)
d = config.to_dict()
assert d["factor"] == 4
assert d["upscaler_type"] == "ffmpeg_sr"
# Path conversion is platform-specific
assert "model.pth" in d["model_path"]
class TestUpscaleResult:
"""Test UpscaleResult dataclass."""
def test_success_result(self):
"""Test successful result."""
result = UpscaleResult(
success=True,
output_path=Path("output.mp4"),
input_resolution=(1920, 1080),
output_resolution=(3840, 2160),
processing_time_s=45.5
)
assert result.success is True
assert result.output_path == Path("output.mp4")
assert result.input_resolution == (1920, 1080)
assert result.output_resolution == (3840, 2160)
assert result.processing_time_s == 45.5
def test_failure_result(self):
"""Test failure result."""
result = UpscaleResult(
success=False,
error_message="FFmpeg not found"
)
assert result.success is False
assert result.error_message == "FFmpeg not found"
class TestFFmpegSRUpscaler:
"""Test FFmpeg-based upscaler."""
@pytest.fixture
def mock_subprocess(self):
"""Fixture to mock subprocess.run."""
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="")
yield mock_run
@pytest.fixture
def temp_dir(self):
"""Fixture for temporary directory."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
def test_is_available_success(self, mock_subprocess):
"""Test availability check when ffmpeg is present."""
config = UpscaleConfig()
upscaler = FFmpegSRUpscaler(config)
assert upscaler.is_available() is True
def test_is_available_failure(self):
"""Test availability check when ffmpeg is missing."""
with patch('subprocess.run', side_effect=FileNotFoundError()):
config = UpscaleConfig()
upscaler = FFmpegSRUpscaler(config)
assert upscaler.is_available() is False
def test_upscale_missing_input(self, mock_subprocess):
"""Test upscaling with missing input file."""
config = UpscaleConfig()
upscaler = FFmpegSRUpscaler(config)
result = upscaler.upscale(Path("nonexistent.mp4"), Path("output.mp4"))
assert result.success is False
assert "not found" in result.error_message
def test_upscale_success(self, mock_subprocess, temp_dir):
"""Test successful upscaling."""
# Create dummy input file
input_file = temp_dir / "input.mp4"
input_file.write_bytes(b"dummy video")
output_file = temp_dir / "output.mp4"
# Mock ffprobe response
ffprobe_response = MagicMock(
returncode=0,
stdout='{"streams": [{"width": 1920, "height": 1080, "r_frame_rate": "24/1", "duration": "10.0"}]}'
)
with patch('subprocess.run') as mock_run:
# First call is ffprobe, second is ffmpeg
mock_run.side_effect = [ffprobe_response, MagicMock(returncode=0)]
config = UpscaleConfig(factor=UpscaleFactor.X2)
upscaler = FFmpegSRUpscaler(config)
result = upscaler.upscale(input_file, output_file)
assert result.success is True
assert result.input_resolution == (1920, 1080)
assert result.output_resolution == (3840, 2160)
def test_upscale_ffmpeg_error(self, mock_subprocess, temp_dir):
"""Test handling of ffmpeg error."""
input_file = temp_dir / "input.mp4"
input_file.write_bytes(b"dummy video")
output_file = temp_dir / "output.mp4"
ffprobe_response = MagicMock(
returncode=0,
stdout='{"streams": [{"width": 1920, "height": 1080, "r_frame_rate": "24/1", "duration": "10.0"}]}'
)
ffmpeg_error = MagicMock(returncode=1, stderr="Invalid data")
with patch('subprocess.run') as mock_run:
mock_run.side_effect = [ffprobe_response, ffmpeg_error]
config = UpscaleConfig()
upscaler = FFmpegSRUpscaler(config)
result = upscaler.upscale(input_file, output_file)
assert result.success is False
assert "FFmpeg error" in result.error_message
class TestRealESRGANUpscaler:
"""Test Real-ESRGAN upscaler."""
def test_is_available_without_realesrgan(self):
"""Test availability when Real-ESRGAN is not installed."""
with patch.dict('sys.modules', {'realesrgan': None}):
config = UpscaleConfig()
upscaler = RealESRGANUpscaler(config)
assert upscaler.is_available() is False
def test_upscale_without_realesrgan(self):
"""Test upscaling when Real-ESRGAN is not installed."""
# Mock the import to fail
import sys
original_modules = sys.modules.copy()
try:
# Remove realesrgan and torch from modules to simulate not installed
for mod in list(sys.modules.keys()):
if 'realesrgan' in mod or 'torch' in mod:
del sys.modules[mod]
# Mock import to raise ImportError
def mock_import(name, *args, **kwargs):
if name == 'realesrgan' or name.startswith('realesrgan.'):
raise ImportError("No module named 'realesrgan'")
if name == 'torch' or name.startswith('torch.'):
raise ImportError("No module named 'torch'")
return original_modules.get(name, __import__(name, *args, **kwargs))
with patch('builtins.__import__', side_effect=mock_import):
config = UpscaleConfig()
upscaler = RealESRGANUpscaler(config)
with tempfile.TemporaryDirectory() as tmpdir:
input_file = Path(tmpdir) / "input.mp4"
input_file.write_bytes(b"dummy")
output_file = Path(tmpdir) / "output.mp4"
result = upscaler.upscale(input_file, output_file)
assert result.success is False
assert "not installed" in result.error_message
finally:
# Restore original modules
sys.modules.clear()
sys.modules.update(original_modules)
class TestUpscaleManager:
"""Test UpscaleManager."""
def test_init_with_default_config(self):
"""Test initialization with default config."""
manager = UpscaleManager()
assert manager.config.upscaler_type == UpscalerType.FFMPEG_SR
def test_init_with_custom_config(self):
"""Test initialization with custom config."""
config = UpscaleConfig(upscaler_type=UpscalerType.REAL_ESRGAN)
manager = UpscaleManager(config)
assert manager.config.upscaler_type == UpscalerType.REAL_ESRGAN
def test_get_upscaler_caching(self):
"""Test that upscalers are cached."""
with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True):
config = UpscaleConfig()
manager = UpscaleManager(config)
upscaler1 = manager.get_upscaler(UpscalerType.FFMPEG_SR)
upscaler2 = manager.get_upscaler(UpscalerType.FFMPEG_SR)
assert upscaler1 is upscaler2 # Same instance
def test_get_upscaler_unavailable(self):
"""Test getting unavailable upscaler."""
with patch.object(FFmpegSRUpscaler, 'is_available', return_value=False):
config = UpscaleConfig()
manager = UpscaleManager(config)
with pytest.raises(RuntimeError) as exc_info:
manager.get_upscaler(UpscalerType.FFMPEG_SR)
assert "not available" in str(exc_info.value)
def test_upscale_with_manager(self):
"""Test upscale through manager."""
with tempfile.TemporaryDirectory() as tmpdir:
input_file = Path(tmpdir) / "input.mp4"
input_file.write_bytes(b"dummy")
output_file = Path(tmpdir) / "output.mp4"
mock_result = UpscaleResult(
success=True,
output_path=output_file,
input_resolution=(1920, 1080),
output_resolution=(3840, 2160)
)
with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True):
with patch.object(FFmpegSRUpscaler, 'upscale', return_value=mock_result):
config = UpscaleConfig()
manager = UpscaleManager(config)
result = manager.upscale(input_file, output_file)
assert result.success is True
assert result.output_resolution == (3840, 2160)
def test_get_available_upscalers(self):
"""Test getting list of available upscalers."""
with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True):
with patch.object(RealESRGANUpscaler, 'is_available', return_value=False):
config = UpscaleConfig()
manager = UpscaleManager(config)
available = manager.get_available_upscalers()
assert UpscalerType.FFMPEG_SR in available
assert UpscalerType.REAL_ESRGAN not in available
def test_upscale_with_fallback_success(self):
"""Test fallback upscaling with first option succeeding."""
with tempfile.TemporaryDirectory() as tmpdir:
input_file = Path(tmpdir) / "input.mp4"
input_file.write_bytes(b"dummy")
output_file = Path(tmpdir) / "output.mp4"
mock_result = UpscaleResult(
success=True,
output_path=output_file,
input_resolution=(1920, 1080),
output_resolution=(3840, 2160)
)
with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True):
with patch.object(FFmpegSRUpscaler, 'upscale', return_value=mock_result):
config = UpscaleConfig()
manager = UpscaleManager(config)
result = manager.upscale_with_fallback(input_file, output_file)
assert result.success is True
def test_upscale_with_fallback_all_fail(self):
"""Test fallback upscaling when all options fail."""
with tempfile.TemporaryDirectory() as tmpdir:
input_file = Path(tmpdir) / "input.mp4"
input_file.write_bytes(b"dummy")
output_file = Path(tmpdir) / "output.mp4"
mock_result = UpscaleResult(
success=False,
error_message="Processing failed"
)
with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True):
with patch.object(FFmpegSRUpscaler, 'upscale', return_value=mock_result):
config = UpscaleConfig()
manager = UpscaleManager(config)
result = manager.upscale_with_fallback(input_file, output_file)
assert result.success is False
assert "All upscalers failed" in result.error_message
class TestUpscaleEnums:
"""Test upscaler enums."""
def test_upscaler_type_values(self):
"""Test upscaler type enum values."""
assert UpscalerType.REAL_ESRGAN.value == "real_esrgan"
assert UpscalerType.FFMPEG_SR.value == "ffmpeg_sr"
def test_upscale_factor_values(self):
"""Test upscale factor enum values."""
assert UpscaleFactor.X2.value == 2
assert UpscaleFactor.X4.value == 4

View File

@@ -92,7 +92,7 @@ class TestWanBackend:
# Double resolution (4x pixels) # Double resolution (4x pixels)
vram_1440_81 = backend.estimate_vram_usage(1920, 1080, 81) vram_1440_81 = backend.estimate_vram_usage(1920, 1080, 81)
assert vram_1440_81 > vram_720_81 * 3 # Should be ~4x assert vram_1440_81 > vram_720_81 * 2 # Should be ~2x due to model overhead
# Double frames # Double frames
vram_720_162 = backend.estimate_vram_usage(1280, 720, 162) vram_720_162 = backend.estimate_vram_usage(1280, 720, 162)