""" Tests for WAN backend. Note: These tests mock the actual model loading/generation. """ import pytest from pathlib import Path from unittest.mock import Mock, patch, MagicMock, mock_open from src.generation.backends.wan import WanBackend from src.generation.base import GenerationSpec, GenerationResult class TestWanBackend: """Test WAN backend functionality.""" @pytest.fixture def config_14b(self): """Configuration for 14B model.""" return { "model_id": "Wan-AI/Wan2.1-T2V-14B", "vram_gb": 12, "dtype": "fp16", "enable_vae_slicing": True, "enable_vae_tiling": True, "model_cache_dir": "~/.cache/test" } @pytest.fixture def config_1_3b(self): """Configuration for 1.3B model.""" return { "model_id": "Wan-AI/Wan2.1-T2V-1.3B", "vram_gb": 8, "dtype": "fp16", "enable_vae_slicing": True, "enable_vae_tiling": False, "model_cache_dir": "~/.cache/test" } def test_backend_properties_14b(self, config_14b): """Test backend properties for 14B model.""" backend = WanBackend(config_14b) assert backend.name == "wan_wan2.1-t2v-14b" assert backend.supports_chunking is True assert backend.supports_init_frame is False assert backend.vram_gb == 12 def test_backend_properties_1_3b(self, config_1_3b): """Test backend properties for 1.3B model.""" backend = WanBackend(config_1_3b) assert backend.name == "wan_wan2.1-t2v-1.3b" assert backend.vram_gb == 8 def test_dtype_str_storage(self, config_14b): """Test dtype string is stored correctly.""" backend = WanBackend(config_14b) assert backend.dtype_str == "fp16" config_14b["dtype"] = "bf16" backend2 = WanBackend(config_14b) assert backend2.dtype_str == "bf16" def test_estimate_vram_14b(self, config_14b): """Test VRAM estimation for 14B model.""" backend = WanBackend(config_14b) # 720p, 81 frames (default) vram = backend.estimate_vram_usage(1280, 720, 81) assert vram > 12.0 # Base is 12GB with margin # 1080p should be higher vram_1080 = backend.estimate_vram_usage(1920, 1080, 81) assert vram_1080 > vram def test_estimate_vram_1_3b(self, config_1_3b): """Test VRAM estimation for 1.3B model.""" backend = WanBackend(config_1_3b) # Should be less than 14B vram = backend.estimate_vram_usage(1280, 720, 81) assert vram < 12.0 def test_estimate_vram_scaling(self, config_14b): """Test VRAM scales with resolution and frames.""" backend = WanBackend(config_14b) # Base vram_720_81 = backend.estimate_vram_usage(1280, 720, 81) # Double resolution (4x pixels) vram_1440_81 = backend.estimate_vram_usage(1920, 1080, 81) assert vram_1440_81 > vram_720_81 * 3 # Should be ~4x # Double frames vram_720_162 = backend.estimate_vram_usage(1280, 720, 162) assert vram_720_162 > vram_720_81 * 1.5 # Should be ~2x def test_generate_not_loaded(self, config_14b, tmp_path): """Test generation fails if model not loaded.""" backend = WanBackend(config_14b) spec = GenerationSpec( prompt="test", negative_prompt="", width=512, height=512, num_frames=16, fps=8, seed=42, steps=10, cfg_scale=6.0, output_path=tmp_path / "test.mp4" ) result = backend.generate(spec) assert result.success is False assert "not loaded" in result.error_message.lower() def test_concatenate_chunks_success(self, config_14b, tmp_path): """Test chunk concatenation.""" backend = WanBackend(config_14b) # Create dummy chunk files chunk_files = [ tmp_path / "chunk0.mp4", tmp_path / "chunk1.mp4", tmp_path / "chunk2.mp4" ] for f in chunk_files: f.write_text("dummy video data") output_path = tmp_path / "output.mp4" # Mock subprocess with patch('subprocess.run') as mock_run: mock_run.return_value = MagicMock(returncode=0) result = backend._concatenate_chunks(chunk_files, output_path, 24) assert result is True mock_run.assert_called_once() # Verify ffmpeg command structure call_args = mock_run.call_args[0][0] assert 'ffmpeg' in call_args assert '-f' in call_args assert 'concat' in call_args def test_concatenate_chunks_failure(self, config_14b, tmp_path): """Test chunk concatenation failure.""" backend = WanBackend(config_14b) chunk_files = [tmp_path / "chunk0.mp4"] chunk_files[0].write_text("dummy") output_path = tmp_path / "output.mp4" # Mock subprocess failure with patch('subprocess.run') as mock_run: mock_run.return_value = MagicMock(returncode=1, stderr="ffmpeg error") result = backend._concatenate_chunks(chunk_files, output_path, 24) assert result is False def test_concatenate_chunks_exception(self, config_14b, tmp_path): """Test chunk concatenation with exception.""" backend = WanBackend(config_14b) chunk_files = [tmp_path / "chunk0.mp4"] chunk_files[0].write_text("dummy") output_path = tmp_path / "output.mp4" # Mock subprocess exception with patch('subprocess.run', side_effect=Exception("subprocess failed")): result = backend._concatenate_chunks(chunk_files, output_path, 24) assert result is False def test_is_loaded_property(self, config_14b): """Test is_loaded property.""" backend = WanBackend(config_14b) # Initially not loaded assert backend.is_loaded is False # Manually set (normally done by load()) backend._is_loaded = True assert backend.is_loaded is True