"""
SHAC-enabled spatial effects module (migrated from Weather-Tune effects)

This module provides spatial effects processing that integrates with the SHAC codec
for creating immersive 3D environments. Includes spatial reverb, dynamic processing,
and weather-specific acoustic environments.

Key Features:
- SHAC room modeling for realistic reverberation
- Distance-based attenuation and filtering
- Weather-specific spatial characteristics
- Full 3D positioning for all effects
"""

import numpy as np
from typing import Dict, List, Tuple, Optional
import scipy.signal as signal

# Import from shac module - adjust path based on actual SHAC installation
try:
    from shac.codec.core import SHACCodec
    from shac.codec.utils import SourceAttributes, RoomAttributes
except ImportError:
    # For development/testing without full SHAC installed
    class SHACCodec:
        pass
    class SourceAttributes:
        pass
    class RoomAttributes:
        def __init__(self, **kwargs):
            self.dimensions = kwargs.get('dimensions', (10, 3, 10))
            self.reflection_coefficients = kwargs.get('reflection_coefficients', {})
            self.rt60 = kwargs.get('rt60', 2.0)
            self.direct_to_reverb_ratio = kwargs.get('direct_to_reverb_ratio', -3.0)


class SpatialReverb:
    """
    Spatial reverb with SHAC room modeling and weather-specific characteristics.
    Replaces Weather-Tune's convolution reverb with full 3D processing.
    """
    
    # Reverb characteristic presets converted to spatial room models
    ROOM_PRESETS = {
        "neutral": {
            "dimensions": (10, 3, 12),  # Small room
            "rt60": 2.0,
            "absorption": {"low": 0.1, "mid": 0.2, "high": 0.4},
            "diffusion": 0.8,
            "early_level": 0.7,
            "late_level": 0.5,
            "modulation": 0.0
        },
        "hall": {
            "dimensions": (20, 8, 30),  # Concert hall
            "rt60": 2.8,
            "absorption": {"low": 0.05, "mid": 0.1, "high": 0.3},
            "diffusion": 0.9,
            "early_level": 0.8,
            "late_level": 0.6,
            "modulation": 0.1
        },
        "plate": {
            "dimensions": (2, 0.01, 1.5),  # Simulated plate
            "rt60": 1.8,
            "absorption": {"low": 0.3, "mid": 0.1, "high": 0.0},
            "diffusion": 0.95,
            "early_level": 0.3,
            "late_level": 0.7,
            "modulation": 0.2
        },
        "chamber": {
            "dimensions": (6, 3, 8),  # Small chamber
            "rt60": 1.5,
            "absorption": {"low": 0.2, "mid": 0.3, "high": 0.5},
            "diffusion": 0.7,
            "early_level": 0.6,
            "late_level": 0.4,
            "modulation": 0.05
        },
        "space": {
            "dimensions": (50, 20, 100),  # Large space
            "rt60": 4.0,
            "absorption": {"low": 0.02, "mid": 0.05, "high": 0.2},
            "diffusion": 0.6,
            "early_level": 0.2,
            "late_level": 0.8,
            "modulation": 0.3
        },
        "cavern": {
            "dimensions": (40, 15, 60),  # Cave/cavern
            "rt60": 5.0,
            "absorption": {"low": 0.01, "mid": 0.05, "high": 0.3},
            "diffusion": 0.5,
            "early_level": 0.9,
            "late_level": 0.7,
            "modulation": 0.1
        },
        "ambient": {
            "dimensions": (25, 10, 35),  # Ambient space
            "rt60": 3.5,
            "absorption": {"low": 0.1, "mid": 0.15, "high": 0.25},
            "diffusion": 0.85,
            "early_level": 0.4,
            "late_level": 0.6,
            "modulation": 0.4
        }
    }
    
    # Weather-specific room configurations
    WEATHER_ROOMS = {
        "clear": {
            "preset": "plate",
            "position_modifiers": {
                "early_spread": 0.3,
                "late_spread": 0.6,
                "height_variance": 0.2
            }
        },
        "cloudy": {
            "preset": "hall",
            "position_modifiers": {
                "early_spread": 0.6,
                "late_spread": 0.8,
                "height_variance": 0.4
            }
        },
        "rain": {
            "preset": "chamber",
            "position_modifiers": {
                "early_spread": 0.8,
                "late_spread": 0.4,
                "height_variance": 0.3
            }
        },
        "storm": {
            "preset": "space",
            "position_modifiers": {
                "early_spread": 0.9,
                "late_spread": 0.95,
                "height_variance": 0.8
            }
        },
        "snow": {
            "preset": "ambient",
            "position_modifiers": {
                "early_spread": 0.4,
                "late_spread": 0.8,
                "height_variance": 0.1
            }
        },
        "fog": {
            "preset": "ambient",
            "position_modifiers": {
                "early_spread": 0.2,
                "late_spread": 0.3,
                "height_variance": 0.2
            }
        }
    }
    
    def __init__(self, sample_rate: int = 48000):
        self.sample_rate = sample_rate
        self.codec = SHACCodec()
        
    def create_spatial_reverb(self, 
                            character: str = "neutral",
                            duration: float = 2.0,
                            position: Tuple[float, float, float] = (0, 0, 0),
                            weather_type: Optional[str] = None,
                            intensity: float = 0.5) -> Dict:
        """
        Create a spatial reverb with SHAC positioning.
        
        Args:
            character: Reverb character preset
            duration: Reverb duration in seconds
            position: 3D position (x, y, z) in meters
            weather_type: Optional weather-specific modifications
            intensity: Weather intensity (0-1)
        
        Returns:
            Dictionary with reverb processor and parameters
        """
        # Get base preset
        preset = self.ROOM_PRESETS.get(character, self.ROOM_PRESETS["neutral"])
        
        # Apply weather modifications if specified
        if weather_type and weather_type in self.WEATHER_ROOMS:
            weather_config = self.WEATHER_ROOMS[weather_type]
            preset = self.ROOM_PRESETS.get(weather_config["preset"], preset)
            
            # Modify RT60 based on intensity
            preset["rt60"] = preset["rt60"] * (0.8 + 0.4 * intensity)
            
            # Adjust absorption based on weather
            for band in preset["absorption"]:
                preset["absorption"][band] *= (1.0 - 0.3 * intensity)
        
        # Create room attributes for SHAC
        room = self._create_room_attributes(preset, duration)
        
        # Generate early reflections pattern
        early_reflections = self._generate_early_reflections(
            preset, position, weather_type
        )
        
        # Create late reverb using SHAC room modeling
        late_reverb = self._create_late_reverb(preset, room, position)
        
        return {
            "early_reflections": early_reflections,
            "late_reverb": late_reverb,
            "room": room,
            "preset": preset,
            "position": position,
            "wet_level": 0.5,
            "dry_level": 0.5
        }
    
    def _create_room_attributes(self, preset: Dict, duration: float) -> RoomAttributes:
        """Create SHAC RoomAttributes from preset."""
        dimensions = preset["dimensions"]
        
        # Scale dimensions based on desired duration
        scale = duration / preset["rt60"]
        scaled_dimensions = tuple(d * scale for d in dimensions)
        
        # Convert absorption to reflection coefficients
        reflection_coeffs = {
            "floor": 1.0 - preset["absorption"]["low"],
            "ceiling": 1.0 - preset["absorption"]["high"],
            "left": 1.0 - preset["absorption"]["mid"],
            "right": 1.0 - preset["absorption"]["mid"],
            "front": 1.0 - preset["absorption"]["mid"],
            "back": 1.0 - preset["absorption"]["mid"]
        }
        
        return RoomAttributes(
            dimensions=scaled_dimensions,
            reflection_coefficients=reflection_coeffs,
            rt60=duration,
            direct_to_reverb_ratio=-3.0  # dB
        )
    
    def _generate_early_reflections(self, 
                                  preset: Dict, 
                                  position: Tuple[float, float, float],
                                  weather_type: Optional[str]) -> List[Dict]:
        """Generate spatial early reflection patterns."""
        reflections = []
        
        # Get weather-specific position modifiers
        if weather_type and weather_type in self.WEATHER_ROOMS:
            mods = self.WEATHER_ROOMS[weather_type]["position_modifiers"]
        else:
            mods = {"early_spread": 0.5, "late_spread": 0.7, "height_variance": 0.3}
        
        # First reflection is direct sound at source position
        reflections.append({
            "time": 0.0,
            "amplitude": preset["early_level"],
            "position": position
        })
        
        # Generate spatial early reflections
        num_reflections = int(10 + preset["early_level"] * 10)
        
        for i in range(1, num_reflections):
            # Time increases non-linearly
            time_pos = np.power(i / num_reflections, 1.5)
            time = time_pos * 0.15  # Max 150ms
            
            # Amplitude decreases with time
            amplitude = preset["early_level"] * (1.0 - time_pos) * 0.7
            
            # Spatial position spreads with time
            spread = mods["early_spread"] * time_pos
            
            # Random position offset from source
            offset = np.random.normal(0, spread, 3)
            offset[1] *= mods["height_variance"]  # Less variance in height
            
            reflection_pos = (
                position[0] + offset[0],
                position[1] + offset[1],
                position[2] + offset[2]
            )
            
            reflections.append({
                "time": time,
                "amplitude": amplitude,
                "position": reflection_pos
            })
        
        return reflections
    
    def _create_late_reverb(self, 
                           preset: Dict, 
                           room: RoomAttributes,
                           position: Tuple[float, float, float]) -> Dict:
        """Create late reverb using SHAC room modeling."""
        # Generate diffuse field positions for late reverb
        num_sources = 16  # Number of decorrelated sources for diffuse field
        
        positions = []
        for i in range(num_sources):
            # Distribute sources around the listener
            angle = (i / num_sources) * 2 * np.pi
            elevation = (np.random.random() - 0.5) * np.pi / 3
            distance = room.dimensions[0] * 0.5  # Half room width
            
            x = distance * np.cos(angle) * np.cos(elevation)
            y = distance * np.sin(elevation)
            z = distance * np.sin(angle) * np.cos(elevation)
            
            positions.append((
                position[0] + x,
                position[1] + y,
                position[2] + z
            ))
        
        return {
            "positions": positions,
            "level": preset["late_level"],
            "diffusion": preset["diffusion"],
            "modulation": preset["modulation"],
            "room": room
        }
    
    def process_audio(self, 
                     audio: np.ndarray,
                     reverb_params: Dict,
                     wet_dry: float = 0.5) -> np.ndarray:
        """
        Process audio through spatial reverb.
        
        Args:
            audio: Input audio (mono or stereo)
            reverb_params: Parameters from create_spatial_reverb
            wet_dry: Mix between dry (0) and wet (1) signal
        
        Returns:
            Processed audio with spatial reverb
        """
        # Ensure mono for spatial processing
        if audio.ndim > 1:
            audio = np.mean(audio, axis=0)
        
        # Process early reflections
        early_out = self._process_early_reflections(
            audio, 
            reverb_params["early_reflections"]
        )
        
        # Process late reverb
        late_out = self._process_late_reverb(
            audio,
            reverb_params["late_reverb"]
        )
        
        # Mix early and late
        wet = early_out + late_out
        
        # Mix wet and dry signals
        output = audio * (1.0 - wet_dry) + wet * wet_dry
        
        return output
    
    def _process_early_reflections(self, 
                                 audio: np.ndarray,
                                 reflections: List[Dict]) -> np.ndarray:
        """Process early reflections with spatial positioning."""
        output = np.zeros_like(audio)
        
        for reflection in reflections:
            # Calculate delay in samples
            delay_samples = int(reflection["time"] * self.sample_rate)
            
            if delay_samples < len(audio):
                # Apply delay and amplitude
                delayed = np.roll(audio, delay_samples)
                delayed[:delay_samples] = 0
                
                output += delayed * reflection["amplitude"]
        
        return output
    
    def _process_late_reverb(self,
                           audio: np.ndarray,
                           late_params: Dict) -> np.ndarray:
        """Process late reverb using diffuse field simulation."""
        output = np.zeros_like(audio)
        
        # Apply room coloration
        room = late_params["room"]
        rt60 = room.rt60
        
        # Simple reverb tail generation
        decay_rate = np.log(0.001) / (rt60 * self.sample_rate)
        reverb_length = min(int(rt60 * self.sample_rate * 0.5), self.sample_rate * 2)  # Limit length
        
        # Generate reverb tail
        tail = np.random.normal(0, 1, reverb_length)
        envelope = np.exp(decay_rate * np.arange(reverb_length))
        tail *= envelope
        
        # Apply diffusion
        if late_params["diffusion"] > 0:
            # All-pass filter for diffusion
            try:
                cutoff = min(late_params["diffusion"], 0.99)  # Ensure valid range
                b, a = signal.butter(2, cutoff, 'low')
                if len(tail) > 3*max(len(a), len(b)):  # Check for valid padlen
                    tail = signal.filtfilt(b, a, tail)
            except (ValueError, np.linalg.LinAlgError) as e:
                logging.debug(f"Filter failed, using unfiltered tail: {e}")
        
        # Convolve with input ensuring same size output
        if len(tail) > len(audio):
            tail = tail[:len(audio)]
        reverb = np.convolve(audio, tail, mode='same')
        
        # Apply late level
        output = reverb * late_params["level"]
        
        # Apply modulation if needed
        if late_params["modulation"] > 0:
            mod_freq = 0.5 + late_params["modulation"] * 2.0  # 0.5-2.5 Hz
            modulation = np.sin(2 * np.pi * mod_freq * np.arange(len(output)) / self.sample_rate)
            modulation = 1.0 + modulation * late_params["modulation"] * 0.1
            output *= modulation
        
        return output


