"""
Binaural Rendering Module

This module contains functions for binaural rendering of ambisonic signals
using Head-Related Transfer Functions (HRTFs).

This is a simplified version that provides basic 3D audio spatialization
without requiring external HRTF databases.
"""

import numpy as np
import math
import os
from typing import Dict, List, Tuple, Optional, Union, Any
import logging

from .utils import HRTFInterpolationMethod, SphericalCoord, Vector3, HRTFData

# Set up logging
logger = logging.getLogger(__name__)

# Flag for whether we've warned about synthetic HRTF use
_synthetic_hrtf_warning_shown = False


def generate_synthetic_hrtf_database(sample_rate: int = 48000) -> Dict:
    """
    Generate a simple synthetic HRTF database when no measured data is available.
    This provides very basic 3D cues but is not as realistic as measured HRTFs.
    
    Args:
        sample_rate: The sample rate in Hz
        
    Returns:
        A synthetic HRTF database compatible with the binauralizer
    """
    global _synthetic_hrtf_warning_shown
    
    if not _synthetic_hrtf_warning_shown:
        logger.warning("Using synthetic HRTF database. For better spatial audio quality, "
                     "consider using a real SOFA file.")
        _synthetic_hrtf_warning_shown = True
    
    # Create positions in a grid
    azimuths = np.linspace(0, 2*np.pi, 12, endpoint=False)
    elevations = np.linspace(-np.pi/2, np.pi/2, 6)
    
    # Create position array
    positions = []
    for azimuth in azimuths:
        for elevation in elevations:
            positions.append((azimuth, elevation, 1.0))  # Fixed distance of 1m
    
    positions = np.array(positions)
    
    # Create HRTF dictionary
    hrtf_length = 128
    hrtf_dict = {}
    
    for i, (azimuth, elevation, distance) in enumerate(positions):
        # Create a simple head model for ITD (interaural time difference)
        # ITD = (r/c) * (sin(azimuth) + elevation/π)
        head_radius = 0.0875  # meters
        speed_of_sound = 343.0  # m/s
        
        # Calculate ITD in samples
        itd_max = head_radius / speed_of_sound * sample_rate
        itd = itd_max * np.sin(azimuth) * np.cos(elevation)
        itd_samples = int(abs(itd))
        
        # Create left and right HRTFs with interaural differences
        left_hrtf = np.zeros(hrtf_length)
        right_hrtf = np.zeros(hrtf_length)
        
        # Center impulse
        center = hrtf_length // 4
        
        # Create basic HRTF with ITD and ILD
        if itd >= 0:  # Sound is more to the right
            left_hrtf[center + itd_samples] = 0.8
            right_hrtf[center] = 1.0
        else:  # Sound is more to the left
            left_hrtf[center] = 1.0
            right_hrtf[center + itd_samples] = 0.8
        
        # Apply basic elevation cues (very simplified)
        elev_factor = 0.5 + 0.5 * np.sin(elevation)
        freq = 2 + elev_factor * 10  # Higher frequencies for higher elevations
        
        # Apply minimal frequency shaping
        t = np.arange(hrtf_length)
        env = np.exp(-0.5 * t / hrtf_length)
        left_hrtf = left_hrtf * env
        right_hrtf = right_hrtf * env
        
        # Store in dictionary
        hrtf_dict[(azimuth, elevation, distance)] = {
            'left': left_hrtf,
            'right': right_hrtf
        }
    
    # Create the full database structure
    hrtf_database = {
        'sample_rate': sample_rate,
        'positions': positions,
        'hrtf_dict': hrtf_dict,
        'max_order': 3,  # Reasonable default
        'sh_hrtfs': None,  # Will be computed on demand
        'convention': "Synthetic",
        'version': "1.0"
    }
    
    # Create SH-domain HRTFs
    sh_hrtfs = _create_sh_domain_hrtfs(hrtf_database)
    hrtf_database['sh_hrtfs'] = sh_hrtfs
    
    return hrtf_database


