"""
Core SHAC Codec Module

This module contains the main SHACCodec class that ties together all the
functionality provided by the other modules.
"""

import numpy as np
import math
import time
from typing import Dict, List, Tuple, Optional, Union, Callable
import queue
import threading
import logging
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, as_completed

from .math_utils import AmbisonicNormalization, real_spherical_harmonic
from .encoders import encode_mono_source, encode_stereo_source, convert_to_cartesian, convert_to_spherical
from .processors import rotate_ambisonics, decode_to_speakers
from .binauralizer import binauralize_ambisonics, load_hrtf_database, apply_frequency_dependent_effects
from .utils import SourceAttributes, RoomAttributes, BinauralRendererConfig, HRTFInterpolationMethod, AmbisonicOrdering
from .io import SHACFileWriter, SHACFileReader

# Set up logger
logger = logging.getLogger(__name__)


class AudioMemoryPool:
    """
    Elegant memory management for zero-allocation audio processing.
    Pre-allocates buffers to eliminate memory allocation overhead during encoding.
    """
    
    def __init__(self, max_channels: int = 64, max_samples: int = 480000):  # 10 seconds at 48kHz
        """
        Initialize memory pools with pre-allocated buffers.
        
        Args:
            max_channels: Maximum number of ambisonic channels to support
            max_samples: Maximum samples per buffer (10 seconds at 48kHz by default)
        """
        self.max_channels = max_channels
        self.max_samples = max_samples
        
        # Pre-allocate buffer pools
        self.pools = {
            'temp_buffers': [np.zeros((max_channels, max_samples), dtype=np.float32) for _ in range(8)],
            'ambisonic_buffers': [np.zeros((max_channels, max_samples), dtype=np.float32) for _ in range(4)],
            'reflection_buffers': [np.zeros((max_channels, max_samples), dtype=np.float32) for _ in range(4)],
            'output_buffers': [np.zeros((max_channels, max_samples), dtype=np.float32) for _ in range(2)],
        }
        
        # Track which buffers are in use
        self.in_use = {pool_name: [False] * len(buffers) for pool_name, buffers in self.pools.items()}
        
        # Statistics
        self.stats = {
            'acquisitions': 0,
            'releases': 0,
            'cache_hits': 0,
            'cache_misses': 0
        }
        
        logger.info(f"AudioMemoryPool initialized: {max_channels} channels, {max_samples} samples per buffer")
    
    def acquire(self, buffer_type: str, channels: int, samples: int) -> np.ndarray:
        """
        Acquire a pre-allocated buffer from the pool.
        
        Args:
            buffer_type: Type of buffer ('temp_buffers', 'ambisonic_buffers', etc.)
            channels: Number of channels needed
            samples: Number of samples needed
            
        Returns:
            Pre-allocated buffer view with requested dimensions
        """
        self.stats['acquisitions'] += 1
        
        # Validate request
        if channels > self.max_channels or samples > self.max_samples:
            # Fallback: create new buffer for oversized requests
            self.stats['cache_misses'] += 1
            logger.warning(f"Buffer request exceeds pool size: {channels}ch x {samples}smp, creating new buffer")
            return np.zeros((channels, samples), dtype=np.float32)
        
        # Find available buffer in pool
        if buffer_type in self.pools:
            pool_buffers = self.pools[buffer_type]
            pool_usage = self.in_use[buffer_type]
            
            for i, in_use in enumerate(pool_usage):
                if not in_use:
                    # Mark as in use
                    pool_usage[i] = True
                    self.stats['cache_hits'] += 1
                    
                    # Return view with requested dimensions
                    buffer = pool_buffers[i]
                    buffer.fill(0)  # Clear for clean use
                    return buffer[:channels, :samples]
        
        # Fallback: create new buffer if pool exhausted
        self.stats['cache_misses'] += 1
        logger.debug(f"Pool {buffer_type} exhausted, creating new buffer")
        return np.zeros((channels, samples), dtype=np.float32)
    
    def release(self, buffer: np.ndarray) -> None:
        """
        Release a buffer back to the pool.
        
        Args:
            buffer: Buffer to release (must be a view from the pool)
        """
        self.stats['releases'] += 1
        
        # Find which pool this buffer belongs to
        for pool_name, pool_buffers in self.pools.items():
            pool_usage = self.in_use[pool_name]
            
            for i, pool_buffer in enumerate(pool_buffers):
                # Check if this buffer is a view of the pool buffer
                if (buffer.base is pool_buffer or 
                    (hasattr(buffer, 'base') and buffer.base is not None and buffer.base.base is pool_buffer)):
                    
                    if pool_usage[i]:  # Only release if marked as in use
                        pool_usage[i] = False
                        pool_buffer.fill(0)  # Clear for next use
                        return
        
        # If we get here, buffer wasn't from pool (fallback allocation)
        # Nothing to do - let garbage collector handle it
    
    def acquire_like(self, reference_array: np.ndarray, buffer_type: str = 'temp_buffers') -> np.ndarray:
        """
        Acquire buffer with same shape as reference array.
        
        Args:
            reference_array: Array to match shape
            buffer_type: Type of buffer to acquire
            
        Returns:
            Buffer with same shape as reference
        """
        return self.acquire(buffer_type, reference_array.shape[0], reference_array.shape[1])
    
    def get_stats(self) -> Dict:
        """Get memory pool usage statistics."""
        cache_hit_rate = (self.stats['cache_hits'] / 
                         max(1, self.stats['cache_hits'] + self.stats['cache_misses']))
        
        return {
            **self.stats,
            'cache_hit_rate': cache_hit_rate,
            'active_buffers': sum(sum(usage) for usage in self.in_use.values())
        }
    
    def reset_stats(self):
        """Reset statistics counters."""
        self.stats = {key: 0 for key in self.stats.keys()}


