diff --git a/src/cli/__init__.py b/src/cli/__init__.py index d84e788..482530c 100644 --- a/src/cli/__init__.py +++ b/src/cli/__init__.py @@ -1,3 +1,7 @@ """ CLI entry points. """ + +from src.cli.main import app, main + +__all__ = ["app", "main"] diff --git a/src/cli/main.py b/src/cli/main.py new file mode 100644 index 0000000..9ce1cee --- /dev/null +++ b/src/cli/main.py @@ -0,0 +1,321 @@ +""" +CLI entry point for storyboard video generation. +""" + +import sys +from pathlib import Path +from typing import Optional + +import typer +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn +from rich.table import Table + +from src.storyboard.loader import StoryboardValidator +from src.storyboard.prompt_compiler import PromptCompiler +from src.storyboard.shot_planner import ShotPlanner +from src.core.config import ConfigLoader +from src.core.checkpoint import CheckpointManager +from src.generation.base import BackendFactory +from src.assembly.assembler import FFmpegAssembler, AssemblyConfig +from src.upscaling.upscaler import UpscaleManager, UpscaleConfig + +app = typer.Typer(help="Storyboard to Video Generation Pipeline") +console = Console() + + +@app.command() +def generate( + storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"), + output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"), + config: Optional[Path] = typer.Option(None, "--config", "-c", help="Path to config file"), + backend: str = typer.Option("wan", "--backend", "-b", help="Generation backend (wan, svd)"), + resume: bool = typer.Option(False, "--resume", "-r", help="Resume from checkpoint"), + skip_generation: bool = typer.Option(False, "--skip-generation", help="Skip generation, only assemble"), + skip_assembly: bool = typer.Option(False, "--skip-assembly", help="Skip assembly, only generate shots"), + upscale: Optional[int] = typer.Option(None, "--upscale", help="Upscale factor (2 or 4)"), + dry_run: bool = typer.Option(False, "--dry-run", help="Validate storyboard without generating"), +): + """Generate video from storyboard.""" + + # Validate storyboard exists + if not storyboard.exists(): + console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]") + raise typer.Exit(1) + + # Load configuration + try: + if config: + app_config = ConfigLoader.load(config) + else: + app_config = ConfigLoader.load() + except Exception as e: + console.print(f"[red]Error loading config: {e}[/red]") + raise typer.Exit(1) + + # Validate storyboard + console.print("[bold blue]Validating storyboard...[/bold blue]") + validator = StoryboardValidator() + try: + storyboard_data = validator.load(storyboard) + console.print(f"[green]✓[/green] Storyboard validated: {len(storyboard_data.shots)} shots") + except Exception as e: + console.print(f"[red]Error validating storyboard: {e}[/red]") + raise typer.Exit(1) + + if dry_run: + console.print("[yellow]Dry run complete. Exiting.[/yellow]") + return + + # Setup output directories + project_dir = output / storyboard_data.project.title.replace(" ", "_") + shots_dir = project_dir / "shots" + assembled_dir = project_dir / "assembled" + metadata_dir = project_dir / "metadata" + + for d in [shots_dir, assembled_dir, metadata_dir]: + d.mkdir(parents=True, exist_ok=True) + + # Initialize checkpoint manager + checkpoint_db = project_dir / "checkpoints.db" + checkpoint_mgr = CheckpointManager(str(checkpoint_db)) + + # Initialize components + prompt_compiler = PromptCompiler(storyboard_data.project.global_style) + shot_planner = ShotPlanner( + fps=storyboard_data.project.fps or 24, + max_chunk_duration=app_config.backend.max_chunk_seconds or 6.0 + ) + + # Initialize generation backend + if not skip_generation: + console.print(f"[bold blue]Initializing {backend} backend...[/bold blue]") + try: + backend_config = app_config.get_backend_config(backend) + video_backend = BackendFactory.create_backend(backend, backend_config) + console.print(f"[green]✓[/green] Backend initialized: {backend}") + except Exception as e: + console.print(f"[red]Error initializing backend: {e}[/red]") + raise typer.Exit(1) + + # Generate shots + generated_shots = [] + + if not skip_generation: + console.print(f"[bold blue]Generating {len(storyboard_data.shots)} shots...[/bold blue]") + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console + ) as progress: + + for i, shot in enumerate(storyboard_data.shots): + task_id = progress.add_task(f"Shot {shot.id}", total=None) + + # Check checkpoint + if resume: + checkpoint = checkpoint_mgr.get_shot_checkpoint( + storyboard_data.project.title, shot.id + ) + if checkpoint and checkpoint.status == "completed": + progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id} (cached)") + generated_shots.append(Path(checkpoint.video_path)) + continue + + # Compile prompt + prompt = prompt_compiler.compile_shot_prompt(shot) + negative_prompt = prompt_compiler.compile_negative_prompt(shot) + + # Plan shot + shot_plan = shot_planner.plan_shot(shot) + + # Generate + try: + from src.generation.base import GenerationSpec + + spec = GenerationSpec( + prompt=prompt, + negative_prompt=negative_prompt, + width=shot_plan.width, + height=shot_plan.height, + num_frames=shot_plan.total_frames, + fps=shot_plan.fps, + seed=shot.generation.seed + ) + + result = video_backend.generate(spec, output_dir=shots_dir) + + if result.success: + generated_shots.append(result.video_path) + checkpoint_mgr.save_shot_checkpoint( + project_name=storyboard_data.project.title, + shot_id=shot.id, + status="completed", + video_path=str(result.video_path), + metadata=result.metadata + ) + progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id}") + else: + progress.update(task_id, description=f"[red]✗[/red] Shot {shot.id}: {result.error_message}") + checkpoint_mgr.save_shot_checkpoint( + project_name=storyboard_data.project.title, + shot_id=shot.id, + status="failed", + error_message=result.error_message + ) + + except Exception as e: + progress.update(task_id, description=f"[red]✗[/red] Shot {shot.id}: {e}") + checkpoint_mgr.save_shot_checkpoint( + project_name=storyboard_data.project.title, + shot_id=shot.id, + status="failed", + error_message=str(e) + ) + + # Assembly + if not skip_assembly and generated_shots: + console.print("[bold blue]Assembling video...[/bold blue]") + + assembler = FFmpegAssembler() + final_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}.mp4" + + assembly_config = AssemblyConfig( + fps=storyboard_data.project.fps or 24, + add_shot_labels=False + ) + + result = assembler.assemble(generated_shots, final_output, assembly_config) + + if result.success: + console.print(f"[green]✓[/green] Video assembled: {final_output}") + + # Upscale if requested + if upscale: + console.print(f"[bold blue]Upscaling {upscale}x...[/bold blue]") + + upscale_config = UpscaleConfig( + factor=upscale, + upscaler_type="ffmpeg_sr" + ) + upscale_mgr = UpscaleManager(upscale_config) + + upscaled_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}_upscaled.mp4" + upscale_result = upscale_mgr.upscale(final_output, upscaled_output) + + if upscale_result.success: + console.print(f"[green]✓[/green] Upscaled video: {upscaled_output}") + final_output = upscaled_output + else: + console.print(f"[yellow]Warning: Upscaling failed: {upscale_result.error_message}[/yellow]") + + # Save project metadata + checkpoint_mgr.save_project_checkpoint( + project_name=storyboard_data.project.title, + storyboard_path=str(storyboard), + output_path=str(final_output), + status="completed", + num_shots=len(generated_shots) + ) + + console.print(f"\n[bold green]Success![/bold green] Video saved to: {final_output}") + else: + console.print(f"[red]Error assembling video: {result.error_message}[/red]") + raise typer.Exit(1) + + +@app.command() +def validate( + storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"), + verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed information"), +): + """Validate a storyboard file.""" + + if not storyboard.exists(): + console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]") + raise typer.Exit(1) + + validator = StoryboardValidator() + + try: + data = validator.load(storyboard) + + console.print(f"[green]✓[/green] Storyboard is valid") + console.print(f"\n[bold]Title:[/bold] {data.project.title}") + console.print(f"[bold]Shots:[/bold] {len(data.shots)}") + console.print(f"[bold]FPS:[/bold] {data.project.fps or 'Not specified'}") + resolution = f"{data.project.resolution.width}x{data.project.resolution.height}" if data.project.resolution else "Not specified" + console.print(f"[bold]Resolution:[/bold] {resolution}") + + if verbose: + table = Table(title="Shots") + table.add_column("ID", style="cyan") + table.add_column("Prompt", style="green") + table.add_column("Duration", style="yellow") + + for shot in data.shots: + table.add_row(shot.id, shot.prompt[:50], f"{shot.duration_s}s") + + console.print(table) + + except Exception as e: + console.print(f"[red]Validation failed: {e}[/red]") + raise typer.Exit(1) + + +@app.command() +def resume( + project: str = typer.Argument(..., help="Project name"), + output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"), +): + """Resume a failed or interrupted project.""" + + project_dir = output / project.replace(" ", "_") + checkpoint_db = project_dir / "checkpoints.db" + + if not checkpoint_db.exists(): + console.print(f"[red]Error: No checkpoint found for project '{project}'[/red]") + raise typer.Exit(1) + + checkpoint_mgr = CheckpointManager(str(checkpoint_db)) + project_checkpoint = checkpoint_mgr.get_project_checkpoint(project) + + if not project_checkpoint: + console.print(f"[red]Error: No project checkpoint found[/red]") + raise typer.Exit(1) + + console.print(f"[bold blue]Resuming project:[/bold blue] {project}") + console.print(f"Storyboard: {project_checkpoint.storyboard_path}") + + # Re-run generation with resume flag + generate( + storyboard=Path(project_checkpoint.storyboard_path), + output=output, + resume=True + ) + + +@app.command() +def list_backends(): + """List available generation backends.""" + + table = Table(title="Available Backends") + table.add_column("Name", style="cyan") + table.add_column("Type", style="green") + table.add_column("Description", style="white") + + table.add_row("wan", "T2V", "WAN 2.x text-to-video") + table.add_row("wan-1.3b", "T2V", "WAN 1.3B (faster, lower quality)") + table.add_row("svd", "I2V", "Stable Video Diffusion (fallback)") + + console.print(table) + + +def main(): + """Entry point for the CLI.""" + app() + + +if __name__ == "__main__": + main() diff --git a/src/upscaling/__init__.py b/src/upscaling/__init__.py index 19b04e7..23a7bcb 100644 --- a/src/upscaling/__init__.py +++ b/src/upscaling/__init__.py @@ -1,3 +1,25 @@ """ Upscaling module. """ + +from src.upscaling.upscaler import ( + BaseUpscaler, + FFmpegSRUpscaler, + RealESRGANUpscaler, + UpscaleManager, + UpscaleConfig, + UpscaleResult, + UpscalerType, + UpscaleFactor +) + +__all__ = [ + "BaseUpscaler", + "FFmpegSRUpscaler", + "RealESRGANUpscaler", + "UpscaleManager", + "UpscaleConfig", + "UpscaleResult", + "UpscalerType", + "UpscaleFactor" +] diff --git a/src/upscaling/upscaler.py b/src/upscaling/upscaler.py new file mode 100644 index 0000000..39db032 --- /dev/null +++ b/src/upscaling/upscaler.py @@ -0,0 +1,579 @@ +""" +Video upscaling module for enhancing video resolution. +Supports multiple upscaling backends including Real-ESRGAN, RealBasicVSR, +and FFmpeg-based super-resolution filters. +""" + +import subprocess +import tempfile +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Dict, Any, List, Optional, Tuple + + +class UpscalerType(str, Enum): + """Supported upscaler types.""" + REAL_ESRGAN = "real_esrgan" + REAL_BASIC_VSR = "real_basic_vsr" + FFMPEG_SR = "ffmpeg_sr" + OPENCV_DNN = "opencv_dnn" + + +class UpscaleFactor(int, Enum): + """Supported upscale factors.""" + X2 = 2 + X4 = 4 + + +@dataclass +class UpscaleConfig: + """Configuration for video upscaling.""" + upscaler_type: UpscalerType = UpscalerType.FFMPEG_SR + factor: UpscaleFactor = UpscaleFactor.X2 + denoise_strength: float = 0.5 + tile_size: int = 0 # 0 = no tiling + tile_pad: int = 10 + pre_pad: int = 0 + half_precision: bool = True # Use fp16 for speed + device: str = "cuda" + batch_size: int = 1 + + # Model paths (optional, will use defaults if not specified) + model_path: Optional[Path] = None + + # FFmpeg-specific options + ffmpeg_preset: str = "medium" + ffmpeg_crf: int = 18 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for metadata.""" + return { + "upscaler_type": self.upscaler_type.value, + "factor": self.factor.value, + "denoise_strength": self.denoise_strength, + "tile_size": self.tile_size, + "tile_pad": self.tile_pad, + "pre_pad": self.pre_pad, + "half_precision": self.half_precision, + "device": self.device, + "batch_size": self.batch_size, + "model_path": str(self.model_path) if self.model_path else None, + "ffmpeg_preset": self.ffmpeg_preset, + "ffmpeg_crf": self.ffmpeg_crf + } + + +@dataclass +class UpscaleResult: + """Result of video upscaling.""" + success: bool + output_path: Optional[Path] = None + input_resolution: Optional[Tuple[int, int]] = None + output_resolution: Optional[Tuple[int, int]] = None + processing_time_s: Optional[float] = None + error_message: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + +class BaseUpscaler(ABC): + """Abstract base class for video upscalers.""" + + def __init__(self, config: UpscaleConfig): + """ + Initialize upscaler. + + Args: + config: Upscale configuration + """ + self.config = config + + @abstractmethod + def upscale(self, input_path: Path, output_path: Path) -> UpscaleResult: + """ + Upscale a video file. + + Args: + input_path: Path to input video + output_path: Path for output video + + Returns: + UpscaleResult with output details + """ + pass + + @abstractmethod + def is_available(self) -> bool: + """Check if this upscaler is available on the system.""" + pass + + def _get_video_info(self, video_path: Path) -> Dict[str, Any]: + """ + Get video information using ffprobe. + + Args: + video_path: Path to video file + + Returns: + Dictionary with video info (width, height, fps, duration) + """ + try: + cmd = [ + "ffprobe", + "-v", "error", + "-select_streams", "v:0", + "-show_entries", "stream=width,height,r_frame_rate,duration", + "-of", "json", + str(video_path) + ] + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=30 + ) + + if result.returncode == 0: + import json + data = json.loads(result.stdout) + stream = data.get("streams", [{}])[0] + + # Parse frame rate fraction + fps_str = stream.get("r_frame_rate", "24/1") + if "/" in fps_str: + num, den = fps_str.split("/") + fps = float(num) / float(den) + else: + fps = float(fps_str) + + return { + "width": int(stream.get("width", 0)), + "height": int(stream.get("height", 0)), + "fps": fps, + "duration": float(stream.get("duration", 0)) + } + except Exception: + pass + + return {"width": 0, "height": 0, "fps": 24.0, "duration": 0.0} + + +class FFmpegSRUpscaler(BaseUpscaler): + """ + FFmpeg-based super-resolution upscaler. + Uses FFmpeg's scale filter with various algorithms. + Good for quick upscaling without ML models. + """ + + def is_available(self) -> bool: + """Check if ffmpeg is available.""" + try: + result = subprocess.run( + ["ffmpeg", "-version"], + capture_output=True, + timeout=5 + ) + return result.returncode == 0 + except Exception: + return False + + def upscale(self, input_path: Path, output_path: Path) -> UpscaleResult: + """ + Upscale using FFmpeg's scale filter. + Uses lanczos scaling for quality. + """ + if not input_path.exists(): + return UpscaleResult( + success=False, + error_message=f"Input file not found: {input_path}" + ) + + # Get input resolution + info = self._get_video_info(input_path) + input_width = info["width"] + input_height = info["height"] + + if input_width == 0 or input_height == 0: + return UpscaleResult( + success=False, + error_message="Could not determine input video resolution" + ) + + # Calculate output resolution + factor = self.config.factor.value + output_width = input_width * factor + output_height = input_height * factor + + # Ensure output directory exists + output_path.parent.mkdir(parents=True, exist_ok=True) + + import time + start_time = time.time() + + try: + # Build FFmpeg command with high-quality scaling + # Using lanczos for better quality than bicubic + cmd = [ + "ffmpeg", + "-y", # Overwrite output + "-i", str(input_path), + "-vf", f"scale={output_width}:{output_height}:flags=lanczos", + "-c:v", "libx264", + "-preset", self.config.ffmpeg_preset, + "-crf", str(self.config.ffmpeg_crf), + "-pix_fmt", "yuv420p", + "-movflags", "+faststart", + str(output_path) + ] + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=7200 # 2 hour timeout + ) + + processing_time = time.time() - start_time + + if result.returncode != 0: + return UpscaleResult( + success=False, + error_message=f"FFmpeg error: {result.stderr}", + processing_time_s=processing_time + ) + + return UpscaleResult( + success=True, + output_path=output_path, + input_resolution=(input_width, input_height), + output_resolution=(output_width, output_height), + processing_time_s=processing_time, + metadata={ + "config": self.config.to_dict(), + "algorithm": "lanczos", + "fps": info["fps"], + "duration": info["duration"] + } + ) + + except subprocess.TimeoutExpired: + return UpscaleResult( + success=False, + error_message="FFmpeg upscaling timed out (2 hours)", + processing_time_s=time.time() - start_time + ) + except Exception as e: + return UpscaleResult( + success=False, + error_message=f"Upscaling failed: {str(e)}", + processing_time_s=time.time() - start_time + ) + + +class RealESRGANUpscaler(BaseUpscaler): + """ + Real-ESRGAN upscaler for high-quality 2x/4x upscaling. + Requires Real-ESRGAN to be installed. + """ + + def is_available(self) -> bool: + """Check if Real-ESRGAN is available.""" + try: + import realesrgan + return True + except ImportError: + return False + + def upscale(self, input_path: Path, output_path: Path) -> UpscaleResult: + """ + Upscale using Real-ESRGAN. + Processes video frame by frame. + """ + if not input_path.exists(): + return UpscaleResult( + success=False, + error_message=f"Input file not found: {input_path}" + ) + + try: + import torch + from realesrgan import RealESRGANer + from basicsr.archs.rrdbnet_arch import RRDBNet + except ImportError as e: + return UpscaleResult( + success=False, + error_message=f"Real-ESRGAN not installed: {e}" + ) + + # Get input resolution + info = self._get_video_info(input_path) + input_width = info["width"] + input_height = info["height"] + + if input_width == 0 or input_height == 0: + return UpscaleResult( + success=False, + error_message="Could not determine input video resolution" + ) + + # Calculate output resolution + factor = self.config.factor.value + output_width = input_width * factor + output_height = input_height * factor + + # Ensure output directory exists + output_path.parent.mkdir(parents=True, exist_ok=True) + + import time + import cv2 + import numpy as np + + start_time = time.time() + + try: + # Initialize Real-ESRGAN model + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=factor + ) + + # Determine model name based on factor + model_name = f"RealESRGAN_x{factor}plus.pth" + model_path = self.config.model_path or Path(f"weights/{model_name}") + + upsampler = RealESRGANer( + scale=factor, + model_path=str(model_path), + model=model, + tile=self.config.tile_size, + tile_pad=self.config.tile_pad, + pre_pad=self.config.pre_pad, + half=self.config.half_precision, + device=torch.device(self.config.device) + ) + + # Extract frames, upscale, and reassemble + with tempfile.TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) + frames_dir = tmp_path / "frames" + frames_dir.mkdir() + upscaled_dir = tmp_path / "upscaled" + upscaled_dir.mkdir() + + # Extract frames + self._extract_frames(input_path, frames_dir, info["fps"]) + + # Upscale frames + frame_files = sorted(frames_dir.glob("*.png")) + for i, frame_file in enumerate(frame_files): + img = cv2.imread(str(frame_file), cv2.IMREAD_UNCHANGED) + if img is not None: + output, _ = upsampler.enhance(img, outscale=factor) + cv2.imwrite(str(upscaled_dir / f"frame_{i:06d}.png"), output) + + # Reassemble video + self._assemble_video( + upscaled_dir, + output_path, + info["fps"], + output_width, + output_height + ) + + processing_time = time.time() - start_time + + return UpscaleResult( + success=True, + output_path=output_path, + input_resolution=(input_width, input_height), + output_resolution=(output_width, output_height), + processing_time_s=processing_time, + metadata={ + "config": self.config.to_dict(), + "algorithm": "Real-ESRGAN", + "model": model_name, + "frames_processed": len(frame_files) + } + ) + + except Exception as e: + return UpscaleResult( + success=False, + error_message=f"Real-ESRGAN upscaling failed: {str(e)}", + processing_time_s=time.time() - start_time + ) + + def _extract_frames(self, video_path: Path, output_dir: Path, fps: float): + """Extract frames from video.""" + cmd = [ + "ffmpeg", + "-i", str(video_path), + "-vf", f"fps={fps}", + "-q:v", "2", + str(output_dir / "frame_%06d.png") + ] + subprocess.run(cmd, capture_output=True, check=True) + + def _assemble_video( + self, + frames_dir: Path, + output_path: Path, + fps: float, + width: int, + height: int + ): + """Assemble frames into video.""" + cmd = [ + "ffmpeg", + "-y", + "-framerate", str(fps), + "-i", str(frames_dir / "frame_%06d.png"), + "-c:v", "libx264", + "-preset", self.config.ffmpeg_preset, + "-crf", str(self.config.ffmpeg_crf), + "-pix_fmt", "yuv420p", + "-s", f"{width}x{height}", + str(output_path) + ] + subprocess.run(cmd, capture_output=True, check=True) + + +class UpscaleManager: + """ + Manager for video upscaling operations. + Automatically selects best available upscaler. + """ + + def __init__(self, config: Optional[UpscaleConfig] = None): + """ + Initialize upscaling manager. + + Args: + config: Upscale configuration (uses defaults if None) + """ + self.config = config or UpscaleConfig() + self._upscalers: Dict[UpscalerType, BaseUpscaler] = {} + + def get_upscaler(self, upscaler_type: Optional[UpscalerType] = None) -> BaseUpscaler: + """ + Get upscaler instance. + + Args: + upscaler_type: Type of upscaler to use (uses config default if None) + + Returns: + Upscaler instance + + Raises: + RuntimeError: If requested upscaler is not available + """ + upscaler_type = upscaler_type or self.config.upscaler_type + + if upscaler_type not in self._upscalers: + if upscaler_type == UpscalerType.FFMPEG_SR: + upscaler = FFmpegSRUpscaler(self.config) + elif upscaler_type == UpscalerType.REAL_ESRGAN: + upscaler = RealESRGANUpscaler(self.config) + else: + raise ValueError(f"Unknown upscaler type: {upscaler_type}") + + if not upscaler.is_available(): + raise RuntimeError(f"Upscaler {upscaler_type.value} is not available") + + self._upscalers[upscaler_type] = upscaler + + return self._upscalers[upscaler_type] + + def upscale( + self, + input_path: Path, + output_path: Path, + upscaler_type: Optional[UpscalerType] = None + ) -> UpscaleResult: + """ + Upscale a video using specified or default upscaler. + + Args: + input_path: Path to input video + output_path: Path for output video + upscaler_type: Type of upscaler to use (uses config default if None) + + Returns: + UpscaleResult + """ + try: + upscaler = self.get_upscaler(upscaler_type) + return upscaler.upscale(input_path, output_path) + except Exception as e: + return UpscaleResult( + success=False, + error_message=str(e) + ) + + def get_available_upscalers(self) -> List[UpscalerType]: + """ + Get list of available upscalers on this system. + + Returns: + List of available upscaler types + """ + available = [] + + for upscaler_type in UpscalerType: + try: + upscaler = self.get_upscaler(upscaler_type) + if upscaler.is_available(): + available.append(upscaler_type) + except Exception: + pass + + return available + + def upscale_with_fallback( + self, + input_path: Path, + output_path: Path, + preferred_order: Optional[List[UpscalerType]] = None + ) -> UpscaleResult: + """ + Upscale video trying multiple upscalers in order. + + Args: + input_path: Path to input video + output_path: Path for output video + preferred_order: List of upscalers to try in order + + Returns: + UpscaleResult from first successful upscaler + """ + if preferred_order is None: + # Default order: try ML-based first, then FFmpeg + preferred_order = [ + UpscalerType.REAL_ESRGAN, + UpscalerType.FFMPEG_SR + ] + + errors = [] + + for upscaler_type in preferred_order: + try: + upscaler = self.get_upscaler(upscaler_type) + if upscaler.is_available(): + result = upscaler.upscale(input_path, output_path) + if result.success: + return result + else: + errors.append(f"{upscaler_type.value}: {result.error_message}") + except Exception as e: + errors.append(f"{upscaler_type.value}: {str(e)}") + + return UpscaleResult( + success=False, + error_message="All upscalers failed. Errors: " + "; ".join(errors) + ) diff --git a/tests/unit/test_assembler.py b/tests/unit/test_assembler.py new file mode 100644 index 0000000..43d4d1f --- /dev/null +++ b/tests/unit/test_assembler.py @@ -0,0 +1,389 @@ +""" +Tests for FFmpeg assembler module. +""" + +import pytest +import tempfile +from pathlib import Path +from unittest.mock import patch, MagicMock +from src.assembly.assembler import ( + FFmpegAssembler, + AssemblyConfig, + AssemblyResult, + TransitionType +) + + +class TestAssemblyConfig: + """Test AssemblyConfig dataclass.""" + + def test_default_config(self): + """Test default configuration values.""" + config = AssemblyConfig() + assert config.fps == 24 + assert config.container == "mp4" + assert config.codec == "h264" + assert config.crf == 18 + assert config.preset == "medium" + assert config.transition == TransitionType.NONE + assert config.transition_duration_ms == 500 + assert config.add_shot_labels is False + assert config.audio_track is None + + def test_config_to_dict(self): + """Test configuration serialization.""" + config = AssemblyConfig( + fps=30, + codec="h265", + transition=TransitionType.FADE + ) + d = config.to_dict() + assert d["fps"] == 30 + assert d["codec"] == "h265" + assert d["transition"] == "fade" + assert d["audio_track"] is None + + def test_config_with_audio(self): + """Test configuration with audio track.""" + audio_path = Path("/path/to/audio.mp3") + config = AssemblyConfig(audio_track=audio_path) + d = config.to_dict() + assert d["audio_track"] == str(audio_path) + + +class TestFFmpegAssemblerInit: + """Test FFmpegAssembler initialization.""" + + def test_default_init(self): + """Test default initialization.""" + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + assembler = FFmpegAssembler() + assert assembler.ffmpeg_path == "ffmpeg" + + def test_custom_ffmpeg_path(self): + """Test initialization with custom ffmpeg path.""" + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + assembler = FFmpegAssembler(ffmpeg_path="/usr/bin/ffmpeg") + assert assembler.ffmpeg_path == "/usr/bin/ffmpeg" + + def test_ffmpeg_not_found(self): + """Test behavior when ffmpeg is not available.""" + with patch('subprocess.run', side_effect=FileNotFoundError()): + assembler = FFmpegAssembler() + # Should not raise, just have _check_ffmpeg return False + assert assembler._check_ffmpeg() is False + + +class TestFFmpegAssemblerAssemble: + """Test video assembly functionality.""" + + @pytest.fixture + def mock_subprocess(self): + """Fixture to mock subprocess.run.""" + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0, stderr="") + yield mock_run + + @pytest.fixture + def temp_dir(self): + """Fixture for temporary directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + def test_assemble_no_files(self, mock_subprocess): + """Test assembly with no input files.""" + assembler = FFmpegAssembler() + result = assembler.assemble([], Path("output.mp4")) + + assert result.success is False + assert "No shot files provided" in result.error_message + + def test_assemble_missing_files(self, mock_subprocess, temp_dir): + """Test assembly with missing input files.""" + assembler = FFmpegAssembler() + missing_file = temp_dir / "nonexistent.mp4" + result = assembler.assemble([missing_file], temp_dir / "output.mp4") + + assert result.success is False + assert "Missing shot files" in result.error_message + + def test_assemble_single_file(self, mock_subprocess, temp_dir): + """Test simple concatenation with single file.""" + # Create dummy input file + input_file = temp_dir / "shot_001.mp4" + input_file.write_bytes(b"dummy video data") + output_file = temp_dir / "output.mp4" + + assembler = FFmpegAssembler() + + # Mock ffprobe for duration check + with patch.object(assembler, '_get_video_duration', return_value=4.0): + result = assembler.assemble([input_file], output_file) + + assert result.success is True + assert result.output_path == output_file + assert result.num_shots == 1 + + # Verify ffmpeg was called + assert mock_subprocess.called + call_args = mock_subprocess.call_args[0][0] + assert "ffmpeg" in call_args[0] + assert "-f" in call_args + assert "concat" in call_args + + def test_assemble_multiple_files(self, mock_subprocess, temp_dir): + """Test concatenation with multiple files.""" + # Create dummy input files + input_files = [ + temp_dir / "shot_001.mp4", + temp_dir / "shot_002.mp4", + temp_dir / "shot_003.mp4" + ] + for f in input_files: + f.write_bytes(b"dummy video data") + + output_file = temp_dir / "output.mp4" + + assembler = FFmpegAssembler() + + with patch.object(assembler, '_get_video_duration', return_value=4.0): + result = assembler.assemble(input_files, output_file) + + assert result.success is True + assert result.num_shots == 3 + assert result.output_path == output_file + + def test_assemble_with_config(self, mock_subprocess, temp_dir): + """Test assembly with custom configuration.""" + input_file = temp_dir / "shot_001.mp4" + input_file.write_bytes(b"dummy video data") + output_file = temp_dir / "output.mp4" + + config = AssemblyConfig( + fps=30, + codec="h265", + crf=20, + preset="slow" + ) + + assembler = FFmpegAssembler() + + with patch.object(assembler, '_get_video_duration', return_value=4.0): + result = assembler.assemble([input_file], output_file, config) + + assert result.success is True + + # Verify config was used + call_args = mock_subprocess.call_args[0][0] + assert "-r" in call_args + assert "30" in call_args + assert "libx265" in call_args + assert "-crf" in call_args + assert "20" in call_args + assert "slow" in call_args + + def test_assemble_ffmpeg_error(self, mock_subprocess, temp_dir): + """Test handling of ffmpeg error.""" + input_file = temp_dir / "shot_001.mp4" + input_file.write_bytes(b"dummy video data") + output_file = temp_dir / "output.mp4" + + # Make ffmpeg return error + mock_subprocess.return_value = MagicMock( + returncode=1, + stderr="Error: Invalid data" + ) + + assembler = FFmpegAssembler() + result = assembler.assemble([input_file], output_file) + + assert result.success is False + assert "FFmpeg error" in result.error_message + + +class TestFFmpegAssemblerTransitions: + """Test transition functionality.""" + + @pytest.fixture + def mock_subprocess(self): + """Fixture to mock subprocess.run.""" + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0, stderr="") + yield mock_run + + @pytest.fixture + def temp_dir(self): + """Fixture for temporary directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + def test_fade_transition(self, mock_subprocess, temp_dir): + """Test assembly with fade transition.""" + input_files = [ + temp_dir / "shot_001.mp4", + temp_dir / "shot_002.mp4" + ] + for f in input_files: + f.write_bytes(b"dummy video data") + + output_file = temp_dir / "output.mp4" + + config = AssemblyConfig( + transition=TransitionType.FADE, + transition_duration_ms=500 + ) + + assembler = FFmpegAssembler() + + with patch.object(assembler, '_get_video_duration', return_value=4.0): + result = assembler.assemble(input_files, output_file, config) + + assert result.success is True + + # Verify filter_complex was used + call_args = mock_subprocess.call_args[0][0] + assert "-filter_complex" in call_args + assert "xfade" in str(call_args) + + +class TestFFmpegAssemblerAudio: + """Test audio functionality.""" + + @pytest.fixture + def mock_subprocess(self): + """Fixture to mock subprocess.run.""" + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0, stderr="") + yield mock_run + + @pytest.fixture + def temp_dir(self): + """Fixture for temporary directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + def test_add_audio(self, mock_subprocess, temp_dir): + """Test adding audio track.""" + input_file = temp_dir / "shot_001.mp4" + input_file.write_bytes(b"dummy video data") + audio_file = temp_dir / "audio.mp3" + audio_file.write_bytes(b"dummy audio data") + output_file = temp_dir / "output.mp4" + + config = AssemblyConfig(audio_track=audio_file) + + assembler = FFmpegAssembler() + + with patch.object(assembler, '_get_video_duration', return_value=4.0): + result = assembler.assemble([input_file], output_file, config) + + assert result.success is True + + # Verify audio was added (second ffmpeg call) + calls = mock_subprocess.call_args_list + # First call is for video, second should be for audio + assert len(calls) >= 2 + + +class TestFFmpegAssemblerUtilities: + """Test utility methods.""" + + def test_get_video_codec(self): + """Test codec name mapping.""" + assembler = FFmpegAssembler() + + assert assembler._get_video_codec("h264") == "libx264" + assert assembler._get_video_codec("h265") == "libx265" + assert assembler._get_video_codec("vp9") == "libvpx-vp9" + assert assembler._get_video_codec("unknown") == "libx264" # Default + + def test_get_video_duration(self): + """Test duration extraction.""" + assembler = FFmpegAssembler() + + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock( + returncode=0, + stdout="4.500\n" + ) + duration = assembler._get_video_duration(Path("test.mp4")) + assert duration == 4.5 + + def test_get_video_duration_error(self): + """Test duration extraction with error.""" + assembler = FFmpegAssembler() + + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=1) + duration = assembler._get_video_duration(Path("test.mp4")) + assert duration == 0.0 + + def test_extract_frame(self): + """Test frame extraction.""" + assembler = FFmpegAssembler() + + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + result = assembler.extract_frame( + Path("video.mp4"), + 2.5, + Path("frame.jpg") + ) + assert result is True + + # Verify ffmpeg command + call_args = mock_run.call_args[0][0] + assert "-ss" in call_args + assert "2.5" in call_args + assert "-vframes" in call_args + assert "1" in call_args + + def test_burn_in_labels(self): + """Test label burn-in.""" + assembler = FFmpegAssembler() + + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0) + result = assembler.burn_in_labels( + Path("input.mp4"), + Path("output.mp4"), + ["Shot 1", "Shot 2"], + font_size=24, + position="top-left" + ) + assert result.success is True + + # Verify drawtext filter + call_args = mock_run.call_args[0][0] + assert "-vf" in call_args + assert "drawtext" in str(call_args) + + +class TestAssemblyResult: + """Test AssemblyResult dataclass.""" + + def test_success_result(self): + """Test successful result.""" + result = AssemblyResult( + success=True, + output_path=Path("output.mp4"), + duration_s=12.5, + num_shots=3 + ) + assert result.success is True + assert result.output_path == Path("output.mp4") + assert result.duration_s == 12.5 + assert result.num_shots == 3 + assert result.metadata == {} + + def test_failure_result(self): + """Test failure result.""" + result = AssemblyResult( + success=False, + error_message="FFmpeg failed" + ) + assert result.success is False + assert result.error_message == "FFmpeg failed" + assert result.output_path is None diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py new file mode 100644 index 0000000..2f6b710 --- /dev/null +++ b/tests/unit/test_cli.py @@ -0,0 +1,187 @@ +""" +Tests for CLI module. +""" + +import pytest +import tempfile +from pathlib import Path +from unittest.mock import patch, MagicMock +from typer.testing import CliRunner + +from src.cli.main import app + + +runner = CliRunner() + + +# Valid storyboard JSON template +VALID_STORYBOARD = """ +{ + "schema_version": "1.0", + "project": { + "title": "Test Storyboard", + "fps": 24, + "target_duration_s": 20, + "resolution": {"width": 1280, "height": 720}, + "aspect_ratio": "16:9" + }, + "shots": [ + { + "id": "shot_001", + "duration_s": 4, + "prompt": "A test shot description", + "camera": {"framing": "wide", "movement": "static"}, + "generation": {"seed": 42, "steps": 30, "cfg_scale": 6.0} + } + ], + "output": { + "container": "mp4", + "codec": "h264", + "crf": 18, + "preset": "medium" + } +} +""" + + +class TestValidateCommand: + """Test the validate command.""" + + def test_validate_nonexistent_file(self): + """Test validation with non-existent file.""" + result = runner.invoke(app, ["validate", "nonexistent.json"]) + assert result.exit_code == 1 + assert "not found" in result.output + + def test_validate_valid_storyboard(self): + """Test validation with valid storyboard.""" + with tempfile.TemporaryDirectory() as tmpdir: + storyboard_file = Path(tmpdir) / "test.json" + storyboard_file.write_text(VALID_STORYBOARD) + + result = runner.invoke(app, ["validate", str(storyboard_file)]) + assert result.exit_code == 0 + assert "valid" in result.output + assert "Test Storyboard" in result.output + + def test_validate_verbose(self): + """Test validation with verbose flag.""" + with tempfile.TemporaryDirectory() as tmpdir: + storyboard_file = Path(tmpdir) / "test.json" + storyboard_file.write_text(VALID_STORYBOARD) + + result = runner.invoke(app, ["validate", str(storyboard_file), "--verbose"]) + assert result.exit_code == 0 + assert "Shots" in result.output + + +class TestListBackendsCommand: + """Test the list-backends command.""" + + def test_list_backends(self): + """Test listing available backends.""" + result = runner.invoke(app, ["list-backends"]) + assert result.exit_code == 0 + assert "wan" in result.output + assert "Available Backends" in result.output + + +class TestGenerateCommand: + """Test the generate command.""" + + @pytest.fixture + def mock_storyboard(self): + """Create a mock storyboard file.""" + with tempfile.TemporaryDirectory() as tmpdir: + storyboard_file = Path(tmpdir) / "test.json" + storyboard_file.write_text(VALID_STORYBOARD) + yield storyboard_file + + def test_generate_dry_run(self, mock_storyboard): + """Test generate command with dry-run flag.""" + result = runner.invoke(app, [ + "generate", + str(mock_storyboard), + "--dry-run" + ]) + assert result.exit_code == 0 + assert "Dry run complete" in result.output + + def test_generate_nonexistent_storyboard(self): + """Test generate with non-existent storyboard.""" + result = runner.invoke(app, ["generate", "nonexistent.json"]) + assert result.exit_code == 1 + assert "not found" in result.output + + +class TestResumeCommand: + """Test the resume command.""" + + def test_resume_nonexistent_project(self): + """Test resume with non-existent project.""" + with tempfile.TemporaryDirectory() as tmpdir: + result = runner.invoke(app, [ + "resume", + "nonexistent_project", + "--output", tmpdir + ]) + assert result.exit_code == 1 + assert "No checkpoint found" in result.output + + def test_resume_existing_project(self): + """Test resume with existing project checkpoint.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_dir = Path(tmpdir) / "test_project" + project_dir.mkdir() + checkpoint_db = project_dir / "checkpoints.db" + checkpoint_db.touch() + + # Create a mock checkpoint + with patch('src.cli.main.CheckpointManager') as mock_mgr: + mock_checkpoint = MagicMock() + mock_checkpoint.storyboard_path = str(Path(tmpdir) / "storyboard.json") + mock_mgr.return_value.get_project_checkpoint.return_value = mock_checkpoint + + # Create storyboard file + storyboard_file = Path(tmpdir) / "storyboard.json" + storyboard_file.write_text(VALID_STORYBOARD) + + result = runner.invoke(app, [ + "resume", + "test_project", + "--output", tmpdir + ]) + + # Should attempt to resume (may fail on actual generation) + assert "Resuming project" in result.output or result.exit_code != 0 + + +class TestCLIHelp: + """Test CLI help messages.""" + + def test_main_help(self): + """Test main help message.""" + result = runner.invoke(app, ["--help"]) + assert result.exit_code == 0 + assert "Storyboard to Video Generation Pipeline" in result.output + + def test_generate_help(self): + """Test generate command help.""" + result = runner.invoke(app, ["generate", "--help"]) + assert result.exit_code == 0 + assert "Generate video from storyboard" in result.output + assert "--output" in result.output + assert "--backend" in result.output + assert "--dry-run" in result.output + + def test_validate_help(self): + """Test validate command help.""" + result = runner.invoke(app, ["validate", "--help"]) + assert result.exit_code == 0 + assert "Validate a storyboard file" in result.output + + def test_resume_help(self): + """Test resume command help.""" + result = runner.invoke(app, ["resume", "--help"]) + assert result.exit_code == 0 + assert "Resume a failed or interrupted project" in result.output diff --git a/tests/unit/test_upscaler.py b/tests/unit/test_upscaler.py new file mode 100644 index 0000000..d45501e --- /dev/null +++ b/tests/unit/test_upscaler.py @@ -0,0 +1,350 @@ +""" +Tests for upscaling module. +""" + +import pytest +import tempfile +from pathlib import Path +from unittest.mock import patch, MagicMock, mock_open +from src.upscaling.upscaler import ( + UpscaleConfig, + UpscaleResult, + UpscalerType, + UpscaleFactor, + FFmpegSRUpscaler, + RealESRGANUpscaler, + UpscaleManager, + BaseUpscaler +) + + +class TestUpscaleConfig: + """Test UpscaleConfig dataclass.""" + + def test_default_config(self): + """Test default configuration values.""" + config = UpscaleConfig() + assert config.upscaler_type == UpscalerType.FFMPEG_SR + assert config.factor == UpscaleFactor.X2 + assert config.denoise_strength == 0.5 + assert config.tile_size == 0 + assert config.half_precision is True + assert config.device == "cuda" + + def test_custom_config(self): + """Test custom configuration.""" + config = UpscaleConfig( + upscaler_type=UpscalerType.REAL_ESRGAN, + factor=UpscaleFactor.X4, + denoise_strength=0.8, + device="cpu" + ) + assert config.upscaler_type == UpscalerType.REAL_ESRGAN + assert config.factor == UpscaleFactor.X4 + assert config.denoise_strength == 0.8 + assert config.device == "cpu" + + def test_config_to_dict(self): + """Test configuration serialization.""" + config = UpscaleConfig( + factor=UpscaleFactor.X4, + model_path=Path("/models/model.pth") + ) + d = config.to_dict() + assert d["factor"] == 4 + assert d["upscaler_type"] == "ffmpeg_sr" + # Path conversion is platform-specific + assert "model.pth" in d["model_path"] + + +class TestUpscaleResult: + """Test UpscaleResult dataclass.""" + + def test_success_result(self): + """Test successful result.""" + result = UpscaleResult( + success=True, + output_path=Path("output.mp4"), + input_resolution=(1920, 1080), + output_resolution=(3840, 2160), + processing_time_s=45.5 + ) + assert result.success is True + assert result.output_path == Path("output.mp4") + assert result.input_resolution == (1920, 1080) + assert result.output_resolution == (3840, 2160) + assert result.processing_time_s == 45.5 + + def test_failure_result(self): + """Test failure result.""" + result = UpscaleResult( + success=False, + error_message="FFmpeg not found" + ) + assert result.success is False + assert result.error_message == "FFmpeg not found" + + +class TestFFmpegSRUpscaler: + """Test FFmpeg-based upscaler.""" + + @pytest.fixture + def mock_subprocess(self): + """Fixture to mock subprocess.run.""" + with patch('subprocess.run') as mock_run: + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + yield mock_run + + @pytest.fixture + def temp_dir(self): + """Fixture for temporary directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + def test_is_available_success(self, mock_subprocess): + """Test availability check when ffmpeg is present.""" + config = UpscaleConfig() + upscaler = FFmpegSRUpscaler(config) + assert upscaler.is_available() is True + + def test_is_available_failure(self): + """Test availability check when ffmpeg is missing.""" + with patch('subprocess.run', side_effect=FileNotFoundError()): + config = UpscaleConfig() + upscaler = FFmpegSRUpscaler(config) + assert upscaler.is_available() is False + + def test_upscale_missing_input(self, mock_subprocess): + """Test upscaling with missing input file.""" + config = UpscaleConfig() + upscaler = FFmpegSRUpscaler(config) + + result = upscaler.upscale(Path("nonexistent.mp4"), Path("output.mp4")) + + assert result.success is False + assert "not found" in result.error_message + + def test_upscale_success(self, mock_subprocess, temp_dir): + """Test successful upscaling.""" + # Create dummy input file + input_file = temp_dir / "input.mp4" + input_file.write_bytes(b"dummy video") + output_file = temp_dir / "output.mp4" + + # Mock ffprobe response + ffprobe_response = MagicMock( + returncode=0, + stdout='{"streams": [{"width": 1920, "height": 1080, "r_frame_rate": "24/1", "duration": "10.0"}]}' + ) + + with patch('subprocess.run') as mock_run: + # First call is ffprobe, second is ffmpeg + mock_run.side_effect = [ffprobe_response, MagicMock(returncode=0)] + + config = UpscaleConfig(factor=UpscaleFactor.X2) + upscaler = FFmpegSRUpscaler(config) + result = upscaler.upscale(input_file, output_file) + + assert result.success is True + assert result.input_resolution == (1920, 1080) + assert result.output_resolution == (3840, 2160) + + def test_upscale_ffmpeg_error(self, mock_subprocess, temp_dir): + """Test handling of ffmpeg error.""" + input_file = temp_dir / "input.mp4" + input_file.write_bytes(b"dummy video") + output_file = temp_dir / "output.mp4" + + ffprobe_response = MagicMock( + returncode=0, + stdout='{"streams": [{"width": 1920, "height": 1080, "r_frame_rate": "24/1", "duration": "10.0"}]}' + ) + + ffmpeg_error = MagicMock(returncode=1, stderr="Invalid data") + + with patch('subprocess.run') as mock_run: + mock_run.side_effect = [ffprobe_response, ffmpeg_error] + + config = UpscaleConfig() + upscaler = FFmpegSRUpscaler(config) + result = upscaler.upscale(input_file, output_file) + + assert result.success is False + assert "FFmpeg error" in result.error_message + + +class TestRealESRGANUpscaler: + """Test Real-ESRGAN upscaler.""" + + def test_is_available_without_realesrgan(self): + """Test availability when Real-ESRGAN is not installed.""" + with patch.dict('sys.modules', {'realesrgan': None}): + config = UpscaleConfig() + upscaler = RealESRGANUpscaler(config) + assert upscaler.is_available() is False + + def test_upscale_without_realesrgan(self): + """Test upscaling when Real-ESRGAN is not installed.""" + # Mock the import to fail + import sys + original_modules = sys.modules.copy() + + try: + # Remove realesrgan and torch from modules to simulate not installed + for mod in list(sys.modules.keys()): + if 'realesrgan' in mod or 'torch' in mod: + del sys.modules[mod] + + # Mock import to raise ImportError + def mock_import(name, *args, **kwargs): + if name == 'realesrgan' or name.startswith('realesrgan.'): + raise ImportError("No module named 'realesrgan'") + if name == 'torch' or name.startswith('torch.'): + raise ImportError("No module named 'torch'") + return original_modules.get(name, __import__(name, *args, **kwargs)) + + with patch('builtins.__import__', side_effect=mock_import): + config = UpscaleConfig() + upscaler = RealESRGANUpscaler(config) + + with tempfile.TemporaryDirectory() as tmpdir: + input_file = Path(tmpdir) / "input.mp4" + input_file.write_bytes(b"dummy") + output_file = Path(tmpdir) / "output.mp4" + + result = upscaler.upscale(input_file, output_file) + + assert result.success is False + assert "not installed" in result.error_message + finally: + # Restore original modules + sys.modules.clear() + sys.modules.update(original_modules) + + +class TestUpscaleManager: + """Test UpscaleManager.""" + + def test_init_with_default_config(self): + """Test initialization with default config.""" + manager = UpscaleManager() + assert manager.config.upscaler_type == UpscalerType.FFMPEG_SR + + def test_init_with_custom_config(self): + """Test initialization with custom config.""" + config = UpscaleConfig(upscaler_type=UpscalerType.REAL_ESRGAN) + manager = UpscaleManager(config) + assert manager.config.upscaler_type == UpscalerType.REAL_ESRGAN + + def test_get_upscaler_caching(self): + """Test that upscalers are cached.""" + with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True): + config = UpscaleConfig() + manager = UpscaleManager(config) + + upscaler1 = manager.get_upscaler(UpscalerType.FFMPEG_SR) + upscaler2 = manager.get_upscaler(UpscalerType.FFMPEG_SR) + + assert upscaler1 is upscaler2 # Same instance + + def test_get_upscaler_unavailable(self): + """Test getting unavailable upscaler.""" + with patch.object(FFmpegSRUpscaler, 'is_available', return_value=False): + config = UpscaleConfig() + manager = UpscaleManager(config) + + with pytest.raises(RuntimeError) as exc_info: + manager.get_upscaler(UpscalerType.FFMPEG_SR) + + assert "not available" in str(exc_info.value) + + def test_upscale_with_manager(self): + """Test upscale through manager.""" + with tempfile.TemporaryDirectory() as tmpdir: + input_file = Path(tmpdir) / "input.mp4" + input_file.write_bytes(b"dummy") + output_file = Path(tmpdir) / "output.mp4" + + mock_result = UpscaleResult( + success=True, + output_path=output_file, + input_resolution=(1920, 1080), + output_resolution=(3840, 2160) + ) + + with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True): + with patch.object(FFmpegSRUpscaler, 'upscale', return_value=mock_result): + config = UpscaleConfig() + manager = UpscaleManager(config) + result = manager.upscale(input_file, output_file) + + assert result.success is True + assert result.output_resolution == (3840, 2160) + + def test_get_available_upscalers(self): + """Test getting list of available upscalers.""" + with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True): + with patch.object(RealESRGANUpscaler, 'is_available', return_value=False): + config = UpscaleConfig() + manager = UpscaleManager(config) + available = manager.get_available_upscalers() + + assert UpscalerType.FFMPEG_SR in available + assert UpscalerType.REAL_ESRGAN not in available + + def test_upscale_with_fallback_success(self): + """Test fallback upscaling with first option succeeding.""" + with tempfile.TemporaryDirectory() as tmpdir: + input_file = Path(tmpdir) / "input.mp4" + input_file.write_bytes(b"dummy") + output_file = Path(tmpdir) / "output.mp4" + + mock_result = UpscaleResult( + success=True, + output_path=output_file, + input_resolution=(1920, 1080), + output_resolution=(3840, 2160) + ) + + with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True): + with patch.object(FFmpegSRUpscaler, 'upscale', return_value=mock_result): + config = UpscaleConfig() + manager = UpscaleManager(config) + result = manager.upscale_with_fallback(input_file, output_file) + + assert result.success is True + + def test_upscale_with_fallback_all_fail(self): + """Test fallback upscaling when all options fail.""" + with tempfile.TemporaryDirectory() as tmpdir: + input_file = Path(tmpdir) / "input.mp4" + input_file.write_bytes(b"dummy") + output_file = Path(tmpdir) / "output.mp4" + + mock_result = UpscaleResult( + success=False, + error_message="Processing failed" + ) + + with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True): + with patch.object(FFmpegSRUpscaler, 'upscale', return_value=mock_result): + config = UpscaleConfig() + manager = UpscaleManager(config) + result = manager.upscale_with_fallback(input_file, output_file) + + assert result.success is False + assert "All upscalers failed" in result.error_message + + +class TestUpscaleEnums: + """Test upscaler enums.""" + + def test_upscaler_type_values(self): + """Test upscaler type enum values.""" + assert UpscalerType.REAL_ESRGAN.value == "real_esrgan" + assert UpscalerType.FFMPEG_SR.value == "ffmpeg_sr" + + def test_upscale_factor_values(self): + """Test upscale factor enum values.""" + assert UpscaleFactor.X2.value == 2 + assert UpscaleFactor.X4.value == 4 diff --git a/tests/unit/test_wan_backend.py b/tests/unit/test_wan_backend.py index edba4fc..b116706 100644 --- a/tests/unit/test_wan_backend.py +++ b/tests/unit/test_wan_backend.py @@ -92,7 +92,7 @@ class TestWanBackend: # Double resolution (4x pixels) vram_1440_81 = backend.estimate_vram_usage(1920, 1080, 81) - assert vram_1440_81 > vram_720_81 * 3 # Should be ~4x + assert vram_1440_81 > vram_720_81 * 2 # Should be ~2x due to model overhead # Double frames vram_720_162 = backend.estimate_vram_usage(1280, 720, 162)