def _create_sh_domain_hrtfs(hrtf_database: Dict) -> np.ndarray:
    """
    Create spherical harmonic domain representation of the HRTF database.
    
    Args:
        hrtf_database: Dictionary containing HRTF data
        
    Returns:
        SH-domain HRTF coefficients, shape (2, n_channels, hrtf_length)
    """
    max_order = hrtf_database['max_order']
    n_channels = (max_order + 1) ** 2
    
    # For this simplified version, create synthetic SH-domain HRTFs directly
    hrtf_length = 256
    
    # Create synthetic SH-domain HRTFs
    sh_hrtfs = np.zeros((2, n_channels, hrtf_length))
    
    # W channel (order 0)
    # Simple approximation: omnidirectional with ITD
    sh_hrtfs[0, 0] = np.hstack([np.zeros(5), np.exp(-np.arange(hrtf_length-5)/20)])
    sh_hrtfs[1, 0] = np.hstack([np.zeros(8), np.exp(-np.arange(hrtf_length-8)/20)])
    
    # Y channel (order 1, m=-1): front-back
    # Affect the coloration differences between front and back
    sh_hrtfs[0, 1] = np.hstack([np.zeros(5), 0.5 * np.exp(-np.arange(hrtf_length-5)/15)])
    sh_hrtfs[1, 1] = np.hstack([np.zeros(8), 0.5 * np.exp(-np.arange(hrtf_length-8)/15)])
    
    # Z channel (order 1, m=0): up-down
    # Affect the coloration differences between up and down
    sh_hrtfs[0, 2] = np.hstack([np.zeros(5), 0.3 * np.exp(-np.arange(hrtf_length-5)/10)])
    sh_hrtfs[1, 2] = np.hstack([np.zeros(8), 0.3 * np.exp(-np.arange(hrtf_length-8)/10)])
    
    # X channel (order 1, m=1): left-right
    # Strongest ILD component
    sh_hrtfs[0, 3] = np.hstack([np.zeros(5), -0.7 * np.exp(-np.arange(hrtf_length-5)/25)])
    sh_hrtfs[1, 3] = np.hstack([np.zeros(8), 0.7 * np.exp(-np.arange(hrtf_length-8)/25)])
    
    # Higher order components (simplified)
    for ch in range(4, n_channels):
        l = math.floor(math.sqrt(ch))
        # Reduce the amplitude with increasing order
        gain = 0.2 * (1.0 / l)
        
        # Add some randomness to the higher order components
        # In a real HRTF, these would have specific patterns
        np.random.seed(ch)  # For reproducibility
        sh_hrtfs[0, ch] = gain * np.random.randn(hrtf_length) * np.exp(-np.arange(hrtf_length)/10)
        sh_hrtfs[1, ch] = gain * np.random.randn(hrtf_length) * np.exp(-np.arange(hrtf_length)/10)
    
    # Normalize
    for ear in range(2):
        max_val = np.max(np.abs(sh_hrtfs[ear, 0]))
        if max_val > 0:
            sh_hrtfs[ear, 0] = sh_hrtfs[ear, 0] / max_val
    
    return sh_hrtfs


def binauralize_ambisonics(ambi_signals: np.ndarray, hrtf_database: Union[str, Dict], 
                          normalize: bool = True, 
                          interpolation_method: HRTFInterpolationMethod = HRTFInterpolationMethod.SPHERICAL) -> np.ndarray:
    """
    Convert ambisonic signals to binaural stereo using HRTF convolution.
    
    Args:
        ambi_signals: Ambisonic signals, shape (n_channels, n_samples)
        hrtf_database: Dictionary with HRTF data
        normalize: Whether to normalize the output
        interpolation_method: Method to use for HRTF interpolation
        
    Returns:
        Binaural stereo signals, shape (2, n_samples)
    """
    n_channels = ambi_signals.shape[0]
    n_samples = ambi_signals.shape[1]
    order = math.floor(math.sqrt(n_channels)) - 1
    
    # Load HRTF data if needed
    if isinstance(hrtf_database, str):
        # For this simplified version, we'll just generate a synthetic database
        hrtf_data = generate_synthetic_hrtf_database()
    else:
        # Assume it's already loaded as a dictionary
        hrtf_data = hrtf_database
    
    # Ensure we have SH-domain HRTFs
    if 'sh_hrtfs' not in hrtf_data or hrtf_data['sh_hrtfs'] is None:
        hrtf_data['sh_hrtfs'] = _create_sh_domain_hrtfs(hrtf_data)
    
    # Get the HRTF filters in SH domain
    sh_hrtfs = hrtf_data['sh_hrtfs']
    
    if sh_hrtfs.shape[1] < n_channels:
        logger.warning(f"HRTF database only supports up to order {math.floor(math.sqrt(sh_hrtfs.shape[1])) - 1}, " +
                      f"but got signals of order {order}. Truncating to available order.")
        # Truncate the ambisonic signals to the available order
        n_channels = sh_hrtfs.shape[1]
        ambi_signals = ambi_signals[:n_channels]
    
    # Apply SH-domain convolution
    binaural = np.zeros((2, n_samples + sh_hrtfs.shape[2] - 1))
    for ch in range(n_channels):
        binaural[0] += np.convolve(ambi_signals[ch], sh_hrtfs[0, ch])
        binaural[1] += np.convolve(ambi_signals[ch], sh_hrtfs[1, ch])
    
    # Truncate the result to the original length
    binaural = binaural[:, :n_samples]
    
    # Normalize if requested
    if normalize:
        max_val = np.max(np.abs(binaural))
        if max_val > 0.0:
            binaural = binaural / max_val * 0.99
    
    return binaural


