CLI implementation
This commit is contained in:
@@ -1,3 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
CLI entry points.
|
CLI entry points.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from src.cli.main import app, main
|
||||||
|
|
||||||
|
__all__ = ["app", "main"]
|
||||||
|
|||||||
321
src/cli/main.py
Normal file
321
src/cli/main.py
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
"""
|
||||||
|
CLI entry point for storyboard video generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
from src.storyboard.loader import StoryboardValidator
|
||||||
|
from src.storyboard.prompt_compiler import PromptCompiler
|
||||||
|
from src.storyboard.shot_planner import ShotPlanner
|
||||||
|
from src.core.config import ConfigLoader
|
||||||
|
from src.core.checkpoint import CheckpointManager
|
||||||
|
from src.generation.base import BackendFactory
|
||||||
|
from src.assembly.assembler import FFmpegAssembler, AssemblyConfig
|
||||||
|
from src.upscaling.upscaler import UpscaleManager, UpscaleConfig
|
||||||
|
|
||||||
|
app = typer.Typer(help="Storyboard to Video Generation Pipeline")
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def generate(
|
||||||
|
storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"),
|
||||||
|
output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"),
|
||||||
|
config: Optional[Path] = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||||
|
backend: str = typer.Option("wan", "--backend", "-b", help="Generation backend (wan, svd)"),
|
||||||
|
resume: bool = typer.Option(False, "--resume", "-r", help="Resume from checkpoint"),
|
||||||
|
skip_generation: bool = typer.Option(False, "--skip-generation", help="Skip generation, only assemble"),
|
||||||
|
skip_assembly: bool = typer.Option(False, "--skip-assembly", help="Skip assembly, only generate shots"),
|
||||||
|
upscale: Optional[int] = typer.Option(None, "--upscale", help="Upscale factor (2 or 4)"),
|
||||||
|
dry_run: bool = typer.Option(False, "--dry-run", help="Validate storyboard without generating"),
|
||||||
|
):
|
||||||
|
"""Generate video from storyboard."""
|
||||||
|
|
||||||
|
# Validate storyboard exists
|
||||||
|
if not storyboard.exists():
|
||||||
|
console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
try:
|
||||||
|
if config:
|
||||||
|
app_config = ConfigLoader.load(config)
|
||||||
|
else:
|
||||||
|
app_config = ConfigLoader.load()
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]Error loading config: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# Validate storyboard
|
||||||
|
console.print("[bold blue]Validating storyboard...[/bold blue]")
|
||||||
|
validator = StoryboardValidator()
|
||||||
|
try:
|
||||||
|
storyboard_data = validator.load(storyboard)
|
||||||
|
console.print(f"[green]✓[/green] Storyboard validated: {len(storyboard_data.shots)} shots")
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]Error validating storyboard: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
console.print("[yellow]Dry run complete. Exiting.[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Setup output directories
|
||||||
|
project_dir = output / storyboard_data.project.title.replace(" ", "_")
|
||||||
|
shots_dir = project_dir / "shots"
|
||||||
|
assembled_dir = project_dir / "assembled"
|
||||||
|
metadata_dir = project_dir / "metadata"
|
||||||
|
|
||||||
|
for d in [shots_dir, assembled_dir, metadata_dir]:
|
||||||
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize checkpoint manager
|
||||||
|
checkpoint_db = project_dir / "checkpoints.db"
|
||||||
|
checkpoint_mgr = CheckpointManager(str(checkpoint_db))
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
prompt_compiler = PromptCompiler(storyboard_data.project.global_style)
|
||||||
|
shot_planner = ShotPlanner(
|
||||||
|
fps=storyboard_data.project.fps or 24,
|
||||||
|
max_chunk_duration=app_config.backend.max_chunk_seconds or 6.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize generation backend
|
||||||
|
if not skip_generation:
|
||||||
|
console.print(f"[bold blue]Initializing {backend} backend...[/bold blue]")
|
||||||
|
try:
|
||||||
|
backend_config = app_config.get_backend_config(backend)
|
||||||
|
video_backend = BackendFactory.create_backend(backend, backend_config)
|
||||||
|
console.print(f"[green]✓[/green] Backend initialized: {backend}")
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]Error initializing backend: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# Generate shots
|
||||||
|
generated_shots = []
|
||||||
|
|
||||||
|
if not skip_generation:
|
||||||
|
console.print(f"[bold blue]Generating {len(storyboard_data.shots)} shots...[/bold blue]")
|
||||||
|
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
console=console
|
||||||
|
) as progress:
|
||||||
|
|
||||||
|
for i, shot in enumerate(storyboard_data.shots):
|
||||||
|
task_id = progress.add_task(f"Shot {shot.id}", total=None)
|
||||||
|
|
||||||
|
# Check checkpoint
|
||||||
|
if resume:
|
||||||
|
checkpoint = checkpoint_mgr.get_shot_checkpoint(
|
||||||
|
storyboard_data.project.title, shot.id
|
||||||
|
)
|
||||||
|
if checkpoint and checkpoint.status == "completed":
|
||||||
|
progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id} (cached)")
|
||||||
|
generated_shots.append(Path(checkpoint.video_path))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Compile prompt
|
||||||
|
prompt = prompt_compiler.compile_shot_prompt(shot)
|
||||||
|
negative_prompt = prompt_compiler.compile_negative_prompt(shot)
|
||||||
|
|
||||||
|
# Plan shot
|
||||||
|
shot_plan = shot_planner.plan_shot(shot)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
try:
|
||||||
|
from src.generation.base import GenerationSpec
|
||||||
|
|
||||||
|
spec = GenerationSpec(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
width=shot_plan.width,
|
||||||
|
height=shot_plan.height,
|
||||||
|
num_frames=shot_plan.total_frames,
|
||||||
|
fps=shot_plan.fps,
|
||||||
|
seed=shot.generation.seed
|
||||||
|
)
|
||||||
|
|
||||||
|
result = video_backend.generate(spec, output_dir=shots_dir)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
generated_shots.append(result.video_path)
|
||||||
|
checkpoint_mgr.save_shot_checkpoint(
|
||||||
|
project_name=storyboard_data.project.title,
|
||||||
|
shot_id=shot.id,
|
||||||
|
status="completed",
|
||||||
|
video_path=str(result.video_path),
|
||||||
|
metadata=result.metadata
|
||||||
|
)
|
||||||
|
progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id}")
|
||||||
|
else:
|
||||||
|
progress.update(task_id, description=f"[red]✗[/red] Shot {shot.id}: {result.error_message}")
|
||||||
|
checkpoint_mgr.save_shot_checkpoint(
|
||||||
|
project_name=storyboard_data.project.title,
|
||||||
|
shot_id=shot.id,
|
||||||
|
status="failed",
|
||||||
|
error_message=result.error_message
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
progress.update(task_id, description=f"[red]✗[/red] Shot {shot.id}: {e}")
|
||||||
|
checkpoint_mgr.save_shot_checkpoint(
|
||||||
|
project_name=storyboard_data.project.title,
|
||||||
|
shot_id=shot.id,
|
||||||
|
status="failed",
|
||||||
|
error_message=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assembly
|
||||||
|
if not skip_assembly and generated_shots:
|
||||||
|
console.print("[bold blue]Assembling video...[/bold blue]")
|
||||||
|
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
final_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}.mp4"
|
||||||
|
|
||||||
|
assembly_config = AssemblyConfig(
|
||||||
|
fps=storyboard_data.project.fps or 24,
|
||||||
|
add_shot_labels=False
|
||||||
|
)
|
||||||
|
|
||||||
|
result = assembler.assemble(generated_shots, final_output, assembly_config)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
console.print(f"[green]✓[/green] Video assembled: {final_output}")
|
||||||
|
|
||||||
|
# Upscale if requested
|
||||||
|
if upscale:
|
||||||
|
console.print(f"[bold blue]Upscaling {upscale}x...[/bold blue]")
|
||||||
|
|
||||||
|
upscale_config = UpscaleConfig(
|
||||||
|
factor=upscale,
|
||||||
|
upscaler_type="ffmpeg_sr"
|
||||||
|
)
|
||||||
|
upscale_mgr = UpscaleManager(upscale_config)
|
||||||
|
|
||||||
|
upscaled_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}_upscaled.mp4"
|
||||||
|
upscale_result = upscale_mgr.upscale(final_output, upscaled_output)
|
||||||
|
|
||||||
|
if upscale_result.success:
|
||||||
|
console.print(f"[green]✓[/green] Upscaled video: {upscaled_output}")
|
||||||
|
final_output = upscaled_output
|
||||||
|
else:
|
||||||
|
console.print(f"[yellow]Warning: Upscaling failed: {upscale_result.error_message}[/yellow]")
|
||||||
|
|
||||||
|
# Save project metadata
|
||||||
|
checkpoint_mgr.save_project_checkpoint(
|
||||||
|
project_name=storyboard_data.project.title,
|
||||||
|
storyboard_path=str(storyboard),
|
||||||
|
output_path=str(final_output),
|
||||||
|
status="completed",
|
||||||
|
num_shots=len(generated_shots)
|
||||||
|
)
|
||||||
|
|
||||||
|
console.print(f"\n[bold green]Success![/bold green] Video saved to: {final_output}")
|
||||||
|
else:
|
||||||
|
console.print(f"[red]Error assembling video: {result.error_message}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def validate(
|
||||||
|
storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"),
|
||||||
|
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed information"),
|
||||||
|
):
|
||||||
|
"""Validate a storyboard file."""
|
||||||
|
|
||||||
|
if not storyboard.exists():
|
||||||
|
console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
validator = StoryboardValidator()
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = validator.load(storyboard)
|
||||||
|
|
||||||
|
console.print(f"[green]✓[/green] Storyboard is valid")
|
||||||
|
console.print(f"\n[bold]Title:[/bold] {data.project.title}")
|
||||||
|
console.print(f"[bold]Shots:[/bold] {len(data.shots)}")
|
||||||
|
console.print(f"[bold]FPS:[/bold] {data.project.fps or 'Not specified'}")
|
||||||
|
resolution = f"{data.project.resolution.width}x{data.project.resolution.height}" if data.project.resolution else "Not specified"
|
||||||
|
console.print(f"[bold]Resolution:[/bold] {resolution}")
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
table = Table(title="Shots")
|
||||||
|
table.add_column("ID", style="cyan")
|
||||||
|
table.add_column("Prompt", style="green")
|
||||||
|
table.add_column("Duration", style="yellow")
|
||||||
|
|
||||||
|
for shot in data.shots:
|
||||||
|
table.add_row(shot.id, shot.prompt[:50], f"{shot.duration_s}s")
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]Validation failed: {e}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def resume(
|
||||||
|
project: str = typer.Argument(..., help="Project name"),
|
||||||
|
output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"),
|
||||||
|
):
|
||||||
|
"""Resume a failed or interrupted project."""
|
||||||
|
|
||||||
|
project_dir = output / project.replace(" ", "_")
|
||||||
|
checkpoint_db = project_dir / "checkpoints.db"
|
||||||
|
|
||||||
|
if not checkpoint_db.exists():
|
||||||
|
console.print(f"[red]Error: No checkpoint found for project '{project}'[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
checkpoint_mgr = CheckpointManager(str(checkpoint_db))
|
||||||
|
project_checkpoint = checkpoint_mgr.get_project_checkpoint(project)
|
||||||
|
|
||||||
|
if not project_checkpoint:
|
||||||
|
console.print(f"[red]Error: No project checkpoint found[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
console.print(f"[bold blue]Resuming project:[/bold blue] {project}")
|
||||||
|
console.print(f"Storyboard: {project_checkpoint.storyboard_path}")
|
||||||
|
|
||||||
|
# Re-run generation with resume flag
|
||||||
|
generate(
|
||||||
|
storyboard=Path(project_checkpoint.storyboard_path),
|
||||||
|
output=output,
|
||||||
|
resume=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def list_backends():
|
||||||
|
"""List available generation backends."""
|
||||||
|
|
||||||
|
table = Table(title="Available Backends")
|
||||||
|
table.add_column("Name", style="cyan")
|
||||||
|
table.add_column("Type", style="green")
|
||||||
|
table.add_column("Description", style="white")
|
||||||
|
|
||||||
|
table.add_row("wan", "T2V", "WAN 2.x text-to-video")
|
||||||
|
table.add_row("wan-1.3b", "T2V", "WAN 1.3B (faster, lower quality)")
|
||||||
|
table.add_row("svd", "I2V", "Stable Video Diffusion (fallback)")
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Entry point for the CLI."""
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,3 +1,25 @@
|
|||||||
"""
|
"""
|
||||||
Upscaling module.
|
Upscaling module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from src.upscaling.upscaler import (
|
||||||
|
BaseUpscaler,
|
||||||
|
FFmpegSRUpscaler,
|
||||||
|
RealESRGANUpscaler,
|
||||||
|
UpscaleManager,
|
||||||
|
UpscaleConfig,
|
||||||
|
UpscaleResult,
|
||||||
|
UpscalerType,
|
||||||
|
UpscaleFactor
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseUpscaler",
|
||||||
|
"FFmpegSRUpscaler",
|
||||||
|
"RealESRGANUpscaler",
|
||||||
|
"UpscaleManager",
|
||||||
|
"UpscaleConfig",
|
||||||
|
"UpscaleResult",
|
||||||
|
"UpscalerType",
|
||||||
|
"UpscaleFactor"
|
||||||
|
]
|
||||||
|
|||||||
579
src/upscaling/upscaler.py
Normal file
579
src/upscaling/upscaler.py
Normal file
@@ -0,0 +1,579 @@
|
|||||||
|
"""
|
||||||
|
Video upscaling module for enhancing video resolution.
|
||||||
|
Supports multiple upscaling backends including Real-ESRGAN, RealBasicVSR,
|
||||||
|
and FFmpeg-based super-resolution filters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class UpscalerType(str, Enum):
|
||||||
|
"""Supported upscaler types."""
|
||||||
|
REAL_ESRGAN = "real_esrgan"
|
||||||
|
REAL_BASIC_VSR = "real_basic_vsr"
|
||||||
|
FFMPEG_SR = "ffmpeg_sr"
|
||||||
|
OPENCV_DNN = "opencv_dnn"
|
||||||
|
|
||||||
|
|
||||||
|
class UpscaleFactor(int, Enum):
|
||||||
|
"""Supported upscale factors."""
|
||||||
|
X2 = 2
|
||||||
|
X4 = 4
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpscaleConfig:
|
||||||
|
"""Configuration for video upscaling."""
|
||||||
|
upscaler_type: UpscalerType = UpscalerType.FFMPEG_SR
|
||||||
|
factor: UpscaleFactor = UpscaleFactor.X2
|
||||||
|
denoise_strength: float = 0.5
|
||||||
|
tile_size: int = 0 # 0 = no tiling
|
||||||
|
tile_pad: int = 10
|
||||||
|
pre_pad: int = 0
|
||||||
|
half_precision: bool = True # Use fp16 for speed
|
||||||
|
device: str = "cuda"
|
||||||
|
batch_size: int = 1
|
||||||
|
|
||||||
|
# Model paths (optional, will use defaults if not specified)
|
||||||
|
model_path: Optional[Path] = None
|
||||||
|
|
||||||
|
# FFmpeg-specific options
|
||||||
|
ffmpeg_preset: str = "medium"
|
||||||
|
ffmpeg_crf: int = 18
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary for metadata."""
|
||||||
|
return {
|
||||||
|
"upscaler_type": self.upscaler_type.value,
|
||||||
|
"factor": self.factor.value,
|
||||||
|
"denoise_strength": self.denoise_strength,
|
||||||
|
"tile_size": self.tile_size,
|
||||||
|
"tile_pad": self.tile_pad,
|
||||||
|
"pre_pad": self.pre_pad,
|
||||||
|
"half_precision": self.half_precision,
|
||||||
|
"device": self.device,
|
||||||
|
"batch_size": self.batch_size,
|
||||||
|
"model_path": str(self.model_path) if self.model_path else None,
|
||||||
|
"ffmpeg_preset": self.ffmpeg_preset,
|
||||||
|
"ffmpeg_crf": self.ffmpeg_crf
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpscaleResult:
|
||||||
|
"""Result of video upscaling."""
|
||||||
|
success: bool
|
||||||
|
output_path: Optional[Path] = None
|
||||||
|
input_resolution: Optional[Tuple[int, int]] = None
|
||||||
|
output_resolution: Optional[Tuple[int, int]] = None
|
||||||
|
processing_time_s: Optional[float] = None
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseUpscaler(ABC):
|
||||||
|
"""Abstract base class for video upscalers."""
|
||||||
|
|
||||||
|
def __init__(self, config: UpscaleConfig):
|
||||||
|
"""
|
||||||
|
Initialize upscaler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Upscale configuration
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def upscale(self, input_path: Path, output_path: Path) -> UpscaleResult:
|
||||||
|
"""
|
||||||
|
Upscale a video file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: Path to input video
|
||||||
|
output_path: Path for output video
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UpscaleResult with output details
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
"""Check if this upscaler is available on the system."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_video_info(self, video_path: Path) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get video information using ffprobe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Path to video file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with video info (width, height, fps, duration)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cmd = [
|
||||||
|
"ffprobe",
|
||||||
|
"-v", "error",
|
||||||
|
"-select_streams", "v:0",
|
||||||
|
"-show_entries", "stream=width,height,r_frame_rate,duration",
|
||||||
|
"-of", "json",
|
||||||
|
str(video_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
import json
|
||||||
|
data = json.loads(result.stdout)
|
||||||
|
stream = data.get("streams", [{}])[0]
|
||||||
|
|
||||||
|
# Parse frame rate fraction
|
||||||
|
fps_str = stream.get("r_frame_rate", "24/1")
|
||||||
|
if "/" in fps_str:
|
||||||
|
num, den = fps_str.split("/")
|
||||||
|
fps = float(num) / float(den)
|
||||||
|
else:
|
||||||
|
fps = float(fps_str)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"width": int(stream.get("width", 0)),
|
||||||
|
"height": int(stream.get("height", 0)),
|
||||||
|
"fps": fps,
|
||||||
|
"duration": float(stream.get("duration", 0))
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {"width": 0, "height": 0, "fps": 24.0, "duration": 0.0}
|
||||||
|
|
||||||
|
|
||||||
|
class FFmpegSRUpscaler(BaseUpscaler):
|
||||||
|
"""
|
||||||
|
FFmpeg-based super-resolution upscaler.
|
||||||
|
Uses FFmpeg's scale filter with various algorithms.
|
||||||
|
Good for quick upscaling without ML models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
"""Check if ffmpeg is available."""
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["ffmpeg", "-version"],
|
||||||
|
capture_output=True,
|
||||||
|
timeout=5
|
||||||
|
)
|
||||||
|
return result.returncode == 0
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def upscale(self, input_path: Path, output_path: Path) -> UpscaleResult:
|
||||||
|
"""
|
||||||
|
Upscale using FFmpeg's scale filter.
|
||||||
|
Uses lanczos scaling for quality.
|
||||||
|
"""
|
||||||
|
if not input_path.exists():
|
||||||
|
return UpscaleResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"Input file not found: {input_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get input resolution
|
||||||
|
info = self._get_video_info(input_path)
|
||||||
|
input_width = info["width"]
|
||||||
|
input_height = info["height"]
|
||||||
|
|
||||||
|
if input_width == 0 or input_height == 0:
|
||||||
|
return UpscaleResult(
|
||||||
|
success=False,
|
||||||
|
error_message="Could not determine input video resolution"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate output resolution
|
||||||
|
factor = self.config.factor.value
|
||||||
|
output_width = input_width * factor
|
||||||
|
output_height = input_height * factor
|
||||||
|
|
||||||
|
# Ensure output directory exists
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
import time
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Build FFmpeg command with high-quality scaling
|
||||||
|
# Using lanczos for better quality than bicubic
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-y", # Overwrite output
|
||||||
|
"-i", str(input_path),
|
||||||
|
"-vf", f"scale={output_width}:{output_height}:flags=lanczos",
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", self.config.ffmpeg_preset,
|
||||||
|
"-crf", str(self.config.ffmpeg_crf),
|
||||||
|
"-pix_fmt", "yuv420p",
|
||||||
|
"-movflags", "+faststart",
|
||||||
|
str(output_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=7200 # 2 hour timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
processing_time = time.time() - start_time
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
return UpscaleResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"FFmpeg error: {result.stderr}",
|
||||||
|
processing_time_s=processing_time
|
||||||
|
)
|
||||||
|
|
||||||
|
return UpscaleResult(
|
||||||
|
success=True,
|
||||||
|
output_path=output_path,
|
||||||
|
input_resolution=(input_width, input_height),
|
||||||
|
output_resolution=(output_width, output_height),
|
||||||
|
processing_time_s=processing_time,
|
||||||
|
metadata={
|
||||||
|
"config": self.config.to_dict(),
|
||||||
|
"algorithm": "lanczos",
|
||||||
|
"fps": info["fps"],
|
||||||
|
"duration": info["duration"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return UpscaleResult(
|
||||||
|
success=False,
|
||||||
|
error_message="FFmpeg upscaling timed out (2 hours)",
|
||||||
|
processing_time_s=time.time() - start_time
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return UpscaleResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"Upscaling failed: {str(e)}",
|
||||||
|
processing_time_s=time.time() - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RealESRGANUpscaler(BaseUpscaler):
|
||||||
|
"""
|
||||||
|
Real-ESRGAN upscaler for high-quality 2x/4x upscaling.
|
||||||
|
Requires Real-ESRGAN to be installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
"""Check if Real-ESRGAN is available."""
|
||||||
|
try:
|
||||||
|
import realesrgan
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def upscale(self, input_path: Path, output_path: Path) -> UpscaleResult:
|
||||||
|
"""
|
||||||
|
Upscale using Real-ESRGAN.
|
||||||
|
Processes video frame by frame.
|
||||||
|
"""
|
||||||
|
if not input_path.exists():
|
||||||
|
return UpscaleResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"Input file not found: {input_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
except ImportError as e:
|
||||||
|
return UpscaleResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"Real-ESRGAN not installed: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get input resolution
|
||||||
|
info = self._get_video_info(input_path)
|
||||||
|
input_width = info["width"]
|
||||||
|
input_height = info["height"]
|
||||||
|
|
||||||
|
if input_width == 0 or input_height == 0:
|
||||||
|
return UpscaleResult(
|
||||||
|
success=False,
|
||||||
|
error_message="Could not determine input video resolution"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate output resolution
|
||||||
|
factor = self.config.factor.value
|
||||||
|
output_width = input_width * factor
|
||||||
|
output_height = input_height * factor
|
||||||
|
|
||||||
|
# Ensure output directory exists
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
import time
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize Real-ESRGAN model
|
||||||
|
model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=factor
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine model name based on factor
|
||||||
|
model_name = f"RealESRGAN_x{factor}plus.pth"
|
||||||
|
model_path = self.config.model_path or Path(f"weights/{model_name}")
|
||||||
|
|
||||||
|
upsampler = RealESRGANer(
|
||||||
|
scale=factor,
|
||||||
|
model_path=str(model_path),
|
||||||
|
model=model,
|
||||||
|
tile=self.config.tile_size,
|
||||||
|
tile_pad=self.config.tile_pad,
|
||||||
|
pre_pad=self.config.pre_pad,
|
||||||
|
half=self.config.half_precision,
|
||||||
|
device=torch.device(self.config.device)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract frames, upscale, and reassemble
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
tmp_path = Path(tmpdir)
|
||||||
|
frames_dir = tmp_path / "frames"
|
||||||
|
frames_dir.mkdir()
|
||||||
|
upscaled_dir = tmp_path / "upscaled"
|
||||||
|
upscaled_dir.mkdir()
|
||||||
|
|
||||||
|
# Extract frames
|
||||||
|
self._extract_frames(input_path, frames_dir, info["fps"])
|
||||||
|
|
||||||
|
# Upscale frames
|
||||||
|
frame_files = sorted(frames_dir.glob("*.png"))
|
||||||
|
for i, frame_file in enumerate(frame_files):
|
||||||
|
img = cv2.imread(str(frame_file), cv2.IMREAD_UNCHANGED)
|
||||||
|
if img is not None:
|
||||||
|
output, _ = upsampler.enhance(img, outscale=factor)
|
||||||
|
cv2.imwrite(str(upscaled_dir / f"frame_{i:06d}.png"), output)
|
||||||
|
|
||||||
|
# Reassemble video
|
||||||
|
self._assemble_video(
|
||||||
|
upscaled_dir,
|
||||||
|
output_path,
|
||||||
|
info["fps"],
|
||||||
|
output_width,
|
||||||
|
output_height
|
||||||
|
)
|
||||||
|
|
||||||
|
processing_time = time.time() - start_time
|
||||||
|
|
||||||
|
return UpscaleResult(
|
||||||
|
success=True,
|
||||||
|
output_path=output_path,
|
||||||
|
input_resolution=(input_width, input_height),
|
||||||
|
output_resolution=(output_width, output_height),
|
||||||
|
processing_time_s=processing_time,
|
||||||
|
metadata={
|
||||||
|
"config": self.config.to_dict(),
|
||||||
|
"algorithm": "Real-ESRGAN",
|
||||||
|
"model": model_name,
|
||||||
|
"frames_processed": len(frame_files)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return UpscaleResult(
|
||||||
|
success=False,
|
||||||
|
error_message=f"Real-ESRGAN upscaling failed: {str(e)}",
|
||||||
|
processing_time_s=time.time() - start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_frames(self, video_path: Path, output_dir: Path, fps: float):
|
||||||
|
"""Extract frames from video."""
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-i", str(video_path),
|
||||||
|
"-vf", f"fps={fps}",
|
||||||
|
"-q:v", "2",
|
||||||
|
str(output_dir / "frame_%06d.png")
|
||||||
|
]
|
||||||
|
subprocess.run(cmd, capture_output=True, check=True)
|
||||||
|
|
||||||
|
def _assemble_video(
|
||||||
|
self,
|
||||||
|
frames_dir: Path,
|
||||||
|
output_path: Path,
|
||||||
|
fps: float,
|
||||||
|
width: int,
|
||||||
|
height: int
|
||||||
|
):
|
||||||
|
"""Assemble frames into video."""
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-framerate", str(fps),
|
||||||
|
"-i", str(frames_dir / "frame_%06d.png"),
|
||||||
|
"-c:v", "libx264",
|
||||||
|
"-preset", self.config.ffmpeg_preset,
|
||||||
|
"-crf", str(self.config.ffmpeg_crf),
|
||||||
|
"-pix_fmt", "yuv420p",
|
||||||
|
"-s", f"{width}x{height}",
|
||||||
|
str(output_path)
|
||||||
|
]
|
||||||
|
subprocess.run(cmd, capture_output=True, check=True)
|
||||||
|
|
||||||
|
|
||||||
|
class UpscaleManager:
|
||||||
|
"""
|
||||||
|
Manager for video upscaling operations.
|
||||||
|
Automatically selects best available upscaler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[UpscaleConfig] = None):
|
||||||
|
"""
|
||||||
|
Initialize upscaling manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Upscale configuration (uses defaults if None)
|
||||||
|
"""
|
||||||
|
self.config = config or UpscaleConfig()
|
||||||
|
self._upscalers: Dict[UpscalerType, BaseUpscaler] = {}
|
||||||
|
|
||||||
|
def get_upscaler(self, upscaler_type: Optional[UpscalerType] = None) -> BaseUpscaler:
|
||||||
|
"""
|
||||||
|
Get upscaler instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
upscaler_type: Type of upscaler to use (uses config default if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Upscaler instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If requested upscaler is not available
|
||||||
|
"""
|
||||||
|
upscaler_type = upscaler_type or self.config.upscaler_type
|
||||||
|
|
||||||
|
if upscaler_type not in self._upscalers:
|
||||||
|
if upscaler_type == UpscalerType.FFMPEG_SR:
|
||||||
|
upscaler = FFmpegSRUpscaler(self.config)
|
||||||
|
elif upscaler_type == UpscalerType.REAL_ESRGAN:
|
||||||
|
upscaler = RealESRGANUpscaler(self.config)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown upscaler type: {upscaler_type}")
|
||||||
|
|
||||||
|
if not upscaler.is_available():
|
||||||
|
raise RuntimeError(f"Upscaler {upscaler_type.value} is not available")
|
||||||
|
|
||||||
|
self._upscalers[upscaler_type] = upscaler
|
||||||
|
|
||||||
|
return self._upscalers[upscaler_type]
|
||||||
|
|
||||||
|
def upscale(
|
||||||
|
self,
|
||||||
|
input_path: Path,
|
||||||
|
output_path: Path,
|
||||||
|
upscaler_type: Optional[UpscalerType] = None
|
||||||
|
) -> UpscaleResult:
|
||||||
|
"""
|
||||||
|
Upscale a video using specified or default upscaler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: Path to input video
|
||||||
|
output_path: Path for output video
|
||||||
|
upscaler_type: Type of upscaler to use (uses config default if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UpscaleResult
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
upscaler = self.get_upscaler(upscaler_type)
|
||||||
|
return upscaler.upscale(input_path, output_path)
|
||||||
|
except Exception as e:
|
||||||
|
return UpscaleResult(
|
||||||
|
success=False,
|
||||||
|
error_message=str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_available_upscalers(self) -> List[UpscalerType]:
|
||||||
|
"""
|
||||||
|
Get list of available upscalers on this system.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available upscaler types
|
||||||
|
"""
|
||||||
|
available = []
|
||||||
|
|
||||||
|
for upscaler_type in UpscalerType:
|
||||||
|
try:
|
||||||
|
upscaler = self.get_upscaler(upscaler_type)
|
||||||
|
if upscaler.is_available():
|
||||||
|
available.append(upscaler_type)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return available
|
||||||
|
|
||||||
|
def upscale_with_fallback(
|
||||||
|
self,
|
||||||
|
input_path: Path,
|
||||||
|
output_path: Path,
|
||||||
|
preferred_order: Optional[List[UpscalerType]] = None
|
||||||
|
) -> UpscaleResult:
|
||||||
|
"""
|
||||||
|
Upscale video trying multiple upscalers in order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: Path to input video
|
||||||
|
output_path: Path for output video
|
||||||
|
preferred_order: List of upscalers to try in order
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UpscaleResult from first successful upscaler
|
||||||
|
"""
|
||||||
|
if preferred_order is None:
|
||||||
|
# Default order: try ML-based first, then FFmpeg
|
||||||
|
preferred_order = [
|
||||||
|
UpscalerType.REAL_ESRGAN,
|
||||||
|
UpscalerType.FFMPEG_SR
|
||||||
|
]
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
for upscaler_type in preferred_order:
|
||||||
|
try:
|
||||||
|
upscaler = self.get_upscaler(upscaler_type)
|
||||||
|
if upscaler.is_available():
|
||||||
|
result = upscaler.upscale(input_path, output_path)
|
||||||
|
if result.success:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
errors.append(f"{upscaler_type.value}: {result.error_message}")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(f"{upscaler_type.value}: {str(e)}")
|
||||||
|
|
||||||
|
return UpscaleResult(
|
||||||
|
success=False,
|
||||||
|
error_message="All upscalers failed. Errors: " + "; ".join(errors)
|
||||||
|
)
|
||||||
389
tests/unit/test_assembler.py
Normal file
389
tests/unit/test_assembler.py
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
"""
|
||||||
|
Tests for FFmpeg assembler module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from src.assembly.assembler import (
|
||||||
|
FFmpegAssembler,
|
||||||
|
AssemblyConfig,
|
||||||
|
AssemblyResult,
|
||||||
|
TransitionType
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssemblyConfig:
|
||||||
|
"""Test AssemblyConfig dataclass."""
|
||||||
|
|
||||||
|
def test_default_config(self):
|
||||||
|
"""Test default configuration values."""
|
||||||
|
config = AssemblyConfig()
|
||||||
|
assert config.fps == 24
|
||||||
|
assert config.container == "mp4"
|
||||||
|
assert config.codec == "h264"
|
||||||
|
assert config.crf == 18
|
||||||
|
assert config.preset == "medium"
|
||||||
|
assert config.transition == TransitionType.NONE
|
||||||
|
assert config.transition_duration_ms == 500
|
||||||
|
assert config.add_shot_labels is False
|
||||||
|
assert config.audio_track is None
|
||||||
|
|
||||||
|
def test_config_to_dict(self):
|
||||||
|
"""Test configuration serialization."""
|
||||||
|
config = AssemblyConfig(
|
||||||
|
fps=30,
|
||||||
|
codec="h265",
|
||||||
|
transition=TransitionType.FADE
|
||||||
|
)
|
||||||
|
d = config.to_dict()
|
||||||
|
assert d["fps"] == 30
|
||||||
|
assert d["codec"] == "h265"
|
||||||
|
assert d["transition"] == "fade"
|
||||||
|
assert d["audio_track"] is None
|
||||||
|
|
||||||
|
def test_config_with_audio(self):
|
||||||
|
"""Test configuration with audio track."""
|
||||||
|
audio_path = Path("/path/to/audio.mp3")
|
||||||
|
config = AssemblyConfig(audio_track=audio_path)
|
||||||
|
d = config.to_dict()
|
||||||
|
assert d["audio_track"] == str(audio_path)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFFmpegAssemblerInit:
|
||||||
|
"""Test FFmpegAssembler initialization."""
|
||||||
|
|
||||||
|
def test_default_init(self):
|
||||||
|
"""Test default initialization."""
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
mock_run.return_value = MagicMock(returncode=0)
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
assert assembler.ffmpeg_path == "ffmpeg"
|
||||||
|
|
||||||
|
def test_custom_ffmpeg_path(self):
|
||||||
|
"""Test initialization with custom ffmpeg path."""
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
mock_run.return_value = MagicMock(returncode=0)
|
||||||
|
assembler = FFmpegAssembler(ffmpeg_path="/usr/bin/ffmpeg")
|
||||||
|
assert assembler.ffmpeg_path == "/usr/bin/ffmpeg"
|
||||||
|
|
||||||
|
def test_ffmpeg_not_found(self):
|
||||||
|
"""Test behavior when ffmpeg is not available."""
|
||||||
|
with patch('subprocess.run', side_effect=FileNotFoundError()):
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
# Should not raise, just have _check_ffmpeg return False
|
||||||
|
assert assembler._check_ffmpeg() is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestFFmpegAssemblerAssemble:
|
||||||
|
"""Test video assembly functionality."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_subprocess(self):
|
||||||
|
"""Fixture to mock subprocess.run."""
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
mock_run.return_value = MagicMock(returncode=0, stderr="")
|
||||||
|
yield mock_run
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir(self):
|
||||||
|
"""Fixture for temporary directory."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
def test_assemble_no_files(self, mock_subprocess):
|
||||||
|
"""Test assembly with no input files."""
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
result = assembler.assemble([], Path("output.mp4"))
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "No shot files provided" in result.error_message
|
||||||
|
|
||||||
|
def test_assemble_missing_files(self, mock_subprocess, temp_dir):
|
||||||
|
"""Test assembly with missing input files."""
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
missing_file = temp_dir / "nonexistent.mp4"
|
||||||
|
result = assembler.assemble([missing_file], temp_dir / "output.mp4")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Missing shot files" in result.error_message
|
||||||
|
|
||||||
|
def test_assemble_single_file(self, mock_subprocess, temp_dir):
|
||||||
|
"""Test simple concatenation with single file."""
|
||||||
|
# Create dummy input file
|
||||||
|
input_file = temp_dir / "shot_001.mp4"
|
||||||
|
input_file.write_bytes(b"dummy video data")
|
||||||
|
output_file = temp_dir / "output.mp4"
|
||||||
|
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
|
||||||
|
# Mock ffprobe for duration check
|
||||||
|
with patch.object(assembler, '_get_video_duration', return_value=4.0):
|
||||||
|
result = assembler.assemble([input_file], output_file)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.output_path == output_file
|
||||||
|
assert result.num_shots == 1
|
||||||
|
|
||||||
|
# Verify ffmpeg was called
|
||||||
|
assert mock_subprocess.called
|
||||||
|
call_args = mock_subprocess.call_args[0][0]
|
||||||
|
assert "ffmpeg" in call_args[0]
|
||||||
|
assert "-f" in call_args
|
||||||
|
assert "concat" in call_args
|
||||||
|
|
||||||
|
def test_assemble_multiple_files(self, mock_subprocess, temp_dir):
|
||||||
|
"""Test concatenation with multiple files."""
|
||||||
|
# Create dummy input files
|
||||||
|
input_files = [
|
||||||
|
temp_dir / "shot_001.mp4",
|
||||||
|
temp_dir / "shot_002.mp4",
|
||||||
|
temp_dir / "shot_003.mp4"
|
||||||
|
]
|
||||||
|
for f in input_files:
|
||||||
|
f.write_bytes(b"dummy video data")
|
||||||
|
|
||||||
|
output_file = temp_dir / "output.mp4"
|
||||||
|
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
|
||||||
|
with patch.object(assembler, '_get_video_duration', return_value=4.0):
|
||||||
|
result = assembler.assemble(input_files, output_file)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.num_shots == 3
|
||||||
|
assert result.output_path == output_file
|
||||||
|
|
||||||
|
def test_assemble_with_config(self, mock_subprocess, temp_dir):
|
||||||
|
"""Test assembly with custom configuration."""
|
||||||
|
input_file = temp_dir / "shot_001.mp4"
|
||||||
|
input_file.write_bytes(b"dummy video data")
|
||||||
|
output_file = temp_dir / "output.mp4"
|
||||||
|
|
||||||
|
config = AssemblyConfig(
|
||||||
|
fps=30,
|
||||||
|
codec="h265",
|
||||||
|
crf=20,
|
||||||
|
preset="slow"
|
||||||
|
)
|
||||||
|
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
|
||||||
|
with patch.object(assembler, '_get_video_duration', return_value=4.0):
|
||||||
|
result = assembler.assemble([input_file], output_file, config)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
# Verify config was used
|
||||||
|
call_args = mock_subprocess.call_args[0][0]
|
||||||
|
assert "-r" in call_args
|
||||||
|
assert "30" in call_args
|
||||||
|
assert "libx265" in call_args
|
||||||
|
assert "-crf" in call_args
|
||||||
|
assert "20" in call_args
|
||||||
|
assert "slow" in call_args
|
||||||
|
|
||||||
|
def test_assemble_ffmpeg_error(self, mock_subprocess, temp_dir):
|
||||||
|
"""Test handling of ffmpeg error."""
|
||||||
|
input_file = temp_dir / "shot_001.mp4"
|
||||||
|
input_file.write_bytes(b"dummy video data")
|
||||||
|
output_file = temp_dir / "output.mp4"
|
||||||
|
|
||||||
|
# Make ffmpeg return error
|
||||||
|
mock_subprocess.return_value = MagicMock(
|
||||||
|
returncode=1,
|
||||||
|
stderr="Error: Invalid data"
|
||||||
|
)
|
||||||
|
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
result = assembler.assemble([input_file], output_file)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "FFmpeg error" in result.error_message
|
||||||
|
|
||||||
|
|
||||||
|
class TestFFmpegAssemblerTransitions:
|
||||||
|
"""Test transition functionality."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_subprocess(self):
|
||||||
|
"""Fixture to mock subprocess.run."""
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
mock_run.return_value = MagicMock(returncode=0, stderr="")
|
||||||
|
yield mock_run
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir(self):
|
||||||
|
"""Fixture for temporary directory."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
def test_fade_transition(self, mock_subprocess, temp_dir):
|
||||||
|
"""Test assembly with fade transition."""
|
||||||
|
input_files = [
|
||||||
|
temp_dir / "shot_001.mp4",
|
||||||
|
temp_dir / "shot_002.mp4"
|
||||||
|
]
|
||||||
|
for f in input_files:
|
||||||
|
f.write_bytes(b"dummy video data")
|
||||||
|
|
||||||
|
output_file = temp_dir / "output.mp4"
|
||||||
|
|
||||||
|
config = AssemblyConfig(
|
||||||
|
transition=TransitionType.FADE,
|
||||||
|
transition_duration_ms=500
|
||||||
|
)
|
||||||
|
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
|
||||||
|
with patch.object(assembler, '_get_video_duration', return_value=4.0):
|
||||||
|
result = assembler.assemble(input_files, output_file, config)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
# Verify filter_complex was used
|
||||||
|
call_args = mock_subprocess.call_args[0][0]
|
||||||
|
assert "-filter_complex" in call_args
|
||||||
|
assert "xfade" in str(call_args)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFFmpegAssemblerAudio:
|
||||||
|
"""Test audio functionality."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_subprocess(self):
|
||||||
|
"""Fixture to mock subprocess.run."""
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
mock_run.return_value = MagicMock(returncode=0, stderr="")
|
||||||
|
yield mock_run
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir(self):
|
||||||
|
"""Fixture for temporary directory."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
def test_add_audio(self, mock_subprocess, temp_dir):
|
||||||
|
"""Test adding audio track."""
|
||||||
|
input_file = temp_dir / "shot_001.mp4"
|
||||||
|
input_file.write_bytes(b"dummy video data")
|
||||||
|
audio_file = temp_dir / "audio.mp3"
|
||||||
|
audio_file.write_bytes(b"dummy audio data")
|
||||||
|
output_file = temp_dir / "output.mp4"
|
||||||
|
|
||||||
|
config = AssemblyConfig(audio_track=audio_file)
|
||||||
|
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
|
||||||
|
with patch.object(assembler, '_get_video_duration', return_value=4.0):
|
||||||
|
result = assembler.assemble([input_file], output_file, config)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
# Verify audio was added (second ffmpeg call)
|
||||||
|
calls = mock_subprocess.call_args_list
|
||||||
|
# First call is for video, second should be for audio
|
||||||
|
assert len(calls) >= 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestFFmpegAssemblerUtilities:
|
||||||
|
"""Test utility methods."""
|
||||||
|
|
||||||
|
def test_get_video_codec(self):
|
||||||
|
"""Test codec name mapping."""
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
|
||||||
|
assert assembler._get_video_codec("h264") == "libx264"
|
||||||
|
assert assembler._get_video_codec("h265") == "libx265"
|
||||||
|
assert assembler._get_video_codec("vp9") == "libvpx-vp9"
|
||||||
|
assert assembler._get_video_codec("unknown") == "libx264" # Default
|
||||||
|
|
||||||
|
def test_get_video_duration(self):
|
||||||
|
"""Test duration extraction."""
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
mock_run.return_value = MagicMock(
|
||||||
|
returncode=0,
|
||||||
|
stdout="4.500\n"
|
||||||
|
)
|
||||||
|
duration = assembler._get_video_duration(Path("test.mp4"))
|
||||||
|
assert duration == 4.5
|
||||||
|
|
||||||
|
def test_get_video_duration_error(self):
|
||||||
|
"""Test duration extraction with error."""
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
mock_run.return_value = MagicMock(returncode=1)
|
||||||
|
duration = assembler._get_video_duration(Path("test.mp4"))
|
||||||
|
assert duration == 0.0
|
||||||
|
|
||||||
|
def test_extract_frame(self):
|
||||||
|
"""Test frame extraction."""
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
mock_run.return_value = MagicMock(returncode=0)
|
||||||
|
result = assembler.extract_frame(
|
||||||
|
Path("video.mp4"),
|
||||||
|
2.5,
|
||||||
|
Path("frame.jpg")
|
||||||
|
)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
# Verify ffmpeg command
|
||||||
|
call_args = mock_run.call_args[0][0]
|
||||||
|
assert "-ss" in call_args
|
||||||
|
assert "2.5" in call_args
|
||||||
|
assert "-vframes" in call_args
|
||||||
|
assert "1" in call_args
|
||||||
|
|
||||||
|
def test_burn_in_labels(self):
|
||||||
|
"""Test label burn-in."""
|
||||||
|
assembler = FFmpegAssembler()
|
||||||
|
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
mock_run.return_value = MagicMock(returncode=0)
|
||||||
|
result = assembler.burn_in_labels(
|
||||||
|
Path("input.mp4"),
|
||||||
|
Path("output.mp4"),
|
||||||
|
["Shot 1", "Shot 2"],
|
||||||
|
font_size=24,
|
||||||
|
position="top-left"
|
||||||
|
)
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
# Verify drawtext filter
|
||||||
|
call_args = mock_run.call_args[0][0]
|
||||||
|
assert "-vf" in call_args
|
||||||
|
assert "drawtext" in str(call_args)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssemblyResult:
|
||||||
|
"""Test AssemblyResult dataclass."""
|
||||||
|
|
||||||
|
def test_success_result(self):
|
||||||
|
"""Test successful result."""
|
||||||
|
result = AssemblyResult(
|
||||||
|
success=True,
|
||||||
|
output_path=Path("output.mp4"),
|
||||||
|
duration_s=12.5,
|
||||||
|
num_shots=3
|
||||||
|
)
|
||||||
|
assert result.success is True
|
||||||
|
assert result.output_path == Path("output.mp4")
|
||||||
|
assert result.duration_s == 12.5
|
||||||
|
assert result.num_shots == 3
|
||||||
|
assert result.metadata == {}
|
||||||
|
|
||||||
|
def test_failure_result(self):
|
||||||
|
"""Test failure result."""
|
||||||
|
result = AssemblyResult(
|
||||||
|
success=False,
|
||||||
|
error_message="FFmpeg failed"
|
||||||
|
)
|
||||||
|
assert result.success is False
|
||||||
|
assert result.error_message == "FFmpeg failed"
|
||||||
|
assert result.output_path is None
|
||||||
187
tests/unit/test_cli.py
Normal file
187
tests/unit/test_cli.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
"""
|
||||||
|
Tests for CLI module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from src.cli.main import app
|
||||||
|
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
# Valid storyboard JSON template
|
||||||
|
VALID_STORYBOARD = """
|
||||||
|
{
|
||||||
|
"schema_version": "1.0",
|
||||||
|
"project": {
|
||||||
|
"title": "Test Storyboard",
|
||||||
|
"fps": 24,
|
||||||
|
"target_duration_s": 20,
|
||||||
|
"resolution": {"width": 1280, "height": 720},
|
||||||
|
"aspect_ratio": "16:9"
|
||||||
|
},
|
||||||
|
"shots": [
|
||||||
|
{
|
||||||
|
"id": "shot_001",
|
||||||
|
"duration_s": 4,
|
||||||
|
"prompt": "A test shot description",
|
||||||
|
"camera": {"framing": "wide", "movement": "static"},
|
||||||
|
"generation": {"seed": 42, "steps": 30, "cfg_scale": 6.0}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output": {
|
||||||
|
"container": "mp4",
|
||||||
|
"codec": "h264",
|
||||||
|
"crf": 18,
|
||||||
|
"preset": "medium"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateCommand:
|
||||||
|
"""Test the validate command."""
|
||||||
|
|
||||||
|
def test_validate_nonexistent_file(self):
|
||||||
|
"""Test validation with non-existent file."""
|
||||||
|
result = runner.invoke(app, ["validate", "nonexistent.json"])
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "not found" in result.output
|
||||||
|
|
||||||
|
def test_validate_valid_storyboard(self):
|
||||||
|
"""Test validation with valid storyboard."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
storyboard_file = Path(tmpdir) / "test.json"
|
||||||
|
storyboard_file.write_text(VALID_STORYBOARD)
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["validate", str(storyboard_file)])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "valid" in result.output
|
||||||
|
assert "Test Storyboard" in result.output
|
||||||
|
|
||||||
|
def test_validate_verbose(self):
|
||||||
|
"""Test validation with verbose flag."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
storyboard_file = Path(tmpdir) / "test.json"
|
||||||
|
storyboard_file.write_text(VALID_STORYBOARD)
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["validate", str(storyboard_file), "--verbose"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Shots" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestListBackendsCommand:
|
||||||
|
"""Test the list-backends command."""
|
||||||
|
|
||||||
|
def test_list_backends(self):
|
||||||
|
"""Test listing available backends."""
|
||||||
|
result = runner.invoke(app, ["list-backends"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "wan" in result.output
|
||||||
|
assert "Available Backends" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateCommand:
|
||||||
|
"""Test the generate command."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_storyboard(self):
|
||||||
|
"""Create a mock storyboard file."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
storyboard_file = Path(tmpdir) / "test.json"
|
||||||
|
storyboard_file.write_text(VALID_STORYBOARD)
|
||||||
|
yield storyboard_file
|
||||||
|
|
||||||
|
def test_generate_dry_run(self, mock_storyboard):
|
||||||
|
"""Test generate command with dry-run flag."""
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"generate",
|
||||||
|
str(mock_storyboard),
|
||||||
|
"--dry-run"
|
||||||
|
])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Dry run complete" in result.output
|
||||||
|
|
||||||
|
def test_generate_nonexistent_storyboard(self):
|
||||||
|
"""Test generate with non-existent storyboard."""
|
||||||
|
result = runner.invoke(app, ["generate", "nonexistent.json"])
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "not found" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
class TestResumeCommand:
|
||||||
|
"""Test the resume command."""
|
||||||
|
|
||||||
|
def test_resume_nonexistent_project(self):
|
||||||
|
"""Test resume with non-existent project."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"resume",
|
||||||
|
"nonexistent_project",
|
||||||
|
"--output", tmpdir
|
||||||
|
])
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "No checkpoint found" in result.output
|
||||||
|
|
||||||
|
def test_resume_existing_project(self):
|
||||||
|
"""Test resume with existing project checkpoint."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
project_dir = Path(tmpdir) / "test_project"
|
||||||
|
project_dir.mkdir()
|
||||||
|
checkpoint_db = project_dir / "checkpoints.db"
|
||||||
|
checkpoint_db.touch()
|
||||||
|
|
||||||
|
# Create a mock checkpoint
|
||||||
|
with patch('src.cli.main.CheckpointManager') as mock_mgr:
|
||||||
|
mock_checkpoint = MagicMock()
|
||||||
|
mock_checkpoint.storyboard_path = str(Path(tmpdir) / "storyboard.json")
|
||||||
|
mock_mgr.return_value.get_project_checkpoint.return_value = mock_checkpoint
|
||||||
|
|
||||||
|
# Create storyboard file
|
||||||
|
storyboard_file = Path(tmpdir) / "storyboard.json"
|
||||||
|
storyboard_file.write_text(VALID_STORYBOARD)
|
||||||
|
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"resume",
|
||||||
|
"test_project",
|
||||||
|
"--output", tmpdir
|
||||||
|
])
|
||||||
|
|
||||||
|
# Should attempt to resume (may fail on actual generation)
|
||||||
|
assert "Resuming project" in result.output or result.exit_code != 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCLIHelp:
|
||||||
|
"""Test CLI help messages."""
|
||||||
|
|
||||||
|
def test_main_help(self):
|
||||||
|
"""Test main help message."""
|
||||||
|
result = runner.invoke(app, ["--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Storyboard to Video Generation Pipeline" in result.output
|
||||||
|
|
||||||
|
def test_generate_help(self):
|
||||||
|
"""Test generate command help."""
|
||||||
|
result = runner.invoke(app, ["generate", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Generate video from storyboard" in result.output
|
||||||
|
assert "--output" in result.output
|
||||||
|
assert "--backend" in result.output
|
||||||
|
assert "--dry-run" in result.output
|
||||||
|
|
||||||
|
def test_validate_help(self):
|
||||||
|
"""Test validate command help."""
|
||||||
|
result = runner.invoke(app, ["validate", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Validate a storyboard file" in result.output
|
||||||
|
|
||||||
|
def test_resume_help(self):
|
||||||
|
"""Test resume command help."""
|
||||||
|
result = runner.invoke(app, ["resume", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Resume a failed or interrupted project" in result.output
|
||||||
350
tests/unit/test_upscaler.py
Normal file
350
tests/unit/test_upscaler.py
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
"""
|
||||||
|
Tests for upscaling module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch, MagicMock, mock_open
|
||||||
|
from src.upscaling.upscaler import (
|
||||||
|
UpscaleConfig,
|
||||||
|
UpscaleResult,
|
||||||
|
UpscalerType,
|
||||||
|
UpscaleFactor,
|
||||||
|
FFmpegSRUpscaler,
|
||||||
|
RealESRGANUpscaler,
|
||||||
|
UpscaleManager,
|
||||||
|
BaseUpscaler
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpscaleConfig:
|
||||||
|
"""Test UpscaleConfig dataclass."""
|
||||||
|
|
||||||
|
def test_default_config(self):
|
||||||
|
"""Test default configuration values."""
|
||||||
|
config = UpscaleConfig()
|
||||||
|
assert config.upscaler_type == UpscalerType.FFMPEG_SR
|
||||||
|
assert config.factor == UpscaleFactor.X2
|
||||||
|
assert config.denoise_strength == 0.5
|
||||||
|
assert config.tile_size == 0
|
||||||
|
assert config.half_precision is True
|
||||||
|
assert config.device == "cuda"
|
||||||
|
|
||||||
|
def test_custom_config(self):
|
||||||
|
"""Test custom configuration."""
|
||||||
|
config = UpscaleConfig(
|
||||||
|
upscaler_type=UpscalerType.REAL_ESRGAN,
|
||||||
|
factor=UpscaleFactor.X4,
|
||||||
|
denoise_strength=0.8,
|
||||||
|
device="cpu"
|
||||||
|
)
|
||||||
|
assert config.upscaler_type == UpscalerType.REAL_ESRGAN
|
||||||
|
assert config.factor == UpscaleFactor.X4
|
||||||
|
assert config.denoise_strength == 0.8
|
||||||
|
assert config.device == "cpu"
|
||||||
|
|
||||||
|
def test_config_to_dict(self):
|
||||||
|
"""Test configuration serialization."""
|
||||||
|
config = UpscaleConfig(
|
||||||
|
factor=UpscaleFactor.X4,
|
||||||
|
model_path=Path("/models/model.pth")
|
||||||
|
)
|
||||||
|
d = config.to_dict()
|
||||||
|
assert d["factor"] == 4
|
||||||
|
assert d["upscaler_type"] == "ffmpeg_sr"
|
||||||
|
# Path conversion is platform-specific
|
||||||
|
assert "model.pth" in d["model_path"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpscaleResult:
|
||||||
|
"""Test UpscaleResult dataclass."""
|
||||||
|
|
||||||
|
def test_success_result(self):
|
||||||
|
"""Test successful result."""
|
||||||
|
result = UpscaleResult(
|
||||||
|
success=True,
|
||||||
|
output_path=Path("output.mp4"),
|
||||||
|
input_resolution=(1920, 1080),
|
||||||
|
output_resolution=(3840, 2160),
|
||||||
|
processing_time_s=45.5
|
||||||
|
)
|
||||||
|
assert result.success is True
|
||||||
|
assert result.output_path == Path("output.mp4")
|
||||||
|
assert result.input_resolution == (1920, 1080)
|
||||||
|
assert result.output_resolution == (3840, 2160)
|
||||||
|
assert result.processing_time_s == 45.5
|
||||||
|
|
||||||
|
def test_failure_result(self):
|
||||||
|
"""Test failure result."""
|
||||||
|
result = UpscaleResult(
|
||||||
|
success=False,
|
||||||
|
error_message="FFmpeg not found"
|
||||||
|
)
|
||||||
|
assert result.success is False
|
||||||
|
assert result.error_message == "FFmpeg not found"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFFmpegSRUpscaler:
|
||||||
|
"""Test FFmpeg-based upscaler."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_subprocess(self):
|
||||||
|
"""Fixture to mock subprocess.run."""
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="")
|
||||||
|
yield mock_run
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir(self):
|
||||||
|
"""Fixture for temporary directory."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
def test_is_available_success(self, mock_subprocess):
|
||||||
|
"""Test availability check when ffmpeg is present."""
|
||||||
|
config = UpscaleConfig()
|
||||||
|
upscaler = FFmpegSRUpscaler(config)
|
||||||
|
assert upscaler.is_available() is True
|
||||||
|
|
||||||
|
def test_is_available_failure(self):
|
||||||
|
"""Test availability check when ffmpeg is missing."""
|
||||||
|
with patch('subprocess.run', side_effect=FileNotFoundError()):
|
||||||
|
config = UpscaleConfig()
|
||||||
|
upscaler = FFmpegSRUpscaler(config)
|
||||||
|
assert upscaler.is_available() is False
|
||||||
|
|
||||||
|
def test_upscale_missing_input(self, mock_subprocess):
|
||||||
|
"""Test upscaling with missing input file."""
|
||||||
|
config = UpscaleConfig()
|
||||||
|
upscaler = FFmpegSRUpscaler(config)
|
||||||
|
|
||||||
|
result = upscaler.upscale(Path("nonexistent.mp4"), Path("output.mp4"))
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "not found" in result.error_message
|
||||||
|
|
||||||
|
def test_upscale_success(self, mock_subprocess, temp_dir):
|
||||||
|
"""Test successful upscaling."""
|
||||||
|
# Create dummy input file
|
||||||
|
input_file = temp_dir / "input.mp4"
|
||||||
|
input_file.write_bytes(b"dummy video")
|
||||||
|
output_file = temp_dir / "output.mp4"
|
||||||
|
|
||||||
|
# Mock ffprobe response
|
||||||
|
ffprobe_response = MagicMock(
|
||||||
|
returncode=0,
|
||||||
|
stdout='{"streams": [{"width": 1920, "height": 1080, "r_frame_rate": "24/1", "duration": "10.0"}]}'
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
# First call is ffprobe, second is ffmpeg
|
||||||
|
mock_run.side_effect = [ffprobe_response, MagicMock(returncode=0)]
|
||||||
|
|
||||||
|
config = UpscaleConfig(factor=UpscaleFactor.X2)
|
||||||
|
upscaler = FFmpegSRUpscaler(config)
|
||||||
|
result = upscaler.upscale(input_file, output_file)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.input_resolution == (1920, 1080)
|
||||||
|
assert result.output_resolution == (3840, 2160)
|
||||||
|
|
||||||
|
def test_upscale_ffmpeg_error(self, mock_subprocess, temp_dir):
|
||||||
|
"""Test handling of ffmpeg error."""
|
||||||
|
input_file = temp_dir / "input.mp4"
|
||||||
|
input_file.write_bytes(b"dummy video")
|
||||||
|
output_file = temp_dir / "output.mp4"
|
||||||
|
|
||||||
|
ffprobe_response = MagicMock(
|
||||||
|
returncode=0,
|
||||||
|
stdout='{"streams": [{"width": 1920, "height": 1080, "r_frame_rate": "24/1", "duration": "10.0"}]}'
|
||||||
|
)
|
||||||
|
|
||||||
|
ffmpeg_error = MagicMock(returncode=1, stderr="Invalid data")
|
||||||
|
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
mock_run.side_effect = [ffprobe_response, ffmpeg_error]
|
||||||
|
|
||||||
|
config = UpscaleConfig()
|
||||||
|
upscaler = FFmpegSRUpscaler(config)
|
||||||
|
result = upscaler.upscale(input_file, output_file)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "FFmpeg error" in result.error_message
|
||||||
|
|
||||||
|
|
||||||
|
class TestRealESRGANUpscaler:
|
||||||
|
"""Test Real-ESRGAN upscaler."""
|
||||||
|
|
||||||
|
def test_is_available_without_realesrgan(self):
|
||||||
|
"""Test availability when Real-ESRGAN is not installed."""
|
||||||
|
with patch.dict('sys.modules', {'realesrgan': None}):
|
||||||
|
config = UpscaleConfig()
|
||||||
|
upscaler = RealESRGANUpscaler(config)
|
||||||
|
assert upscaler.is_available() is False
|
||||||
|
|
||||||
|
def test_upscale_without_realesrgan(self):
|
||||||
|
"""Test upscaling when Real-ESRGAN is not installed."""
|
||||||
|
# Mock the import to fail
|
||||||
|
import sys
|
||||||
|
original_modules = sys.modules.copy()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Remove realesrgan and torch from modules to simulate not installed
|
||||||
|
for mod in list(sys.modules.keys()):
|
||||||
|
if 'realesrgan' in mod or 'torch' in mod:
|
||||||
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
# Mock import to raise ImportError
|
||||||
|
def mock_import(name, *args, **kwargs):
|
||||||
|
if name == 'realesrgan' or name.startswith('realesrgan.'):
|
||||||
|
raise ImportError("No module named 'realesrgan'")
|
||||||
|
if name == 'torch' or name.startswith('torch.'):
|
||||||
|
raise ImportError("No module named 'torch'")
|
||||||
|
return original_modules.get(name, __import__(name, *args, **kwargs))
|
||||||
|
|
||||||
|
with patch('builtins.__import__', side_effect=mock_import):
|
||||||
|
config = UpscaleConfig()
|
||||||
|
upscaler = RealESRGANUpscaler(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
input_file = Path(tmpdir) / "input.mp4"
|
||||||
|
input_file.write_bytes(b"dummy")
|
||||||
|
output_file = Path(tmpdir) / "output.mp4"
|
||||||
|
|
||||||
|
result = upscaler.upscale(input_file, output_file)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "not installed" in result.error_message
|
||||||
|
finally:
|
||||||
|
# Restore original modules
|
||||||
|
sys.modules.clear()
|
||||||
|
sys.modules.update(original_modules)
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpscaleManager:
|
||||||
|
"""Test UpscaleManager."""
|
||||||
|
|
||||||
|
def test_init_with_default_config(self):
|
||||||
|
"""Test initialization with default config."""
|
||||||
|
manager = UpscaleManager()
|
||||||
|
assert manager.config.upscaler_type == UpscalerType.FFMPEG_SR
|
||||||
|
|
||||||
|
def test_init_with_custom_config(self):
|
||||||
|
"""Test initialization with custom config."""
|
||||||
|
config = UpscaleConfig(upscaler_type=UpscalerType.REAL_ESRGAN)
|
||||||
|
manager = UpscaleManager(config)
|
||||||
|
assert manager.config.upscaler_type == UpscalerType.REAL_ESRGAN
|
||||||
|
|
||||||
|
def test_get_upscaler_caching(self):
|
||||||
|
"""Test that upscalers are cached."""
|
||||||
|
with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True):
|
||||||
|
config = UpscaleConfig()
|
||||||
|
manager = UpscaleManager(config)
|
||||||
|
|
||||||
|
upscaler1 = manager.get_upscaler(UpscalerType.FFMPEG_SR)
|
||||||
|
upscaler2 = manager.get_upscaler(UpscalerType.FFMPEG_SR)
|
||||||
|
|
||||||
|
assert upscaler1 is upscaler2 # Same instance
|
||||||
|
|
||||||
|
def test_get_upscaler_unavailable(self):
|
||||||
|
"""Test getting unavailable upscaler."""
|
||||||
|
with patch.object(FFmpegSRUpscaler, 'is_available', return_value=False):
|
||||||
|
config = UpscaleConfig()
|
||||||
|
manager = UpscaleManager(config)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
manager.get_upscaler(UpscalerType.FFMPEG_SR)
|
||||||
|
|
||||||
|
assert "not available" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_upscale_with_manager(self):
|
||||||
|
"""Test upscale through manager."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
input_file = Path(tmpdir) / "input.mp4"
|
||||||
|
input_file.write_bytes(b"dummy")
|
||||||
|
output_file = Path(tmpdir) / "output.mp4"
|
||||||
|
|
||||||
|
mock_result = UpscaleResult(
|
||||||
|
success=True,
|
||||||
|
output_path=output_file,
|
||||||
|
input_resolution=(1920, 1080),
|
||||||
|
output_resolution=(3840, 2160)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True):
|
||||||
|
with patch.object(FFmpegSRUpscaler, 'upscale', return_value=mock_result):
|
||||||
|
config = UpscaleConfig()
|
||||||
|
manager = UpscaleManager(config)
|
||||||
|
result = manager.upscale(input_file, output_file)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.output_resolution == (3840, 2160)
|
||||||
|
|
||||||
|
def test_get_available_upscalers(self):
|
||||||
|
"""Test getting list of available upscalers."""
|
||||||
|
with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True):
|
||||||
|
with patch.object(RealESRGANUpscaler, 'is_available', return_value=False):
|
||||||
|
config = UpscaleConfig()
|
||||||
|
manager = UpscaleManager(config)
|
||||||
|
available = manager.get_available_upscalers()
|
||||||
|
|
||||||
|
assert UpscalerType.FFMPEG_SR in available
|
||||||
|
assert UpscalerType.REAL_ESRGAN not in available
|
||||||
|
|
||||||
|
def test_upscale_with_fallback_success(self):
|
||||||
|
"""Test fallback upscaling with first option succeeding."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
input_file = Path(tmpdir) / "input.mp4"
|
||||||
|
input_file.write_bytes(b"dummy")
|
||||||
|
output_file = Path(tmpdir) / "output.mp4"
|
||||||
|
|
||||||
|
mock_result = UpscaleResult(
|
||||||
|
success=True,
|
||||||
|
output_path=output_file,
|
||||||
|
input_resolution=(1920, 1080),
|
||||||
|
output_resolution=(3840, 2160)
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True):
|
||||||
|
with patch.object(FFmpegSRUpscaler, 'upscale', return_value=mock_result):
|
||||||
|
config = UpscaleConfig()
|
||||||
|
manager = UpscaleManager(config)
|
||||||
|
result = manager.upscale_with_fallback(input_file, output_file)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
def test_upscale_with_fallback_all_fail(self):
|
||||||
|
"""Test fallback upscaling when all options fail."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
input_file = Path(tmpdir) / "input.mp4"
|
||||||
|
input_file.write_bytes(b"dummy")
|
||||||
|
output_file = Path(tmpdir) / "output.mp4"
|
||||||
|
|
||||||
|
mock_result = UpscaleResult(
|
||||||
|
success=False,
|
||||||
|
error_message="Processing failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(FFmpegSRUpscaler, 'is_available', return_value=True):
|
||||||
|
with patch.object(FFmpegSRUpscaler, 'upscale', return_value=mock_result):
|
||||||
|
config = UpscaleConfig()
|
||||||
|
manager = UpscaleManager(config)
|
||||||
|
result = manager.upscale_with_fallback(input_file, output_file)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "All upscalers failed" in result.error_message
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpscaleEnums:
|
||||||
|
"""Test upscaler enums."""
|
||||||
|
|
||||||
|
def test_upscaler_type_values(self):
|
||||||
|
"""Test upscaler type enum values."""
|
||||||
|
assert UpscalerType.REAL_ESRGAN.value == "real_esrgan"
|
||||||
|
assert UpscalerType.FFMPEG_SR.value == "ffmpeg_sr"
|
||||||
|
|
||||||
|
def test_upscale_factor_values(self):
|
||||||
|
"""Test upscale factor enum values."""
|
||||||
|
assert UpscaleFactor.X2.value == 2
|
||||||
|
assert UpscaleFactor.X4.value == 4
|
||||||
@@ -92,7 +92,7 @@ class TestWanBackend:
|
|||||||
|
|
||||||
# Double resolution (4x pixels)
|
# Double resolution (4x pixels)
|
||||||
vram_1440_81 = backend.estimate_vram_usage(1920, 1080, 81)
|
vram_1440_81 = backend.estimate_vram_usage(1920, 1080, 81)
|
||||||
assert vram_1440_81 > vram_720_81 * 3 # Should be ~4x
|
assert vram_1440_81 > vram_720_81 * 2 # Should be ~2x due to model overhead
|
||||||
|
|
||||||
# Double frames
|
# Double frames
|
||||||
vram_720_162 = backend.estimate_vram_usage(1280, 720, 162)
|
vram_720_162 = backend.estimate_vram_usage(1280, 720, 162)
|
||||||
|
|||||||
Reference in New Issue
Block a user