class SpatialCompressor:
    """
    Distance-aware dynamic range compressor with spatial characteristics.
    """
    
    def __init__(self, sample_rate: int = 48000):
        self.sample_rate = sample_rate
    
    def create_spatial_compressor(self,
                                threshold: float = -24.0,
                                ratio: float = 4.0,
                                attack: float = 0.003,
                                release: float = 0.25,
                                position: Tuple[float, float, float] = (0, 0, 0),
                                distance_factor: float = 1.0) -> Dict:
        """
        Create a spatial compressor with distance-based adjustments.
        
        Args:
            threshold: Threshold in dB
            ratio: Compression ratio
            attack: Attack time in seconds
            release: Release time in seconds
            position: 3D position of the compressor
            distance_factor: How much distance affects compression
        
        Returns:
            Compressor parameters
        """
        # Calculate distance from listener (assumed at origin)
        distance = np.sqrt(sum(p**2 for p in position))
        
        # Adjust parameters based on distance
        if distance > 1.0:
            # Further sources get less aggressive compression
            distance_mod = 1.0 / (1.0 + distance * distance_factor)
            
            # Adjust threshold (higher for distant sources)
            threshold = threshold + (1.0 - distance_mod) * 12.0
            
            # Adjust ratio (lower for distant sources)
            ratio = ratio * distance_mod
            
            # Slower attack for distant sources
            attack = attack * (1.0 + distance * 0.5)
        
        return {
            "threshold": threshold,
            "ratio": ratio,
            "attack": attack,
            "release": release,
            "position": position,
            "distance": distance,
            "makeup_gain": self._calculate_makeup_gain(threshold, ratio)
        }
    
    def _calculate_makeup_gain(self, threshold: float, ratio: float) -> float:
        """Calculate appropriate makeup gain based on compression settings."""
        # Convert threshold from dB to linear for makeup gain calculation
        threshold_linear = 10 ** (threshold / 20)
        
        # Calculate gain reduction at threshold
        gain_reduction_db = (threshold - threshold/ratio)
        
        # Convert to linear and apply conservative scaling
        makeup_gain_linear = 10 ** (gain_reduction_db / 20) 
        
        # Apply conservative scaling to avoid over-compensation
        return min(makeup_gain_linear * 0.5, 2.0)  # Limit to reasonable range
    
    def process_audio(self,
                     audio: np.ndarray,
                     compressor_params: Dict) -> np.ndarray:
        """
        Process audio through spatial compressor.
        
        Args:
            audio: Input audio
            compressor_params: Parameters from create_spatial_compressor
        
        Returns:
            Compressed audio
        """
        # Convert parameters to linear values
        threshold_linear = 10 ** (compressor_params["threshold"] / 20)
        ratio = compressor_params["ratio"]
        
        # Calculate time constants
        attack_samples = int(compressor_params["attack"] * self.sample_rate)
        release_samples = int(compressor_params["release"] * self.sample_rate)
        
        # Initialize envelope follower
        envelope = 0.0
        
        # Process audio
        output = np.zeros_like(audio)
        
        for i in range(len(audio)):
            input_level = abs(audio[i])
            
            # Update envelope
            if input_level > envelope:
                # Attack
                envelope += (input_level - envelope) / attack_samples
            else:
                # Release
                envelope -= (envelope - input_level) / release_samples
            
            # Calculate gain reduction
            if envelope > threshold_linear:
                # Compression above threshold
                gain_reduction = (envelope - threshold_linear) * (1 - 1/ratio) / envelope
                gain = 1.0 - gain_reduction
            else:
                gain = 1.0
            
            # Apply gain and makeup gain
            output[i] = audio[i] * gain * compressor_params["makeup_gain"]
        
        return output