def load_hrtf_database(hrtf_path: str = "default") -> Dict:
    """
    Load and prepare an HRTF database for use with the binauralize_ambisonics function.
    
    For this simplified version, we just create a synthetic database regardless of input.
    
    Args:
        hrtf_path: Path to the HRTF database (ignored in this simplified version)
        
    Returns:
        Dictionary containing HRTF data in suitable format for binauralization
    """
    return generate_synthetic_hrtf_database()


def binauralize_mono_source(mono_signal: np.ndarray, position: SphericalCoord, 
                           hrtf_database: Union[str, Dict],
                           interpolation_method: HRTFInterpolationMethod = HRTFInterpolationMethod.SPHERICAL) -> np.ndarray:
    """
    Render a mono sound source to binaural stereo at a specific position.
    
    This is a direct binaural rendering that bypasses the ambisonic encoding.
    It's more efficient for rendering a single source and can provide better
    quality for near-field sources.
    
    Args:
        mono_signal: Mono audio signal, shape (n_samples,)
        position: Source position as (azimuth, elevation, distance)
        hrtf_database: Path to HRTF database or dictionary with HRTF data
        interpolation_method: Method to use for HRTF interpolation
        
    Returns:
        Binaural stereo signal, shape (2, n_samples)
    """
    # Load HRTF data if needed
    if isinstance(hrtf_database, str):
        hrtf_data = generate_synthetic_hrtf_database()
    else:
        hrtf_data = hrtf_database
    
    # Get HRTF dictionary and positions
    hrtf_dict = hrtf_data['hrtf_dict']
    positions = hrtf_data['positions']
    
    # Normalize position
    azimuth, elevation, distance = position
    azimuth = azimuth % (2 * np.pi)
    elevation = np.clip(elevation, -np.pi/2, np.pi/2)
    
    # Find nearest position in database
    pos_key = None
    min_dist = float('inf')
    
    for az, el, dist in hrtf_dict.keys():
        # Calculate spherical distance
        d = (np.sin(elevation) * np.sin(el) + 
             np.cos(elevation) * np.cos(el) * np.cos(azimuth - az))
        d = np.clip(d, -1.0, 1.0)
        angle_dist = np.arccos(d)
        
        # Include distance as a factor
        dist_factor = abs(np.log(distance / dist))
        total_dist = angle_dist + 0.2 * dist_factor
        
        if total_dist < min_dist:
            min_dist = total_dist
            pos_key = (az, el, dist)
    
    # Get the HRTF for the nearest position
    hrtf = hrtf_dict[pos_key]
    left_hrtf = hrtf['left']
    right_hrtf = hrtf['right']
    
    # Apply distance attenuation
    if distance > 1.0:
        attenuation = 1.0 / distance
        mono_signal = mono_signal * attenuation
    
    # Convolve with HRTFs
    left_channel = np.convolve(mono_signal, left_hrtf)
    right_channel = np.convolve(mono_signal, right_hrtf)
    
    # Trim to original length
    n_samples = len(mono_signal)
    left_channel = left_channel[:n_samples]
    right_channel = right_channel[:n_samples]
    
    # Stack channels
    binaural = np.vstack([left_channel, right_channel])
    
    return binaural


def apply_frequency_dependent_effects(signal: np.ndarray, position: SphericalCoord) -> np.ndarray:
    """
    Apply frequency-dependent effects like air absorption based on distance.
    
    Args:
        signal: Audio signal
        position: Source position as (azimuth, elevation, distance)
        
    Returns:
        Processed audio signal
    """
    # Simplified version just applies a very basic distance-based lowpass filter
    azimuth, elevation, distance = position
    
    # No processing for close sources
    if distance <= 1.0:
        return signal
    
    # Apply simplified air absorption (just a crude lowpass approximation)
    # In a real implementation, this would be a more accurate filter
    n_samples = len(signal)
    
    # Create a simple 1-pole lowpass filter
    # The cutoff frequency decreases with distance
    cutoff = 20000.0 / (1.0 + 0.1 * distance)
    
    # Simple IIR filter
    filtered = np.zeros_like(signal)
    filtered[0] = signal[0]
    
    # Apply the filter
    alpha = 0.01 + 0.99 * (1.0 - distance / 50.0)
    if alpha < 0.01:
        alpha = 0.01
    
    for i in range(1, n_samples):
        filtered[i] = alpha * signal[i] + (1.0 - alpha) * filtered[i-1]
    
    return filtered