# Global memory pool instance
_global_memory_pool = None

def get_memory_pool() -> AudioMemoryPool:
    """Get or create the global memory pool instance."""
    global _global_memory_pool
    if _global_memory_pool is None:
        _global_memory_pool = AudioMemoryPool()
    return _global_memory_pool


class ParallelEncoder:
    """
    Elegant parallel processing for multi-layer compositions.
    Processes multiple audio layers simultaneously on different CPU cores.
    """
    
    def __init__(self, max_workers: Optional[int] = None):
        """
        Initialize parallel encoder.
        
        Args:
            max_workers: Maximum number of worker threads (None = auto-detect)
        """
        if max_workers is None:
            # Use reasonable number of cores, but don't overwhelm the system
            self.max_workers = min(8, multiprocessing.cpu_count())
        else:
            self.max_workers = max_workers
            
        self.stats = {
            'parallel_jobs': 0,
            'sequential_fallbacks': 0,
            'total_layers_processed': 0,
            'total_speedup': 0.0
        }
        
        logger.info(f"ParallelEncoder initialized with {self.max_workers} workers")
    
    def encode_layers_parallel(self, source_dict: Dict[str, Dict], 
                             codec_settings: Dict) -> Dict[str, np.ndarray]:
        """
        Process multiple layers in parallel.
        
        Args:
            source_dict: Dictionary of {layer_name: {audio_data, metadata}}
            codec_settings: Codec configuration (sample_rate, order, etc.)
            
        Returns:
            Dictionary of {layer_name: encoded_ambisonic_signals}
        """
        if len(source_dict) <= 1:
            # Single layer - no benefit from parallel processing
            return self._encode_layers_sequential(source_dict, codec_settings)
        
        start_time = time.time()
        results = {}
        
        try:
            # Prepare encoding tasks
            encoding_tasks = []
            for layer_name, layer_data in source_dict.items():
                task = {
                    'layer_name': layer_name,
                    'audio_data': layer_data['audio_data'],
                    'metadata': layer_data['metadata'],
                    'codec_settings': codec_settings
                }
                encoding_tasks.append(task)
            
            # Execute in parallel
            with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
                # Submit all encoding jobs
                future_to_layer = {
                    executor.submit(self._encode_single_layer, task): task['layer_name']
                    for task in encoding_tasks
                }
                
                # Collect results as they complete
                for future in as_completed(future_to_layer):
                    layer_name = future_to_layer[future]
                    try:
                        encoded_signals = future.result()
                        results[layer_name] = encoded_signals
                        logger.debug(f"Completed encoding for layer: {layer_name}")
                    except Exception as e:
                        logger.error(f"Failed to encode layer {layer_name}: {e}")
                        # Fallback: encode sequentially
                        layer_data = source_dict[layer_name]
                        results[layer_name] = self._encode_single_layer({
                            'layer_name': layer_name,
                            'audio_data': layer_data['audio_data'],
                            'metadata': layer_data['metadata'],
                            'codec_settings': codec_settings
                        })
                        self.stats['sequential_fallbacks'] += 1
            
            # Update statistics
            elapsed_time = time.time() - start_time
            sequential_estimate = len(source_dict) * (elapsed_time / len(source_dict))
            speedup = sequential_estimate / elapsed_time if elapsed_time > 0 else 1.0
            
            self.stats['parallel_jobs'] += 1
            self.stats['total_layers_processed'] += len(source_dict)
            self.stats['total_speedup'] += speedup
            
            logger.info(f"Parallel encoding completed: {len(source_dict)} layers, "
                       f"{speedup:.1f}x speedup, {elapsed_time:.2f}s total")
            
        except Exception as e:
            logger.warning(f"Parallel encoding failed, falling back to sequential: {e}")
            results = self._encode_layers_sequential(source_dict, codec_settings)
            self.stats['sequential_fallbacks'] += 1
        
        return results
    
    def _encode_single_layer(self, task: Dict) -> np.ndarray:
        """
        Encode a single layer (designed for parallel execution).
        
        Args:
            task: Dictionary with layer_name, audio_data, metadata, codec_settings
            
        Returns:
            Encoded ambisonic signals
        """
        layer_name = task['layer_name']
        audio_data = task['audio_data']
        metadata = task['metadata']
        codec_settings = task['codec_settings']
        
        # Extract encoding parameters
        position = metadata.get('position', [0, 0, 0])
        gain = metadata.get('gain', 1.0)
        sample_rate = codec_settings['sample_rate']
        order = codec_settings.get('order', 1)
        
        # Convert position to spherical coordinates
        spherical_pos = convert_to_spherical(position)
        
        # Encode based on audio format
        if len(audio_data.shape) == 1:
            # Mono audio
            encoded = encode_mono_source(
                audio_data, 
                spherical_pos[0],  # azimuth
                spherical_pos[1],  # elevation  
                spherical_pos[2],  # distance
                order,
                gain
            )
        elif len(audio_data.shape) == 2 and audio_data.shape[0] == 2:
            # Stereo audio - encode each channel separately and sum
            left_encoded = encode_mono_source(
                audio_data[0], spherical_pos[0], spherical_pos[1], 
                spherical_pos[2], order, gain * 0.7  # Reduce level for stereo
            )
            right_encoded = encode_mono_source(
                audio_data[1], spherical_pos[0] + 0.1, spherical_pos[1], 
                spherical_pos[2], order, gain * 0.7  # Slight position offset
            )
            encoded = left_encoded + right_encoded
        else:
            # Multi-channel - assume already ambisonic
            encoded = audio_data * gain
        
        logger.debug(f"Encoded layer {layer_name}: {encoded.shape}")
        return encoded
    
    def _encode_layers_sequential(self, source_dict: Dict[str, Dict], 
                                codec_settings: Dict) -> Dict[str, np.ndarray]:
        """
        Fallback sequential encoding for single layers or error recovery.
        
        Args:
            source_dict: Dictionary of {layer_name: {audio_data, metadata}}
            codec_settings: Codec configuration
            
        Returns:
            Dictionary of {layer_name: encoded_ambisonic_signals}
        """
        results = {}
        
        for layer_name, layer_data in source_dict.items():
            task = {
                'layer_name': layer_name,
                'audio_data': layer_data['audio_data'],
                'metadata': layer_data['metadata'],
                'codec_settings': codec_settings
            }
            results[layer_name] = self._encode_single_layer(task)
        
        return results
    
    def get_stats(self) -> Dict:
        """Get parallel processing statistics."""
        avg_speedup = (self.stats['total_speedup'] / 
                      max(1, self.stats['parallel_jobs']))
        
        return {
            **self.stats,
            'average_speedup': avg_speedup,
            'parallel_efficiency': (self.stats['parallel_jobs'] / 
                                  max(1, self.stats['parallel_jobs'] + self.stats['sequential_fallbacks']))
        }