class SpatialFilter:
    """
    Distance and direction-based filtering for spatial audio.
    """
    
    @staticmethod
    def apply_distance_filtering(audio: np.ndarray,
                               distance: float,
                               sample_rate: int = 48000) -> np.ndarray:
        """
        Apply distance-based high-frequency attenuation.
        
        Args:
            audio: Input audio
            distance: Distance in meters
            sample_rate: Sample rate
        
        Returns:
            Filtered audio
        """
        if distance <= 1.0:
            return audio
        
        # High frequencies attenuate with distance (air absorption)
        # Approximately 1dB per 100m at 10kHz
        attenuation_db = distance / 100.0
        
        # Design low-pass filter
        cutoff_freq = 20000 * np.exp(-distance * 0.001)  # Exponential rolloff
        
        # Ensure cutoff is reasonable
        cutoff_freq = max(1000, min(cutoff_freq, sample_rate / 2 - 100))
        
        # Create butterworth filter
        nyquist = sample_rate / 2
        normal_cutoff = cutoff_freq / nyquist
        
        b, a = signal.butter(2, normal_cutoff, btype='low')
        
        # Apply filter
        return signal.filtfilt(b, a, audio, padlen=min(len(audio)-1, 3*max(len(a), len(b))))
    
    @staticmethod
    def apply_occlusion_filtering(audio: np.ndarray,
                                occlusion: float,
                                sample_rate: int = 48000) -> np.ndarray:
        """
        Apply occlusion filtering (object between source and listener).
        
        Args:
            audio: Input audio
            occlusion: Occlusion amount (0-1)
            sample_rate: Sample rate
        
        Returns:
            Filtered audio
        """
        if occlusion <= 0:
            return audio
        
        # Occluded sounds lose high frequencies
        cutoff_freq = 20000 * (1.0 - occlusion * 0.9)
        cutoff_freq = max(200, cutoff_freq)
        
        nyquist = sample_rate / 2
        normal_cutoff = cutoff_freq / nyquist
        
        b, a = signal.butter(4, normal_cutoff, btype='low')
        
        # Apply filter with attenuation
        filtered = signal.filtfilt(b, a, audio, padlen=min(len(audio)-1, 3*max(len(a), len(b))))
        
        # Reduce overall level based on occlusion
        attenuation = 1.0 - occlusion * 0.7  # Max 70% attenuation
        
        return filtered * attenuation


