196 lines
6.4 KiB
Python
196 lines
6.4 KiB
Python
"""
|
|
Tests for WAN backend.
|
|
Note: These tests mock the actual model loading/generation.
|
|
"""
|
|
|
|
import pytest
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch, MagicMock, mock_open
|
|
|
|
from src.generation.backends.wan import WanBackend
|
|
from src.generation.base import GenerationSpec, GenerationResult
|
|
|
|
|
|
class TestWanBackend:
|
|
"""Test WAN backend functionality."""
|
|
|
|
@pytest.fixture
|
|
def config_14b(self):
|
|
"""Configuration for 14B model."""
|
|
return {
|
|
"model_id": "Wan-AI/Wan2.1-T2V-14B",
|
|
"vram_gb": 12,
|
|
"dtype": "fp16",
|
|
"enable_vae_slicing": True,
|
|
"enable_vae_tiling": True,
|
|
"model_cache_dir": "~/.cache/test"
|
|
}
|
|
|
|
@pytest.fixture
|
|
def config_1_3b(self):
|
|
"""Configuration for 1.3B model."""
|
|
return {
|
|
"model_id": "Wan-AI/Wan2.1-T2V-1.3B",
|
|
"vram_gb": 8,
|
|
"dtype": "fp16",
|
|
"enable_vae_slicing": True,
|
|
"enable_vae_tiling": False,
|
|
"model_cache_dir": "~/.cache/test"
|
|
}
|
|
|
|
def test_backend_properties_14b(self, config_14b):
|
|
"""Test backend properties for 14B model."""
|
|
backend = WanBackend(config_14b)
|
|
|
|
assert backend.name == "wan_wan2.1-t2v-14b"
|
|
assert backend.supports_chunking is True
|
|
assert backend.supports_init_frame is False
|
|
assert backend.vram_gb == 12
|
|
|
|
def test_backend_properties_1_3b(self, config_1_3b):
|
|
"""Test backend properties for 1.3B model."""
|
|
backend = WanBackend(config_1_3b)
|
|
|
|
assert backend.name == "wan_wan2.1-t2v-1.3b"
|
|
assert backend.vram_gb == 8
|
|
|
|
def test_dtype_str_storage(self, config_14b):
|
|
"""Test dtype string is stored correctly."""
|
|
backend = WanBackend(config_14b)
|
|
assert backend.dtype_str == "fp16"
|
|
|
|
config_14b["dtype"] = "bf16"
|
|
backend2 = WanBackend(config_14b)
|
|
assert backend2.dtype_str == "bf16"
|
|
|
|
def test_estimate_vram_14b(self, config_14b):
|
|
"""Test VRAM estimation for 14B model."""
|
|
backend = WanBackend(config_14b)
|
|
|
|
# 720p, 81 frames (default)
|
|
vram = backend.estimate_vram_usage(1280, 720, 81)
|
|
assert vram > 12.0 # Base is 12GB with margin
|
|
|
|
# 1080p should be higher
|
|
vram_1080 = backend.estimate_vram_usage(1920, 1080, 81)
|
|
assert vram_1080 > vram
|
|
|
|
def test_estimate_vram_1_3b(self, config_1_3b):
|
|
"""Test VRAM estimation for 1.3B model."""
|
|
backend = WanBackend(config_1_3b)
|
|
|
|
# Should be less than 14B
|
|
vram = backend.estimate_vram_usage(1280, 720, 81)
|
|
assert vram < 12.0
|
|
|
|
def test_estimate_vram_scaling(self, config_14b):
|
|
"""Test VRAM scales with resolution and frames."""
|
|
backend = WanBackend(config_14b)
|
|
|
|
# Base
|
|
vram_720_81 = backend.estimate_vram_usage(1280, 720, 81)
|
|
|
|
# Double resolution (4x pixels)
|
|
vram_1440_81 = backend.estimate_vram_usage(1920, 1080, 81)
|
|
assert vram_1440_81 > vram_720_81 * 3 # Should be ~4x
|
|
|
|
# Double frames
|
|
vram_720_162 = backend.estimate_vram_usage(1280, 720, 162)
|
|
assert vram_720_162 > vram_720_81 * 1.5 # Should be ~2x
|
|
|
|
def test_generate_not_loaded(self, config_14b, tmp_path):
|
|
"""Test generation fails if model not loaded."""
|
|
backend = WanBackend(config_14b)
|
|
|
|
spec = GenerationSpec(
|
|
prompt="test",
|
|
negative_prompt="",
|
|
width=512,
|
|
height=512,
|
|
num_frames=16,
|
|
fps=8,
|
|
seed=42,
|
|
steps=10,
|
|
cfg_scale=6.0,
|
|
output_path=tmp_path / "test.mp4"
|
|
)
|
|
|
|
result = backend.generate(spec)
|
|
|
|
assert result.success is False
|
|
assert "not loaded" in result.error_message.lower()
|
|
|
|
def test_concatenate_chunks_success(self, config_14b, tmp_path):
|
|
"""Test chunk concatenation."""
|
|
backend = WanBackend(config_14b)
|
|
|
|
# Create dummy chunk files
|
|
chunk_files = [
|
|
tmp_path / "chunk0.mp4",
|
|
tmp_path / "chunk1.mp4",
|
|
tmp_path / "chunk2.mp4"
|
|
]
|
|
|
|
for f in chunk_files:
|
|
f.write_text("dummy video data")
|
|
|
|
output_path = tmp_path / "output.mp4"
|
|
|
|
# Mock subprocess
|
|
with patch('subprocess.run') as mock_run:
|
|
mock_run.return_value = MagicMock(returncode=0)
|
|
|
|
result = backend._concatenate_chunks(chunk_files, output_path, 24)
|
|
|
|
assert result is True
|
|
mock_run.assert_called_once()
|
|
|
|
# Verify ffmpeg command structure
|
|
call_args = mock_run.call_args[0][0]
|
|
assert 'ffmpeg' in call_args
|
|
assert '-f' in call_args
|
|
assert 'concat' in call_args
|
|
|
|
def test_concatenate_chunks_failure(self, config_14b, tmp_path):
|
|
"""Test chunk concatenation failure."""
|
|
backend = WanBackend(config_14b)
|
|
|
|
chunk_files = [tmp_path / "chunk0.mp4"]
|
|
chunk_files[0].write_text("dummy")
|
|
|
|
output_path = tmp_path / "output.mp4"
|
|
|
|
# Mock subprocess failure
|
|
with patch('subprocess.run') as mock_run:
|
|
mock_run.return_value = MagicMock(returncode=1, stderr="ffmpeg error")
|
|
|
|
result = backend._concatenate_chunks(chunk_files, output_path, 24)
|
|
|
|
assert result is False
|
|
|
|
def test_concatenate_chunks_exception(self, config_14b, tmp_path):
|
|
"""Test chunk concatenation with exception."""
|
|
backend = WanBackend(config_14b)
|
|
|
|
chunk_files = [tmp_path / "chunk0.mp4"]
|
|
chunk_files[0].write_text("dummy")
|
|
|
|
output_path = tmp_path / "output.mp4"
|
|
|
|
# Mock subprocess exception
|
|
with patch('subprocess.run', side_effect=Exception("subprocess failed")):
|
|
result = backend._concatenate_chunks(chunk_files, output_path, 24)
|
|
|
|
assert result is False
|
|
|
|
def test_is_loaded_property(self, config_14b):
|
|
"""Test is_loaded property."""
|
|
backend = WanBackend(config_14b)
|
|
|
|
# Initially not loaded
|
|
assert backend.is_loaded is False
|
|
|
|
# Manually set (normally done by load())
|
|
backend._is_loaded = True
|
|
assert backend.is_loaded is True
|