# Global parallel encoder instance
_global_parallel_encoder = None

def get_parallel_encoder() -> ParallelEncoder:
    """Get or create the global parallel encoder instance."""
    global _global_parallel_encoder
    if _global_parallel_encoder is None:
        _global_parallel_encoder = ParallelEncoder()
    return _global_parallel_encoder


class StreamingEncoder:
    """
    Elegant streaming encoder for arbitrarily long compositions.
    Processes audio in chunks with overlap-add for seamless output.
    Maintains constant memory usage regardless of composition length.
    """
    
    def __init__(self, chunk_size: int = 8192, overlap_ratio: float = 0.25):
        """
        Initialize streaming encoder.
        
        Args:
            chunk_size: Samples per chunk (~170ms at 48kHz)
            overlap_ratio: Overlap between chunks (0.25 = 25% overlap)
        """
        self.chunk_size = chunk_size
        self.overlap_samples = int(chunk_size * overlap_ratio)
        self.hop_size = chunk_size - self.overlap_samples
        
        # Windowing for smooth overlap-add
        self.window = self._create_overlap_window()
        
        self.stats = {
            'chunks_processed': 0,
            'total_samples_processed': 0,
            'memory_peak_mb': 0,
            'processing_time': 0.0
        }
        
        logger.info(f"StreamingEncoder initialized: {chunk_size} samples/chunk, "
                   f"{overlap_ratio*100:.0f}% overlap")
    
    def _create_overlap_window(self) -> np.ndarray:
        """Create Hanning window for smooth overlap-add processing."""
        # Hanning window for the overlap regions
        overlap_window = np.ones(self.chunk_size)
        
        if self.overlap_samples > 0:
            # Fade in at start
            fade_in = np.hanning(2 * self.overlap_samples)[:self.overlap_samples]
            overlap_window[:self.overlap_samples] = fade_in
            
            # Fade out at end  
            fade_out = np.hanning(2 * self.overlap_samples)[self.overlap_samples:]
            overlap_window[-self.overlap_samples:] = fade_out
        
        return overlap_window
    
    def encode_streaming(self, source_dict: Dict[str, Dict], 
                        codec_settings: Dict,
                        output_callback: Optional[Callable] = None) -> Dict[str, np.ndarray]:
        """
        Encode large compositions using streaming processing.
        
        Args:
            source_dict: Dictionary of {layer_name: {audio_data, metadata}}
            codec_settings: Codec configuration
            output_callback: Optional callback for receiving chunks as they're processed
            
        Returns:
            Dictionary of {layer_name: complete_encoded_signals} or 
            None if using output_callback for streaming
        """
        start_time = time.time()
        
        # Find maximum length across all sources
        max_samples = max(
            len(layer_data['audio_data']) if len(layer_data['audio_data'].shape) == 1 
            else layer_data['audio_data'].shape[1]
            for layer_data in source_dict.values()
        )
        
        # If composition is short, use regular encoding
        if max_samples <= self.chunk_size * 2:
            logger.info("Composition is short, using regular encoding")
            return get_parallel_encoder().encode_layers_parallel(source_dict, codec_settings)
        
        # Prepare for streaming
        n_channels = (codec_settings.get('order', 1) + 1) ** 2
        memory_pool = get_memory_pool()
        
        # Storage for complete results (if not using callback)
        complete_results = {} if output_callback is None else None
        if complete_results is not None:
            for layer_name in source_dict.keys():
                complete_results[layer_name] = memory_pool.acquire(
                    'ambisonic_buffers', n_channels, max_samples
                )
        
        # Previous chunk overlap storage
        prev_overlap = {}
        
        try:
            # Process in chunks
            for chunk_start in range(0, max_samples, self.hop_size):
                chunk_end = min(chunk_start + self.chunk_size, max_samples)
                current_chunk_size = chunk_end - chunk_start
                
                # Extract chunk from each source
                chunk_sources = {}
                for layer_name, layer_data in source_dict.items():
                    audio_data = layer_data['audio_data']
                    
                    # Extract chunk (handle both mono and multi-channel)
                    if len(audio_data.shape) == 1:
                        chunk_audio = audio_data[chunk_start:chunk_end]
                    else:
                        chunk_audio = audio_data[:, chunk_start:chunk_end]
                    
                    # Pad if needed for consistent chunk size
                    if chunk_audio.shape[-1] < self.chunk_size:
                        if len(chunk_audio.shape) == 1:
                            padded = np.zeros(self.chunk_size)
                            padded[:chunk_audio.shape[0]] = chunk_audio
                        else:
                            padded = np.zeros((chunk_audio.shape[0], self.chunk_size))
                            padded[:, :chunk_audio.shape[1]] = chunk_audio
                        chunk_audio = padded
                    
                    chunk_sources[layer_name] = {
                        'audio_data': chunk_audio,
                        'metadata': layer_data['metadata']
                    }
                
                # Encode this chunk in parallel
                chunk_results = get_parallel_encoder().encode_layers_parallel(
                    chunk_sources, codec_settings
                )
                
                # Apply windowing and overlap-add
                for layer_name, encoded_chunk in chunk_results.items():
                    # Apply window
                    windowed_chunk = encoded_chunk * self.window[np.newaxis, :]
                    
                    # Handle overlap-add
                    if layer_name in prev_overlap:
                        # Add overlap from previous chunk
                        overlap_region = min(self.overlap_samples, windowed_chunk.shape[1])
                        windowed_chunk[:, :overlap_region] += prev_overlap[layer_name][:, :overlap_region]
                    
                    # Store overlap for next chunk
                    if chunk_end < max_samples:  # Not the last chunk
                        prev_overlap[layer_name] = windowed_chunk[:, -self.overlap_samples:].copy()
                    
                    # Store or output result
                    if output_callback is not None:
                        # Streaming mode - send chunk to callback
                        actual_chunk_size = min(self.hop_size, current_chunk_size)
                        output_chunk = windowed_chunk[:, :actual_chunk_size]
                        output_callback(layer_name, chunk_start, output_chunk)
                    else:
                        # Accumulation mode - store in complete result
                        actual_chunk_size = min(self.hop_size, current_chunk_size)
                        end_pos = chunk_start + actual_chunk_size
                        complete_results[layer_name][:, chunk_start:end_pos] = windowed_chunk[:, :actual_chunk_size]
                
                # Update statistics
                self.stats['chunks_processed'] += 1
                self.stats['total_samples_processed'] += current_chunk_size
                
                # Memory usage tracking
                import psutil
                process = psutil.Process()
                memory_mb = process.memory_info().rss / 1024 / 1024
                self.stats['memory_peak_mb'] = max(self.stats['memory_peak_mb'], memory_mb)
        
        finally:
            # Clean up overlap storage
            for overlap_buffer in prev_overlap.values():
                memory_pool.release(overlap_buffer)
        
        # Update final statistics
        self.stats['processing_time'] = time.time() - start_time
        
        logger.info(f"Streaming encoding completed: {self.stats['chunks_processed']} chunks, "
                   f"{self.stats['total_samples_processed']} samples, "
                   f"{self.stats['memory_peak_mb']:.1f}MB peak memory, "
                   f"{self.stats['processing_time']:.2f}s total")
        
        return complete_results
    
    def estimate_memory_usage(self, source_dict: Dict, codec_settings: Dict) -> Dict[str, float]:
        """
        Estimate memory usage for streaming vs regular encoding.
        
        Returns:
            Dictionary with memory estimates in MB
        """
        # Calculate audio data sizes
        total_samples = max(
            len(layer_data['audio_data']) if len(layer_data['audio_data'].shape) == 1 
            else layer_data['audio_data'].shape[1]
            for layer_data in source_dict.values()
        )
        
        n_layers = len(source_dict)
        n_channels = (codec_settings.get('order', 1) + 1) ** 2
        
        # Regular encoding memory (everything in memory at once)
        regular_mb = (total_samples * n_layers * n_channels * 4) / (1024 * 1024)  # 4 bytes per float32
        
        # Streaming encoding memory (only chunks in memory)
        streaming_mb = (self.chunk_size * n_layers * n_channels * 4) / (1024 * 1024)
        
        return {
            'regular_encoding_mb': regular_mb,
            'streaming_encoding_mb': streaming_mb,
            'memory_reduction': (regular_mb - streaming_mb) / regular_mb if regular_mb > 0 else 0,
            'total_duration_seconds': total_samples / codec_settings.get('sample_rate', 48000)
        }
    
    def get_stats(self) -> Dict:
        """Get streaming encoder statistics."""
        avg_chunk_time = (self.stats['processing_time'] / 
                         max(1, self.stats['chunks_processed']))
        
        return {
            **self.stats,
            'average_chunk_time_ms': avg_chunk_time * 1000,
            'samples_per_second': (self.stats['total_samples_processed'] / 
                                 max(0.001, self.stats['processing_time']))
        }