class WeatherEffects:
    """
    Weather-specific spatial effects combining all processors.
    """
    
    def __init__(self, sample_rate: int = 48000):
        self.sample_rate = sample_rate
        self.reverb = SpatialReverb(sample_rate)
        self.compressor = SpatialCompressor(sample_rate)
        self.filter = SpatialFilter()
    
    def create_weather_environment(self,
                                 weather_type: str,
                                 intensity: float,
                                 size: float = 0.5) -> Dict:
        """
        Create a complete weather-based spatial environment.
        
        Args:
            weather_type: Type of weather
            intensity: Weather intensity (0-1)
            size: Environment size (0-1)
        
        Returns:
            Complete effect chain parameters
        """
        # Define spatial characteristics for each weather type
        weather_configs = {
            "clear": {
                "reverb_character": "plate",
                "reverb_duration": 1.5 + size * 1.0,
                "reverb_position": (0, 5, 0),  # Above listener
                "compression_distance_factor": 0.5,
                "filter_distance_scale": 1.0,
                "spatial_width": 0.6
            },
            "cloudy": {
                "reverb_character": "hall",
                "reverb_duration": 2.0 + size * 1.5,
                "reverb_position": (0, 8, 0),  # Higher up
                "compression_distance_factor": 0.7,
                "filter_distance_scale": 1.2,
                "spatial_width": 0.7
            },
            "rain": {
                "reverb_character": "chamber",
                "reverb_duration": 1.8 + size * 1.2,
                "reverb_position": (0, 2, 0),  # Closer
                "compression_distance_factor": 0.8,
                "filter_distance_scale": 1.5,
                "spatial_width": 0.4
            },
            "storm": {
                "reverb_character": "space",
                "reverb_duration": 3.0 + size * 2.0,
                "reverb_position": (0, 10, 0),  # Very high
                "compression_distance_factor": 1.0,
                "filter_distance_scale": 2.0,
                "spatial_width": 0.9
            },
            "snow": {
                "reverb_character": "ambient",
                "reverb_duration": 2.5 + size * 2.0,
                "reverb_position": (0, 6, 0),
                "compression_distance_factor": 0.3,
                "filter_distance_scale": 0.8,  # Less distance filtering (absorption)
                "spatial_width": 0.8
            },
            "fog": {
                "reverb_character": "ambient",
                "reverb_duration": 3.0 + size * 2.5,
                "reverb_position": (0, 1, 0),  # Very close
                "compression_distance_factor": 0.5,
                "filter_distance_scale": 3.0,  # Heavy distance filtering
                "spatial_width": 0.3
            }
        }
        
        config = weather_configs.get(weather_type, weather_configs["clear"])
        
        # Create reverb
        reverb_params = self.reverb.create_spatial_reverb(
            character=config["reverb_character"],
            duration=config["reverb_duration"],
            position=config["reverb_position"],
            weather_type=weather_type,
            intensity=intensity
        )
        
        # Create compression settings
        compressor_params = self.compressor.create_spatial_compressor(
            threshold=-20 - intensity * 4,  # Lower threshold for intense weather
            ratio=3.0 + intensity * 2.0,    # Higher ratio for intense weather
            attack=0.005,
            release=0.3,
            position=(0, 0, 5),  # Default position
            distance_factor=config["compression_distance_factor"]
        )
        
        # Set filter parameters
        filter_params = {
            "distance_scale": config["filter_distance_scale"] * (1.0 + intensity * 0.5),
            "occlusion": 0.0  # No occlusion by default
        }
        
        # Set spatial width
        spatial_params = {
            "width": config["spatial_width"] * (0.8 + intensity * 0.2),
            "weather_type": weather_type,
            "intensity": intensity
        }
        
        return {
            "reverb": reverb_params,
            "compressor": compressor_params,
            "filter": filter_params,
            "spatial": spatial_params,
            "weather_type": weather_type,
            "intensity": intensity,
            "size": size
        }
    
    def process_complete_chain(self,
                             audio: np.ndarray,
                             environment: Dict,
                             position: Tuple[float, float, float] = (0, 0, 0)) -> np.ndarray:
        """
        Process audio through complete weather effect chain.
        
        Args:
            audio: Input audio
            environment: Environment parameters from create_weather_environment
            position: Source position for distance calculations
        
        Returns:
            Processed audio
        """
        # Calculate distance for filtering
        distance = np.sqrt(sum(p**2 for p in position))
        
        # Apply distance filtering
        filtered = self.filter.apply_distance_filtering(
            audio, 
            distance * environment["filter"]["distance_scale"],
            self.sample_rate
        )
        
        # Apply compression with spatial awareness
        compressed = self.compressor.process_audio(
            filtered,
            environment["compressor"]
        )
        
        # Apply reverb
        reverbed = self.reverb.process_audio(
            compressed,
            environment["reverb"],
            wet_dry=0.3 + environment["intensity"] * 0.2
        )
        
        return reverbed


# Module exports
__all__ = [
    'SpatialReverb',
    'SpatialCompressor', 
    'SpatialFilter',
    'WeatherEffects'
]