Files
video-gen/tests/unit/test_wan_backend.py

196 lines
6.4 KiB
Python

"""
Tests for WAN backend.
Note: These tests mock the actual model loading/generation.
"""
import pytest
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock, mock_open
from src.generation.backends.wan import WanBackend
from src.generation.base import GenerationSpec, GenerationResult
class TestWanBackend:
"""Test WAN backend functionality."""
@pytest.fixture
def config_14b(self):
"""Configuration for 14B model."""
return {
"model_id": "Wan-AI/Wan2.1-T2V-14B",
"vram_gb": 12,
"dtype": "fp16",
"enable_vae_slicing": True,
"enable_vae_tiling": True,
"model_cache_dir": "~/.cache/test"
}
@pytest.fixture
def config_1_3b(self):
"""Configuration for 1.3B model."""
return {
"model_id": "Wan-AI/Wan2.1-T2V-1.3B",
"vram_gb": 8,
"dtype": "fp16",
"enable_vae_slicing": True,
"enable_vae_tiling": False,
"model_cache_dir": "~/.cache/test"
}
def test_backend_properties_14b(self, config_14b):
"""Test backend properties for 14B model."""
backend = WanBackend(config_14b)
assert backend.name == "wan_wan2.1-t2v-14b"
assert backend.supports_chunking is True
assert backend.supports_init_frame is False
assert backend.vram_gb == 12
def test_backend_properties_1_3b(self, config_1_3b):
"""Test backend properties for 1.3B model."""
backend = WanBackend(config_1_3b)
assert backend.name == "wan_wan2.1-t2v-1.3b"
assert backend.vram_gb == 8
def test_dtype_str_storage(self, config_14b):
"""Test dtype string is stored correctly."""
backend = WanBackend(config_14b)
assert backend.dtype_str == "fp16"
config_14b["dtype"] = "bf16"
backend2 = WanBackend(config_14b)
assert backend2.dtype_str == "bf16"
def test_estimate_vram_14b(self, config_14b):
"""Test VRAM estimation for 14B model."""
backend = WanBackend(config_14b)
# 720p, 81 frames (default)
vram = backend.estimate_vram_usage(1280, 720, 81)
assert vram > 12.0 # Base is 12GB with margin
# 1080p should be higher
vram_1080 = backend.estimate_vram_usage(1920, 1080, 81)
assert vram_1080 > vram
def test_estimate_vram_1_3b(self, config_1_3b):
"""Test VRAM estimation for 1.3B model."""
backend = WanBackend(config_1_3b)
# Should be less than 14B
vram = backend.estimate_vram_usage(1280, 720, 81)
assert vram < 12.0
def test_estimate_vram_scaling(self, config_14b):
"""Test VRAM scales with resolution and frames."""
backend = WanBackend(config_14b)
# Base
vram_720_81 = backend.estimate_vram_usage(1280, 720, 81)
# Double resolution (4x pixels)
vram_1440_81 = backend.estimate_vram_usage(1920, 1080, 81)
assert vram_1440_81 > vram_720_81 * 3 # Should be ~4x
# Double frames
vram_720_162 = backend.estimate_vram_usage(1280, 720, 162)
assert vram_720_162 > vram_720_81 * 1.5 # Should be ~2x
def test_generate_not_loaded(self, config_14b, tmp_path):
"""Test generation fails if model not loaded."""
backend = WanBackend(config_14b)
spec = GenerationSpec(
prompt="test",
negative_prompt="",
width=512,
height=512,
num_frames=16,
fps=8,
seed=42,
steps=10,
cfg_scale=6.0,
output_path=tmp_path / "test.mp4"
)
result = backend.generate(spec)
assert result.success is False
assert "not loaded" in result.error_message.lower()
def test_concatenate_chunks_success(self, config_14b, tmp_path):
"""Test chunk concatenation."""
backend = WanBackend(config_14b)
# Create dummy chunk files
chunk_files = [
tmp_path / "chunk0.mp4",
tmp_path / "chunk1.mp4",
tmp_path / "chunk2.mp4"
]
for f in chunk_files:
f.write_text("dummy video data")
output_path = tmp_path / "output.mp4"
# Mock subprocess
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=0)
result = backend._concatenate_chunks(chunk_files, output_path, 24)
assert result is True
mock_run.assert_called_once()
# Verify ffmpeg command structure
call_args = mock_run.call_args[0][0]
assert 'ffmpeg' in call_args
assert '-f' in call_args
assert 'concat' in call_args
def test_concatenate_chunks_failure(self, config_14b, tmp_path):
"""Test chunk concatenation failure."""
backend = WanBackend(config_14b)
chunk_files = [tmp_path / "chunk0.mp4"]
chunk_files[0].write_text("dummy")
output_path = tmp_path / "output.mp4"
# Mock subprocess failure
with patch('subprocess.run') as mock_run:
mock_run.return_value = MagicMock(returncode=1, stderr="ffmpeg error")
result = backend._concatenate_chunks(chunk_files, output_path, 24)
assert result is False
def test_concatenate_chunks_exception(self, config_14b, tmp_path):
"""Test chunk concatenation with exception."""
backend = WanBackend(config_14b)
chunk_files = [tmp_path / "chunk0.mp4"]
chunk_files[0].write_text("dummy")
output_path = tmp_path / "output.mp4"
# Mock subprocess exception
with patch('subprocess.run', side_effect=Exception("subprocess failed")):
result = backend._concatenate_chunks(chunk_files, output_path, 24)
assert result is False
def test_is_loaded_property(self, config_14b):
"""Test is_loaded property."""
backend = WanBackend(config_14b)
# Initially not loaded
assert backend.is_loaded is False
# Manually set (normally done by load())
backend._is_loaded = True
assert backend.is_loaded is True