322 lines
13 KiB
Python
322 lines
13 KiB
Python
"""
|
|
CLI entry point for storyboard video generation.
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import typer
|
|
from rich.console import Console
|
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|
from rich.table import Table
|
|
|
|
from src.storyboard.loader import StoryboardValidator
|
|
from src.storyboard.prompt_compiler import PromptCompiler
|
|
from src.storyboard.shot_planner import ShotPlanner
|
|
from src.core.config import ConfigLoader
|
|
from src.core.checkpoint import CheckpointManager
|
|
from src.generation.base import BackendFactory
|
|
from src.assembly.assembler import FFmpegAssembler, AssemblyConfig
|
|
from src.upscaling.upscaler import UpscaleManager, UpscaleConfig
|
|
|
|
app = typer.Typer(help="Storyboard to Video Generation Pipeline")
|
|
console = Console()
|
|
|
|
|
|
@app.command()
|
|
def generate(
|
|
storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"),
|
|
output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"),
|
|
config: Optional[Path] = typer.Option(None, "--config", "-c", help="Path to config file"),
|
|
backend: str = typer.Option("wan", "--backend", "-b", help="Generation backend (wan, svd)"),
|
|
resume: bool = typer.Option(False, "--resume", "-r", help="Resume from checkpoint"),
|
|
skip_generation: bool = typer.Option(False, "--skip-generation", help="Skip generation, only assemble"),
|
|
skip_assembly: bool = typer.Option(False, "--skip-assembly", help="Skip assembly, only generate shots"),
|
|
upscale: Optional[int] = typer.Option(None, "--upscale", help="Upscale factor (2 or 4)"),
|
|
dry_run: bool = typer.Option(False, "--dry-run", help="Validate storyboard without generating"),
|
|
):
|
|
"""Generate video from storyboard."""
|
|
|
|
# Validate storyboard exists
|
|
if not storyboard.exists():
|
|
console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]")
|
|
raise typer.Exit(1)
|
|
|
|
# Load configuration
|
|
try:
|
|
if config:
|
|
app_config = ConfigLoader.load(config)
|
|
else:
|
|
app_config = ConfigLoader.load()
|
|
except Exception as e:
|
|
console.print(f"[red]Error loading config: {e}[/red]")
|
|
raise typer.Exit(1)
|
|
|
|
# Validate storyboard
|
|
console.print("[bold blue]Validating storyboard...[/bold blue]")
|
|
validator = StoryboardValidator()
|
|
try:
|
|
storyboard_data = validator.load(storyboard)
|
|
console.print(f"[green]✓[/green] Storyboard validated: {len(storyboard_data.shots)} shots")
|
|
except Exception as e:
|
|
console.print(f"[red]Error validating storyboard: {e}[/red]")
|
|
raise typer.Exit(1)
|
|
|
|
if dry_run:
|
|
console.print("[yellow]Dry run complete. Exiting.[/yellow]")
|
|
return
|
|
|
|
# Setup output directories
|
|
project_dir = output / storyboard_data.project.title.replace(" ", "_")
|
|
shots_dir = project_dir / "shots"
|
|
assembled_dir = project_dir / "assembled"
|
|
metadata_dir = project_dir / "metadata"
|
|
|
|
for d in [shots_dir, assembled_dir, metadata_dir]:
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Initialize checkpoint manager
|
|
checkpoint_db = project_dir / "checkpoints.db"
|
|
checkpoint_mgr = CheckpointManager(str(checkpoint_db))
|
|
|
|
# Initialize components
|
|
prompt_compiler = PromptCompiler(storyboard_data.project.global_style)
|
|
shot_planner = ShotPlanner(
|
|
fps=storyboard_data.project.fps or 24,
|
|
max_chunk_duration=app_config.backend.max_chunk_seconds or 6.0
|
|
)
|
|
|
|
# Initialize generation backend
|
|
if not skip_generation:
|
|
console.print(f"[bold blue]Initializing {backend} backend...[/bold blue]")
|
|
try:
|
|
backend_config = app_config.get_backend_config(backend)
|
|
video_backend = BackendFactory.create_backend(backend, backend_config)
|
|
console.print(f"[green]✓[/green] Backend initialized: {backend}")
|
|
except Exception as e:
|
|
console.print(f"[red]Error initializing backend: {e}[/red]")
|
|
raise typer.Exit(1)
|
|
|
|
# Generate shots
|
|
generated_shots = []
|
|
|
|
if not skip_generation:
|
|
console.print(f"[bold blue]Generating {len(storyboard_data.shots)} shots...[/bold blue]")
|
|
|
|
with Progress(
|
|
SpinnerColumn(),
|
|
TextColumn("[progress.description]{task.description}"),
|
|
console=console
|
|
) as progress:
|
|
|
|
for i, shot in enumerate(storyboard_data.shots):
|
|
task_id = progress.add_task(f"Shot {shot.id}", total=None)
|
|
|
|
# Check checkpoint
|
|
if resume:
|
|
checkpoint = checkpoint_mgr.get_shot_checkpoint(
|
|
storyboard_data.project.title, shot.id
|
|
)
|
|
if checkpoint and checkpoint.status == "completed":
|
|
progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id} (cached)")
|
|
generated_shots.append(Path(checkpoint.video_path))
|
|
continue
|
|
|
|
# Compile prompt
|
|
prompt = prompt_compiler.compile_shot_prompt(shot)
|
|
negative_prompt = prompt_compiler.compile_negative_prompt(shot)
|
|
|
|
# Plan shot
|
|
shot_plan = shot_planner.plan_shot(shot)
|
|
|
|
# Generate
|
|
try:
|
|
from src.generation.base import GenerationSpec
|
|
|
|
spec = GenerationSpec(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
width=shot_plan.width,
|
|
height=shot_plan.height,
|
|
num_frames=shot_plan.total_frames,
|
|
fps=shot_plan.fps,
|
|
seed=shot.generation.seed
|
|
)
|
|
|
|
result = video_backend.generate(spec, output_dir=shots_dir)
|
|
|
|
if result.success:
|
|
generated_shots.append(result.video_path)
|
|
checkpoint_mgr.save_shot_checkpoint(
|
|
project_name=storyboard_data.project.title,
|
|
shot_id=shot.id,
|
|
status="completed",
|
|
video_path=str(result.video_path),
|
|
metadata=result.metadata
|
|
)
|
|
progress.update(task_id, description=f"[green]✓[/green] Shot {shot.id}")
|
|
else:
|
|
progress.update(task_id, description=f"[red]✗[/red] Shot {shot.id}: {result.error_message}")
|
|
checkpoint_mgr.save_shot_checkpoint(
|
|
project_name=storyboard_data.project.title,
|
|
shot_id=shot.id,
|
|
status="failed",
|
|
error_message=result.error_message
|
|
)
|
|
|
|
except Exception as e:
|
|
progress.update(task_id, description=f"[red]✗[/red] Shot {shot.id}: {e}")
|
|
checkpoint_mgr.save_shot_checkpoint(
|
|
project_name=storyboard_data.project.title,
|
|
shot_id=shot.id,
|
|
status="failed",
|
|
error_message=str(e)
|
|
)
|
|
|
|
# Assembly
|
|
if not skip_assembly and generated_shots:
|
|
console.print("[bold blue]Assembling video...[/bold blue]")
|
|
|
|
assembler = FFmpegAssembler()
|
|
final_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}.mp4"
|
|
|
|
assembly_config = AssemblyConfig(
|
|
fps=storyboard_data.project.fps or 24,
|
|
add_shot_labels=False
|
|
)
|
|
|
|
result = assembler.assemble(generated_shots, final_output, assembly_config)
|
|
|
|
if result.success:
|
|
console.print(f"[green]✓[/green] Video assembled: {final_output}")
|
|
|
|
# Upscale if requested
|
|
if upscale:
|
|
console.print(f"[bold blue]Upscaling {upscale}x...[/bold blue]")
|
|
|
|
upscale_config = UpscaleConfig(
|
|
factor=upscale,
|
|
upscaler_type="ffmpeg_sr"
|
|
)
|
|
upscale_mgr = UpscaleManager(upscale_config)
|
|
|
|
upscaled_output = assembled_dir / f"{storyboard_data.project.title.replace(' ', '_')}_upscaled.mp4"
|
|
upscale_result = upscale_mgr.upscale(final_output, upscaled_output)
|
|
|
|
if upscale_result.success:
|
|
console.print(f"[green]✓[/green] Upscaled video: {upscaled_output}")
|
|
final_output = upscaled_output
|
|
else:
|
|
console.print(f"[yellow]Warning: Upscaling failed: {upscale_result.error_message}[/yellow]")
|
|
|
|
# Save project metadata
|
|
checkpoint_mgr.save_project_checkpoint(
|
|
project_name=storyboard_data.project.title,
|
|
storyboard_path=str(storyboard),
|
|
output_path=str(final_output),
|
|
status="completed",
|
|
num_shots=len(generated_shots)
|
|
)
|
|
|
|
console.print(f"\n[bold green]Success![/bold green] Video saved to: {final_output}")
|
|
else:
|
|
console.print(f"[red]Error assembling video: {result.error_message}[/red]")
|
|
raise typer.Exit(1)
|
|
|
|
|
|
@app.command()
|
|
def validate(
|
|
storyboard: Path = typer.Argument(..., help="Path to storyboard JSON file"),
|
|
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed information"),
|
|
):
|
|
"""Validate a storyboard file."""
|
|
|
|
if not storyboard.exists():
|
|
console.print(f"[red]Error: Storyboard file not found: {storyboard}[/red]")
|
|
raise typer.Exit(1)
|
|
|
|
validator = StoryboardValidator()
|
|
|
|
try:
|
|
data = validator.load(storyboard)
|
|
|
|
console.print(f"[green]✓[/green] Storyboard is valid")
|
|
console.print(f"\n[bold]Title:[/bold] {data.project.title}")
|
|
console.print(f"[bold]Shots:[/bold] {len(data.shots)}")
|
|
console.print(f"[bold]FPS:[/bold] {data.project.fps or 'Not specified'}")
|
|
resolution = f"{data.project.resolution.width}x{data.project.resolution.height}" if data.project.resolution else "Not specified"
|
|
console.print(f"[bold]Resolution:[/bold] {resolution}")
|
|
|
|
if verbose:
|
|
table = Table(title="Shots")
|
|
table.add_column("ID", style="cyan")
|
|
table.add_column("Prompt", style="green")
|
|
table.add_column("Duration", style="yellow")
|
|
|
|
for shot in data.shots:
|
|
table.add_row(shot.id, shot.prompt[:50], f"{shot.duration_s}s")
|
|
|
|
console.print(table)
|
|
|
|
except Exception as e:
|
|
console.print(f"[red]Validation failed: {e}[/red]")
|
|
raise typer.Exit(1)
|
|
|
|
|
|
@app.command()
|
|
def resume(
|
|
project: str = typer.Argument(..., help="Project name"),
|
|
output: Path = typer.Option(Path("outputs"), "--output", "-o", help="Output directory"),
|
|
):
|
|
"""Resume a failed or interrupted project."""
|
|
|
|
project_dir = output / project.replace(" ", "_")
|
|
checkpoint_db = project_dir / "checkpoints.db"
|
|
|
|
if not checkpoint_db.exists():
|
|
console.print(f"[red]Error: No checkpoint found for project '{project}'[/red]")
|
|
raise typer.Exit(1)
|
|
|
|
checkpoint_mgr = CheckpointManager(str(checkpoint_db))
|
|
project_checkpoint = checkpoint_mgr.get_project_checkpoint(project)
|
|
|
|
if not project_checkpoint:
|
|
console.print(f"[red]Error: No project checkpoint found[/red]")
|
|
raise typer.Exit(1)
|
|
|
|
console.print(f"[bold blue]Resuming project:[/bold blue] {project}")
|
|
console.print(f"Storyboard: {project_checkpoint.storyboard_path}")
|
|
|
|
# Re-run generation with resume flag
|
|
generate(
|
|
storyboard=Path(project_checkpoint.storyboard_path),
|
|
output=output,
|
|
resume=True
|
|
)
|
|
|
|
|
|
@app.command()
|
|
def list_backends():
|
|
"""List available generation backends."""
|
|
|
|
table = Table(title="Available Backends")
|
|
table.add_column("Name", style="cyan")
|
|
table.add_column("Type", style="green")
|
|
table.add_column("Description", style="white")
|
|
|
|
table.add_row("wan", "T2V", "WAN 2.x text-to-video")
|
|
table.add_row("wan-1.3b", "T2V", "WAN 1.3B (faster, lower quality)")
|
|
table.add_row("svd", "I2V", "Stable Video Diffusion (fallback)")
|
|
|
|
console.print(table)
|
|
|
|
|
|
def main():
|
|
"""Entry point for the CLI."""
|
|
app()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|