# Global streaming encoder instance  
_global_streaming_encoder = None

def get_streaming_encoder() -> StreamingEncoder:
    """Get or create the global streaming encoder instance."""
    global _global_streaming_encoder
    if _global_streaming_encoder is None:
        _global_streaming_encoder = StreamingEncoder()
    return _global_streaming_encoder


def apply_early_reflections(ambi_signals: np.ndarray, source_position: Tuple[float, float, float],
                         room_dimensions: Tuple[float, float, float], reflection_coefficients: Dict[str, float],
                         sample_rate: int, max_order: int = 1) -> np.ndarray:
    """
    Efficient shoebox early reflections model with memory pool optimization.
    
    Args:
        ambi_signals: Ambisonic signals, shape (n_channels, n_samples)
        source_position: Source position (x, y, z) in meters
        room_dimensions: Room dimensions (width, height, length) in meters
        reflection_coefficients: Dictionary of reflection coefficients for each surface
        sample_rate: Sample rate in Hz
        max_order: Maximum reflection order (1 = first-order reflections only)
        
    Returns:
        Ambisonic signals with early reflections
    """
    if max_order < 1:
        return ambi_signals
    
    n_channels, n_samples = ambi_signals.shape
    
    # Use memory pool for output buffer
    memory_pool = get_memory_pool()
    output_signals = memory_pool.acquire('reflection_buffers', n_channels, n_samples)
    output_signals[:] = ambi_signals  # Copy input to output buffer
    sx, sy, sz = source_position
    w, h, l = room_dimensions
    c = 343.0  # Speed of sound
    
    # Six primary reflections using vectorized operations
    # [surface_name]: (image_pos_calc, reflection_coeff_key)
    reflections = [
        ((-sx, sy, sz), 'left'),       # Left wall
        ((2*w - sx, sy, sz), 'right'), # Right wall  
        ((sx, -sy, sz), 'floor'),      # Floor
        ((sx, 2*h - sy, sz), 'ceiling'), # Ceiling
        ((sx, sy, -sz), 'front'),      # Front wall
        ((sx, sy, 2*l - sz), 'back')   # Back wall
    ]
    
    w_channel = ambi_signals[0]  # Extract W channel once
    
    for (img_x, img_y, img_z), surface in reflections:
        # Calculate reflection properties
        distance = math.sqrt(img_x**2 + img_y**2 + img_z**2)
        delay_samples = int(distance / c * sample_rate)
        
        if delay_samples >= n_samples:
            continue
            
        # Combined attenuation: surface reflection + distance
        attenuation = reflection_coefficients.get(surface, 0.5) / max(1.0, distance)
        
        # Spherical coordinates for reflection direction
        azimuth = math.atan2(img_x, img_z)
        elevation = math.atan2(img_y, math.sqrt(img_x**2 + img_z**2))
        
        # Create delayed W channel using memory pool
        delayed_w = memory_pool.acquire('temp_buffers', 1, n_samples)[0]  # Get single channel
        delayed_w[delay_samples:] = w_channel[:n_samples - delay_samples]
        
        # Encode to all ambisonic channels using vectorized spherical harmonics
        from .math_utils import compute_all_spherical_harmonics
        sh_coeffs = compute_all_spherical_harmonics(
            int(math.sqrt(n_channels)) - 1, azimuth, elevation
        )
        
        # Apply reflection to all channels at once
        for ch in range(min(len(sh_coeffs), n_channels)):
            output_signals[ch] += delayed_w * sh_coeffs[ch] * attenuation
        
        # Release the temporary delayed_w buffer
        memory_pool.release(delayed_w.reshape(1, -1))  # Reshape back to 2D for release
    
    # Efficient normalization
    max_val = np.max(np.abs(output_signals))
    if max_val > 0.99:
        output_signals *= (0.99 / max_val)
    
    # Create result copy and release pooled buffer
    result = output_signals.copy()
    memory_pool.release(output_signals)
    
    return result


