Files
video-gen/src/cli/main.py

322 lines
13 KiB
Python

"""
CLI entry point for storyboard video generation.
"""
import sys
from pathlib import Path
from typing import Optional
import typer
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.table import Table
from src.storyboard.loader import StoryboardValidator
from src.storyboard.prompt_compiler import PromptCompiler
from src.storyboard.shot_planner import ShotPlanner
from src.core.config import ConfigLoader
from src.core.checkpoint import CheckpointManager
from src.generation.base import BackendFactory
from src.assembly.assembler import FFmpegAssembler, AssemblyConfig
from src.upscaling.upscaler import UpscaleManager, UpscaleConfig
app = typer.Typer(help="Storyboard to Video Generation Pipeline")
console = Console()
@app.command()
def generate(
storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"),
output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"),
config: Optional[Path] = typer.Option(None, "--config", "-c", help="Path to config file"),
backend: str = typer.Option("wan", "--backend", "-b", help="Generation backend (wan, svd)"),
resume: bool = typer.Option(False, "--resume", "-r", help="Resume from checkpoint"),
skip_generation: bool = typer.Option(False, "--skip-generation", help="Skip generation, only assemble"),
skip_assembly: bool = typer.Option(False, "--skip-assembly", help="Skip assembly, only generate shots"),
upscale: Optional[int] = typer.Option(None, "--upscale", help="Upscale factor (2 or 4)"),
dry_run: bool = typer.Option(False, "--dry-run", help="Validate storyboard without generating"),
):
"""Generate video from storyboard."""
# Validate storyboard exists
if not storyboard.exists():
console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]")
raise typer.Exit(1)
# Load configuration
try:
if config:
app_config = ConfigLoader.load(config)
else:
app_config = ConfigLoader.load()
except Exception as e:
console.print(f"[red]Error loading config: {e}[/red]")
raise typer.Exit(1)
# Validate storyboard
console.print("[bold blue]Validating storyboard...[/bold blue]")
validator = StoryboardValidator()
try:
storyboard_data = validator.load(storyboard)
console.print(f"[green]✓[/green] Storyboard validated: {len(storyboard_data.shots)} shots")
except Exception as e:
console.print(f"[red]Error validating storyboard: {e}[/red]")
raise typer.Exit(1)
if dry_run:
console.print("[yellow]Dry run complete. Exiting.[/yellow]")
return
# Setup output directories
project_dir = output / storyboard_data.project.title.replace(" ", "_")
shots_dir = project_dir / "shots"
assembled_dir = project_dir / "assembled"
metadata_dir = project_dir / "metadata"
for d in [shots_dir, assembled_dir, metadata_dir]:
d.mkdir(parents=True, exist_ok=True)
# Initialize checkpoint manager
checkpoint_db = project_dir / "checkpoints.db"
checkpoint_mgr = CheckpointManager(str(checkpoint_db))
# Initialize components
prompt_compiler = PromptCompiler(storyboard_data.project.global_style)
shot_planner = ShotPlanner(
fps=storyboard_data.project.fps or 24,
max_chunk_duration=app_config.backend.max_chunk_seconds or 6.0
)
# Initialize generation backend
if not skip_generation:
console.print(f"[bold blue]Initializing {backend} backend...[/bold blue]")
try:
backend_config = app_config.get_backend_config(backend)
video_backend = BackendFactory.create_backend(backend, backend_config)
console.print(f"[green]✓[/green] Backend initialized: {backend}")
except Exception as e:
console.print(f"[red]Error initializing backend: {e}[/red]")
raise typer.Exit(1)
# Generate shots
generated_shots = []
if not skip_generation:
console.print(f"[bold blue]Generating {len(storyboard_data.shots)} shots...[/bold blue]")
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console
) as progress:
for i, shot in enumerate(storyboard_data.shots):
task_id = progress.add_task(f"Shot {shot.id}", total=None)
# Check checkpoint
if resume:
checkpoint = checkpoint_mgr.get_shot_checkpoint(
storyboard_data.project.title, shot.id
)
if checkpoint and checkpoint.status == "completed":
progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id} (cached)")
generated_shots.append(Path(checkpoint.video_path))
continue
# Compile prompt
prompt = prompt_compiler.compile_shot_prompt(shot)
negative_prompt = prompt_compiler.compile_negative_prompt(shot)
# Plan shot
shot_plan = shot_planner.plan_shot(shot)
# Generate
try:
from src.generation.base import GenerationSpec
spec = GenerationSpec(
prompt=prompt,
negative_prompt=negative_prompt,
width=shot_plan.width,
height=shot_plan.height,
num_frames=shot_plan.total_frames,
fps=shot_plan.fps,
seed=shot.generation.seed
)
result = video_backend.generate(spec, output_dir=shots_dir)
if result.success:
generated_shots.append(result.video_path)
checkpoint_mgr.save_shot_checkpoint(
project_name=storyboard_data.project.title,
shot_id=shot.id,
status="completed",
video_path=str(result.video_path),
metadata=result.metadata
)
progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id}")
else:
progress.update(task_id, description=f"[red]✗[/red] Shot {shot.id}: {result.error_message}")
checkpoint_mgr.save_shot_checkpoint(
project_name=storyboard_data.project.title,
shot_id=shot.id,
status="failed",
error_message=result.error_message
)
except Exception as e:
progress.update(task_id, description=f"[red]✗[/red] Shot {shot.id}: {e}")
checkpoint_mgr.save_shot_checkpoint(
project_name=storyboard_data.project.title,
shot_id=shot.id,
status="failed",
error_message=str(e)
)
# Assembly
if not skip_assembly and generated_shots:
console.print("[bold blue]Assembling video...[/bold blue]")
assembler = FFmpegAssembler()
final_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}.mp4"
assembly_config = AssemblyConfig(
fps=storyboard_data.project.fps or 24,
add_shot_labels=False
)
result = assembler.assemble(generated_shots, final_output, assembly_config)
if result.success:
console.print(f"[green]✓[/green] Video assembled: {final_output}")
# Upscale if requested
if upscale:
console.print(f"[bold blue]Upscaling {upscale}x...[/bold blue]")
upscale_config = UpscaleConfig(
factor=upscale,
upscaler_type="ffmpeg_sr"
)
upscale_mgr = UpscaleManager(upscale_config)
upscaled_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}_upscaled.mp4"
upscale_result = upscale_mgr.upscale(final_output, upscaled_output)
if upscale_result.success:
console.print(f"[green]✓[/green] Upscaled video: {upscaled_output}")
final_output = upscaled_output
else:
console.print(f"[yellow]Warning: Upscaling failed: {upscale_result.error_message}[/yellow]")
# Save project metadata
checkpoint_mgr.save_project_checkpoint(
project_name=storyboard_data.project.title,
storyboard_path=str(storyboard),
output_path=str(final_output),
status="completed",
num_shots=len(generated_shots)
)
console.print(f"\n[bold green]Success![/bold green] Video saved to: {final_output}")
else:
console.print(f"[red]Error assembling video: {result.error_message}[/red]")
raise typer.Exit(1)
@app.command()
def validate(
storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed information"),
):
"""Validate a storyboard file."""
if not storyboard.exists():
console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]")
raise typer.Exit(1)
validator = StoryboardValidator()
try:
data = validator.load(storyboard)
console.print(f"[green]✓[/green] Storyboard is valid")
console.print(f"\n[bold]Title:[/bold] {data.project.title}")
console.print(f"[bold]Shots:[/bold] {len(data.shots)}")
console.print(f"[bold]FPS:[/bold] {data.project.fps or 'Not specified'}")
resolution = f"{data.project.resolution.width}x{data.project.resolution.height}" if data.project.resolution else "Not specified"
console.print(f"[bold]Resolution:[/bold] {resolution}")
if verbose:
table = Table(title="Shots")
table.add_column("ID", style="cyan")
table.add_column("Prompt", style="green")
table.add_column("Duration", style="yellow")
for shot in data.shots:
table.add_row(shot.id, shot.prompt[:50], f"{shot.duration_s}s")
console.print(table)
except Exception as e:
console.print(f"[red]Validation failed: {e}[/red]")
raise typer.Exit(1)
@app.command()
def resume(
project: str = typer.Argument(..., help="Project name"),
output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"),
):
"""Resume a failed or interrupted project."""
project_dir = output / project.replace(" ", "_")
checkpoint_db = project_dir / "checkpoints.db"
if not checkpoint_db.exists():
console.print(f"[red]Error: No checkpoint found for project '{project}'[/red]")
raise typer.Exit(1)
checkpoint_mgr = CheckpointManager(str(checkpoint_db))
project_checkpoint = checkpoint_mgr.get_project_checkpoint(project)
if not project_checkpoint:
console.print(f"[red]Error: No project checkpoint found[/red]")
raise typer.Exit(1)
console.print(f"[bold blue]Resuming project:[/bold blue] {project}")
console.print(f"Storyboard: {project_checkpoint.storyboard_path}")
# Re-run generation with resume flag
generate(
storyboard=Path(project_checkpoint.storyboard_path),
output=output,
resume=True
)
@app.command()
def list_backends():
"""List available generation backends."""
table = Table(title="Available Backends")
table.add_column("Name", style="cyan")
table.add_column("Type", style="green")
table.add_column("Description", style="white")
table.add_row("wan", "T2V", "WAN 2.x text-to-video")
table.add_row("wan-1.3b", "T2V", "WAN 1.3B (faster, lower quality)")
table.add_row("svd", "I2V", "Stable Video Diffusion (fallback)")
console.print(table)
def main():
"""Entry point for the CLI."""
app()
if __name__ == "__main__":
main()