def apply_diffuse_reverberation(ambi_signals: np.ndarray, rt60: float, sample_rate: int,
                              room_volume: float, early_reflection_delay: int = 0) -> np.ndarray:
    """
    Add professional spatial reverberation to ambisonic signals.
    
    REPLACED: The previous "exponentially decaying noise" placeholder with 
    proper convolution-based spatial reverb using impulse responses.
    
    Args:
        ambi_signals: Ambisonic signals with direct path and early reflections
        rt60: Reverberation time (time for level to drop by 60dB) in seconds
        sample_rate: Sample rate in Hz
        room_volume: Room volume in cubic meters (used for room size estimation)
        early_reflection_delay: Delay in samples before the reverb tail starts
        
    Returns:
        Ambisonic signals with professional spatial reverberation, shape (n_channels, n_samples)
    """
    try:
        # Import the professional spatial reverb engine
        from instruments.effects.spatial_reverb import SpatialReverbEngine
        
        # Determine room type based on RT60 and volume
        if rt60 < 0.6:
            room_preset = "studio"
        elif rt60 < 1.0:
            room_preset = "live_room" 
        elif rt60 < 2.0:
            room_preset = "chamber"
        elif rt60 < 3.5:
            room_preset = "hall"
        else:
            room_preset = "cathedral"
        
        # Create custom room parameters based on volume
        custom_room = None
        if room_volume > 0:
            # Estimate room dimensions from volume (assuming reasonable proportions)
            # Volume = width * height * length, assume height = 3m for typical room
            area = room_volume / 3.0
            width = height = math.sqrt(area)
            length = width * 1.2  # Slightly rectangular
            
            # Get base room preset
            reverb_engine = SpatialReverbEngine(sample_rate)
            base_room = reverb_engine.ROOM_PRESETS[room_preset].copy()
            
            # Customize dimensions and RT60
            custom_room = base_room
            custom_room["dimensions"] = (width, 3.0, length)
            custom_room["rt60"] = rt60
        
        # Apply professional spatial reverb
        reverb_engine = SpatialReverbEngine(sample_rate)
        processed_signals = reverb_engine.apply_spatial_reverb(
            ambi_signals,
            room_preset=room_preset,
            wet_level=0.25,  # Conservative wet level for backward compatibility
            custom_room=custom_room
        )
        
        return processed_signals
        
    except ImportError:
        # Fallback to basic reverb if professional version import fails
        # This is intentional graceful degradation for maximum compatibility
        logger.warning("Professional spatial reverb not available, using fallback implementation")
        return _apply_fallback_reverb(ambi_signals, rt60, sample_rate)


def _apply_fallback_reverb(ambi_signals: np.ndarray, rt60: float, sample_rate: int) -> np.ndarray:
    """
    Fallback reverb implementation - basic but functional.
    Used when professional reverb library is unavailable.
    """
    n_channels, n_samples = ambi_signals.shape
    reverb_samples = int(rt60 * sample_rate)
    
    if reverb_samples <= 0:
        return ambi_signals
    
    # Generate decay envelope once
    decay = np.exp(-6.91 * np.arange(reverb_samples) / reverb_samples)
    
    # Initialize output with original signal
    output_signals = ambi_signals.copy()
    
    # Vectorized reverb generation for all channels except W
    for ch in range(1, n_channels):  # Skip W channel (ch=0)
        # Reproducible noise per channel
        np.random.seed(ch)
        noise = np.random.randn(reverb_samples)
        
        # Channel-dependent frequency decay
        l = int(math.sqrt(ch))
        freq_decay = np.exp(-0.5 * l)
        
        # Generate reverb tail
        reverb_tail = noise * decay * freq_decay * 0.1
        
        # Add to signal using FFT convolution for efficiency
        if reverb_samples > 1024:
            output_signals[ch] = signal.fftconvolve(ambi_signals[ch], reverb_tail, mode='same')
        else:
            # Direct convolution for short reverbs
            output_signals[ch] = signal.convolve(ambi_signals[ch], reverb_tail, mode='same')
    
    # Enhanced W channel with stronger presence
    np.random.seed(0)
    w_reverb = np.random.randn(reverb_samples) * decay * 0.2
    output_signals[0] = signal.fftconvolve(ambi_signals[0], w_reverb, mode='same')
    
    # Efficient normalization
    max_val = np.max(np.abs(output_signals))
    if max_val > 0.99:
        output_signals *= (0.99 / max_val)
    
    return output_signals


class SHACCodec:
    """
    Spherical Harmonic Audio Codec (SHAC) main class.
    
    This class provides a complete API for encoding, processing, and decoding
    spatial audio using spherical harmonics.
    """
    
    def __init__(self, order: int = 3, sample_rate: int = 48000,
               normalization: AmbisonicNormalization = AmbisonicNormalization.SN3D,
               channel_ordering: AmbisonicOrdering = AmbisonicOrdering.ACN):
        """
        Initialize the SHAC codec.
        
        Args:
            order: Ambisonic order (higher = more spatial resolution)
            sample_rate: Audio sample rate in Hz
            normalization: Spherical harmonic normalization convention
            channel_ordering: Channel ordering convention
        """
        self.order = order
        self.sample_rate = sample_rate
        self.normalization = normalization
        self.channel_ordering = channel_ordering
        
        # Calculate number of channels
        self.n_channels = (order + 1) ** 2
        
        # Initialize storage for sources and layers
        self.sources = {}  # Raw mono sources
        self.layers = {}   # Encoded ambisonic layers
        self.layer_metadata = {}  # Clean metadata for SHAC files
        self._internal_metadata = {}  # Internal codec state (muted, gain, etc.)
        
        # Processing parameters
        self.hrtf_database = None
        self.binaural_renderer = None
        self.room = None
        
        # Initialize with default listener orientation
        self.listener_orientation = (0.0, 0.0, 0.0)  # yaw, pitch, roll
        
        # Frame size for real-time processing
        self.frame_size = 1024
        
        # For real-time processing
        self.processing_queue = queue.Queue()
        self.output_queue = queue.Queue()
        self.processing_thread = None
        self.running = False
        
        logger.info(f"SHAC Codec initialized: order {order}, {self.n_channels} channels")
    
    def add_mono_source(self, source_id: str, audio: np.ndarray, position: Tuple[float, float, float],
                      attributes: Optional[SourceAttributes] = None) -> None:
        """
        Add a mono source to the SHAC encoder.
        
        Args:
            source_id: Unique identifier for the source
            audio: Mono audio signal, shape (n_samples,)
            position: (azimuth, elevation, distance) in radians and meters
            attributes: Optional source attributes
        """
        if audio.ndim != 1:
            raise ValueError("Audio must be mono (1-dimensional)")
        
        # Store the source
        self.sources[source_id] = {
            'audio': audio,
            'position': position,
            'attributes': attributes if attributes is not None else SourceAttributes(position)
        }
        
        # Encode to ambisonics
        ambi_signals = encode_mono_source(audio, position, self.order, self.normalization)
        
        # Store as a layer
        self.layers[source_id] = ambi_signals
        
        # Convert spherical back to Cartesian for SHAC metadata
        cartesian_pos = convert_to_cartesian(position)
        
        # Clean metadata with proper Python types
        self.layer_metadata[source_id] = {
            'position': [float(cartesian_pos[0]), float(cartesian_pos[1]), float(cartesian_pos[2])],
            'type': 'mono_source'
        }
        
        # Internal codec state
        self._internal_metadata[source_id] = {
            'muted': False,
            'gain': 1.0,
            'current_gain': 1.0,
            'attributes': attributes
        }
    
    
    def add_ambisonic_layer(self, layer_id: str, ambi_signals: np.ndarray, metadata: Dict = None) -> None:
        """
        Add an existing ambisonic layer to the SHAC encoder.
        
        Args:
            layer_id: Unique identifier for the layer
            ambi_signals: Ambisonic signals, shape (n_channels, n_samples)
            metadata: Optional metadata for the layer
        """
        if ambi_signals.shape[0] != self.n_channels:
            raise ValueError(f"Expected {self.n_channels} channels, got {ambi_signals.shape[0]}")
        
        # Store the layer
        self.layers[layer_id] = ambi_signals
        
        # Create clean metadata for SHAC file
        if metadata is None:
            clean_metadata = {'type': 'ambisonic_layer'}
        else:
            # Keep only clean fields for SHAC file
            clean_metadata = {k: v for k, v in metadata.items() 
                            if k not in ['muted', 'gain', 'current_gain']}
            if 'type' not in clean_metadata:
                clean_metadata['type'] = 'ambisonic_layer'
        
        self.layer_metadata[layer_id] = clean_metadata
        
        # Store internal metadata separately
        self._internal_metadata[layer_id] = {
            'muted': metadata.get('muted', False) if metadata else False,
            'gain': metadata.get('gain', 1.0) if metadata else 1.0,
            'current_gain': metadata.get('current_gain', 1.0) if metadata else 1.0,
            'attributes': None
        }
    
    def update_source_position(self, source_id: str, position: Tuple[float, float, float]) -> None:
        """
        Update the position of a source.
        
        Args:
            source_id: Source identifier
            position: New (azimuth, elevation, distance) in radians and meters
        """
        if source_id not in self.sources:
            raise ValueError(f"Source not found: {source_id}")
        
        # Update the source position
        self.sources[source_id]['position'] = position
        
        # Update layer metadata
        if source_id in self.layer_metadata:
            self.layer_metadata[source_id]['position'] = position
        
        # Re-encode if it's a mono or stereo source
        source_data = self.sources[source_id]
        
        if 'audio' in source_data:
            # Mono source
            ambi_signals = encode_mono_source(source_data['audio'], position, self.order, self.normalization)
            self.layers[source_id] = ambi_signals
        
        elif 'left_audio' in source_data and 'right_audio' in source_data:
            # Stereo source
            width = source_data.get('width', 0.35)
            ambi_signals = encode_stereo_source(
                source_data['left_audio'], source_data['right_audio'], 
                position, width, self.order, self.normalization
            )
            self.layers[source_id] = ambi_signals
    
    def update_listener_orientation(self, yaw: float, pitch: float, roll: float) -> None:
        """
        Update the listener's orientation.
        
        Args:
            yaw: Rotation around vertical axis (positive = left) in radians
            pitch: Rotation around side axis (positive = up) in radians
            roll: Rotation around front axis (positive = tilt right) in radians
        """
        self.listener_orientation = (yaw, pitch, roll)
    
    def set_source_gain(self, source_id: str, gain: float) -> None:
        """
        Set the gain for a source or layer.
        
        Args:
            source_id: Source or layer identifier
            gain: Gain value (1.0 = unity gain)
        """
        if source_id in self._internal_metadata:
            self._internal_metadata[source_id]['gain'] = gain
            # For smooth transitions, actual gain change happens in process()
    
    def mute_source(self, source_id: str, muted: bool = True) -> None:
        """
        Mute or unmute a source or layer.
        
        Args:
            source_id: Source or layer identifier
            muted: Whether to mute (True) or unmute (False)
        """
        if source_id in self._internal_metadata:
            self._internal_metadata[source_id]['muted'] = muted
    
    def set_room_model(self, room_dimensions: Tuple[float, float, float], 
                     reflection_coefficients: Dict[str, float], rt60: float) -> None:
        """
        Set a room model for reflections and reverberation.
        
        Args:
            room_dimensions: (width, height, length) in meters
            reflection_coefficients: Coefficients for each surface
            rt60: Reverberation time in seconds
        """
        # Calculate room volume
        width, height, length = room_dimensions
        volume = width * height * length
        
        # Store room model
        self.room = {
            'dimensions': room_dimensions,
            'reflection_coefficients': reflection_coefficients,
            'rt60': rt60,
            'volume': volume
        }
    
    def set_binaural_renderer(self, hrtf_database: Union[str, Dict],
                            interpolation_method: HRTFInterpolationMethod = HRTFInterpolationMethod.SPHERICAL) -> None:
        """
        Set the binaural renderer configuration.
        
        Args:
            hrtf_database: Path to HRTF database or dictionary with HRTF data
            interpolation_method: HRTF interpolation method
        """
        if isinstance(hrtf_database, str):
            self.hrtf_database = load_hrtf_database(hrtf_database)
        else:
            self.hrtf_database = hrtf_database
        
        self.binaural_renderer = {
            'interpolation_method': interpolation_method,
            'nearfield_compensation': True,
            'crossfade_time': 0.1
        }
    
    def process(self) -> np.ndarray:
        """
        Process all sources and layers to create the final ambisonic signals.
        
        Optimized single-pass algorithm that combines mixing, gain application, 
        and reflection processing for maximum efficiency.
        
        Returns:
            Processed ambisonic signals, shape (n_channels, n_samples)
        """
        # Early exit for empty layers
        if not self.layers:
            return np.zeros((self.n_channels, 0))
        
        # Find max samples and active layers in one pass
        max_samples = 0
        active_layers = []
        
        for layer_id, ambi_signals in self.layers.items():
            internal_metadata = self._internal_metadata[layer_id]
            if not internal_metadata['muted']:
                max_samples = max(max_samples, ambi_signals.shape[1])
                active_layers.append((layer_id, ambi_signals, internal_metadata))
        
        if max_samples == 0:
            return np.zeros((self.n_channels, 0))
        
        # Initialize output using memory pool
        memory_pool = get_memory_pool()
        output_signals = memory_pool.acquire('output_buffers', self.n_channels, max_samples)
        
        # Single-pass processing: mix direct signals and compute reflections simultaneously
        room_reflections = memory_pool.acquire('temp_buffers', self.n_channels, max_samples) if self.room else None
        
        for layer_id, ambi_signals, metadata in active_layers:
            # Apply gain once
            gain = metadata['current_gain']
            n_samples = ambi_signals.shape[1]
            
            # Mix direct signal with gain applied
            output_signals[:, :n_samples] += ambi_signals * gain
            
            # Compute reflections in same loop if room model exists
            if self.room and layer_id in self.sources:
                position = self.sources[layer_id]['position']
                cart_pos = convert_to_cartesian(position)
                
                # Apply early reflections with gain pre-applied
                source_reflections = apply_early_reflections(
                    ambi_signals,  # Use original signal
                    cart_pos,
                    self.room['dimensions'],
                    self.room['reflection_coefficients'],
                    self.sample_rate
                )
                
                # Add reflections with gain applied
                room_reflections[:, :n_samples] += source_reflections * gain
        
        # Apply room processing if present
        if self.room:
            # Mix in all reflections at once
            output_signals += room_reflections
            
            # Release room reflections buffer
            memory_pool.release(room_reflections)
            
            # Apply diffuse reverberation
            output_signals = apply_diffuse_reverberation(
                output_signals,
                self.room['rt60'],
                self.sample_rate,
                self.room['volume']
            )
        
        # Efficient normalization
        max_val = np.max(np.abs(output_signals))
        if max_val > 0.99:
            output_signals *= (0.99 / max_val)
        
        # Create result copy and release pooled buffer
        result = output_signals.copy()
        memory_pool.release(output_signals)
        
        return result
    
    def rotate(self, ambi_signals: np.ndarray, yaw: float, pitch: float, roll: float) -> np.ndarray:
        """
        Rotate the ambisonic sound field.
        
        Args:
            ambi_signals: Ambisonic signals to rotate
            yaw: Rotation around vertical axis in radians
            pitch: Rotation around side axis in radians
            roll: Rotation around front axis in radians
            
        Returns:
            Rotated ambisonic signals
        """
        return rotate_ambisonics(ambi_signals, yaw, pitch, roll)
    
    def set_listener_rotation(self, rotation: Tuple[float, float, float]) -> None:
        """
        Set the listener's head rotation.
        
        Args:
            rotation: Tuple of (yaw, pitch, roll) in radians
        """
        self.listener_orientation = rotation
    
    def binauralize(self, ambi_signals: np.ndarray) -> np.ndarray:
        """
        Convert ambisonic signals to binaural stereo.
        
        Args:
            ambi_signals: Ambisonic signals to binauralize
            
        Returns:
            Binaural stereo signals, shape (2, n_samples)
        """
        if self.hrtf_database is None:
            # Load synthetic HRTF database
            self.hrtf_database = load_hrtf_database()
            logger.info("Using synthetic HRTF database for binauralization")
        
        return binauralize_ambisonics(ambi_signals, self.hrtf_database)
    
    def save_to_file(self, filename: str, bit_depth: int = 32) -> None:
        """
        Save the processed audio to a SHAC file.
        
        Args:
            filename: Output filename
            bit_depth: Bit depth (16 or 32)
        """
        # Create a SHAC file writer
        writer = SHACFileWriter(self.order, self.sample_rate, self.normalization)
        
        # Add individual layers only - NO premixed "main" layer
        # Each source maintains its spatial position for interactive experience
        for layer_id, layer_signals in self.layers.items():
            if not self._internal_metadata[layer_id]['muted']:
                writer.add_layer(layer_id, layer_signals, self.layer_metadata[layer_id])
        
        # Write the file
        writer.write_file(filename, bit_depth)
        logger.info(f"Saved SHAC file: {filename}")
    
    def load_from_file(self, filename: str) -> None:
        """
        Load audio from a SHAC file.
        
        Args:
            filename: Input filename
        """
        # Create a SHAC file reader
        reader = SHACFileReader(filename)
        
        # Get file info
        file_info = reader.get_file_info()
        
        # Update codec parameters
        self.order = file_info['order']
        self.sample_rate = file_info['sample_rate']
        self.n_channels = file_info['n_channels']
        
        # Clear existing layers
        self.layers = {}
        self.layer_metadata = {}
        
        # Load each layer
        for layer_name in reader.get_layer_names():
            layer_audio = reader.read_layer(layer_name)
            layer_metadata = reader.get_layer_metadata(layer_name)
            
            self.add_ambisonic_layer(layer_name, layer_audio, layer_metadata)
            
        logger.info(f"Loaded SHAC file: {filename} with {len(reader.get_layer_names())} layers")
    
    def start_realtime_processing(self, callback: Callable[[np.ndarray], None]) -> None:
        """
        Start real-time processing thread.
        
        Args:
            callback: Function to call with processed audio frames
        """
        if self.running:
            return
        
        self.running = True
        
        def processing_loop():
            while self.running:
                try:
                    # Get input frame from queue
                    frame_data = self.processing_queue.get(timeout=0.1)
                    
                    if frame_data is None:
                        continue
                    
                    # Process the frame
                    output_frame = self.process_frame(frame_data)
                    
                    # Call the callback
                    callback(output_frame)
                    
                except queue.Empty:
                    # No input data available
                    pass
                except Exception as e:
                    logger.error(f"Error in processing thread: {e}")
        
        self.processing_thread = threading.Thread(target=processing_loop)
        self.processing_thread.daemon = True
        self.processing_thread.start()
    
    def stop_realtime_processing(self) -> None:
        """Stop the real-time processing thread."""
        self.running = False
        if self.processing_thread:
            self.processing_thread.join(timeout=1.0)
            self.processing_thread = None
    
    def process_frame(self, frame_data: Dict) -> np.ndarray:
        """Process a single frame of audio data.
        
        Args:
            frame_data: Dictionary with input data
            
        Returns:
            Processed audio frame
        """
        # Example implementation - this would vary based on your real-time needs
        output_frame = np.zeros((self.n_channels, self.frame_size))
        
        # Process source updates
        for source_id, source_data in frame_data.get('sources', {}).items():
            if source_id in self.sources:
                if 'position' in source_data:
                    self.update_source_position(source_id, source_data['position'])
                if 'gain' in source_data:
                    self.set_source_gain(source_id, source_data['gain'])
                if 'muted' in source_data:
                    self.mute_source(source_id, source_data['muted'])
        
        # Process listener updates
        listener_data = frame_data.get('listener', {})
        if 'orientation' in listener_data:
            yaw, pitch, roll = listener_data['orientation']
            self.update_listener_orientation(yaw, pitch, roll)
        
        # Process audio data
        if 'audio_data' in frame_data:
            # Actual processing would go here
            # This is just a placeholder
            pass
        
        return output_frame