"""
Spatial Oscillator Module for Atabey Symphony

This module provides oscillator synthesis functions enhanced with 3D spatial positioning.
Based on Weather-Tune's oscillator.js but adapted for SHAC's spatial audio capabilities.

The oscillators generate musical tones that can be positioned in 3D space,
with parameters driven by weather data. The spatial positioning enhances
the musical experience without simulating weather sounds.
"""

import numpy as np
import math
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass


@dataclass
class OscillatorConfig:
    """Configuration for a spatial oscillator"""
    frequency: float = 440.0
    type: str = "sine"
    volume: float = 0.5
    detune: float = 0.0
    position: Tuple[float, float, float] = (0.0, 0.0, 3.0)  # (azimuth, elevation, distance)
    envelope: Optional[Dict[str, float]] = None
    filter: Optional[Dict[str, float]] = None
    duration: Optional[float] = None
    
    
@dataclass
class SpatialMovement:
    """Movement pattern for oscillator position"""
    pattern: str = "static"  # "static", "circular", "arc", "random"
    speed: float = 0.1  # Movement speed
    radius: float = 1.0  # Movement radius
    center: Tuple[float, float, float] = (0.0, 0.0, 3.0)
    

class SpatialOscillator:
    """
    Core oscillator class that generates audio with spatial positioning.
    Preserves Weather-Tune's musical logic while adding 3D positioning.
    """
    
    def __init__(self, sample_rate: int = 48000):
        self.sample_rate = sample_rate
        self.oscillators = {}
        self.active_sources = {}
        
    def create_oscillator(self, osc_id: str, config: OscillatorConfig) -> np.ndarray:
        """
        Create a basic oscillator with spatial parameters.
        
        Args:
            osc_id: Unique identifier for the oscillator
            config: Oscillator configuration
            
        Returns:
            Generated audio signal
        """
        # Generate time array
        if config.duration:
            num_samples = int(config.duration * self.sample_rate)
        else:
            num_samples = self.sample_rate  # Default 1 second
            
        t = np.linspace(0, num_samples / self.sample_rate, num_samples, endpoint=False)
        
        # Apply detuning (convert cents to frequency ratio)
        detune_ratio = 2 ** (config.detune / 1200.0)
        freq = config.frequency * detune_ratio
        
        # Generate oscillator based on type
        if config.type == "sine":
            signal = np.sin(2 * np.pi * freq * t)
        elif config.type == "triangle":
            signal = 2 * np.arcsin(np.sin(2 * np.pi * freq * t)) / np.pi
        elif config.type == "sawtooth":
            signal = 2 * (t * freq % 1) - 1
        elif config.type == "square":
            signal = np.sign(np.sin(2 * np.pi * freq * t))
        else:  # Default to sine
            signal = np.sin(2 * np.pi * freq * t)
            
        # Apply volume
        signal *= config.volume
        
        # Apply filter if specified
        if config.filter:
            signal = self._apply_filter(signal, config.filter)
            
        # Apply envelope if specified
        if config.envelope:
            signal = self._apply_envelope(signal, config.envelope)
            
        # Store oscillator data
        self.oscillators[osc_id] = {
            'config': config,
            'signal': signal,
            'position': config.position
        }
        
        return signal
    
    def create_rich_oscillator(self, osc_id: str, config: OscillatorConfig,
                            type1: str = "sine", type2: str = "triangle") -> np.ndarray:
        """
        Create a rich, layered oscillator with multiple components.
        Preserves Weather-Tune's approach to creating rich timbres.
        
        Args:
            osc_id: Unique identifier for the oscillator
            config: Base oscillator configuration
            type1: First oscillator type
            type2: Second oscillator type
            
        Returns:
            Combined audio signal
        """
        # Create master signal
        master_signal = np.zeros(int(config.duration * self.sample_rate) if config.duration 
                               else self.sample_rate)
        
        # Component configurations with spatial spread
        components = [
            # Primary oscillator - centered
            {
                'type': type1,
                'frequency': config.frequency,
                'detune': 0,
                'volume': 0.65 * config.volume,
                'position': config.position
            },
            # Detuned companion - slightly right
            {
                'type': type2,
                'frequency': config.frequency,
                'detune': config.detune,
                'volume': 0.55 * config.volume,
                'position': (config.position[0] + 0.2, config.position[1], config.position[2])
            },
            # Sub oscillator (octave down) - slightly left and below
            {
                'type': "sine",
                'frequency': config.frequency / 2,
                'detune': -config.detune / 2,
                'volume': 0.4 * config.volume,
                'position': (config.position[0] - 0.2, config.position[1] - 0.1, config.position[2])
            },
            # Upper harmonic - slightly above
            {
                'type': "sine",
                'frequency': config.frequency * 2,
                'detune': config.detune * 1.5,
                'volume': 0.15 * config.volume,
                'position': (config.position[0], config.position[1] + 0.1, config.position[2])
            }
        ]
        
        # Generate each component
        component_signals = []
        positions = []
        
        for i, comp in enumerate(components):
            if comp['volume'] <= 0:
                continue
                
            # Create component config
            comp_config = OscillatorConfig(
                frequency=comp['frequency'],
                type=comp['type'],
                volume=comp['volume'],
                detune=comp['detune'],
                position=comp['position'],
                envelope=config.envelope,
                filter=config.filter,
                duration=config.duration
            )
            
            # Generate component signal
            signal = self.create_oscillator(f"{osc_id}_comp_{i}", comp_config)
            component_signals.append(signal)
            positions.append(comp['position'])
            
        # Store rich oscillator data
        self.oscillators[osc_id] = {
            'config': config,
            'components': component_signals,
            'positions': positions,
            'type': 'rich'
        }
        
        return component_signals
    
    def create_dual_oscillator(self, osc_id: str, config: OscillatorConfig,
                            type1: str = "sine", type2: str = "triangle") -> np.ndarray:
        """
        Create a dual oscillator with detuning for richer sound.
        Spatially separates the two oscillators for width.
        
        Args:
            osc_id: Unique identifier for the oscillator
            config: Base oscillator configuration
            type1: First oscillator type
            type2: Second oscillator type
            
        Returns:
            List of [left_signal, right_signal] for spatial separation
        """
        # Create two oscillators with slight spatial separation
        osc1_config = OscillatorConfig(
            frequency=config.frequency,
            type=type1,
            volume=config.volume,
            detune=0,
            position=(config.position[0] - 0.1, config.position[1], config.position[2]),
            envelope=config.envelope,
            filter=config.filter,
            duration=config.duration
        )
        
        osc2_config = OscillatorConfig(
            frequency=config.frequency,
            type=type2,
            volume=config.volume,
            detune=config.detune,
            position=(config.position[0] + 0.1, config.position[1], config.position[2]),
            envelope=config.envelope,
            filter=config.filter,
            duration=config.duration
        )
        
        # Generate signals
        signal1 = self.create_oscillator(f"{osc_id}_1", osc1_config)
        signal2 = self.create_oscillator(f"{osc_id}_2", osc2_config)
        
        # Store dual oscillator data
        self.oscillators[osc_id] = {
            'config': config,
            'signals': [signal1, signal2],
            'positions': [osc1_config.position, osc2_config.position],
            'type': 'dual'
        }
        
        return [signal1, signal2]
    
    def create_pad(self, pad_id: str, config: OscillatorConfig,
                   filter_type: str = "lowpass", filter_freq: float = 1000) -> List[Tuple[np.ndarray, Tuple[float, float, float]]]:
        """
        Create a pad sound with slow evolution and rich harmonics.
        Returns multiple layers with their spatial positions.
        
        Args:
            pad_id: Unique identifier for the pad
            config: Base oscillator configuration
            filter_type: Type of filter to apply
            filter_freq: Filter frequency
            
        Returns:
            List of (signal, position) tuples for each layer
        """
        layers = []
        
        # Primary layer - centered
        primary_config = OscillatorConfig(
            frequency=config.frequency,
            volume=0.7 * config.volume,
            type="sine",
            detune=config.detune,
            position=config.position,
            envelope={
                'attack': config.envelope.get('attack', 1.5) if config.envelope else 1.5,
                'decay': config.envelope.get('decay', 2.0) if config.envelope else 2.0,
                'sustain': config.envelope.get('sustain', 0.7) if config.envelope else 0.7,
                'release': config.envelope.get('release', 3.0) if config.envelope else 3.0
            },
            filter={'type': filter_type, 'frequency': filter_freq, 'Q': 0.5},
            duration=config.duration
        )
        
        primary_signals = self.create_rich_oscillator(f"{pad_id}_primary", primary_config)
        
        # Higher octave layer - spread in upper space
        if config.frequency > 100:  # Only add if frequency is high enough
            high_config = OscillatorConfig(
                frequency=config.frequency * 2,
                volume=0.3 * config.volume,
                type="sine",
                detune=config.detune * 1.5,
                position=(config.position[0], config.position[1] + 0.2, config.position[2] + 0.5),
                envelope={
                    'attack': primary_config.envelope['attack'] * 1.3,
                    'decay': primary_config.envelope['decay'],
                    'sustain': primary_config.envelope['sustain'] * 0.8,
                    'release': primary_config.envelope['release']
                },
                filter={'type': filter_type, 'frequency': filter_freq * 1.5, 'Q': 0.3},
                duration=config.duration
            )
            
            high_signals = self.create_dual_oscillator(f"{pad_id}_high", high_config)
            
        # Lower octave layer - spread in lower space
        if config.frequency > 200:  # Only add if frequency is high enough
            low_config = OscillatorConfig(
                frequency=config.frequency / 2,
                volume=0.4 * config.volume,
                type="sine",
                detune=0,
                position=(config.position[0], config.position[1] - 0.2, config.position[2] + 0.7),
                envelope={
                    'attack': primary_config.envelope['attack'] * 1.7,
                    'decay': primary_config.envelope['decay'],
                    'sustain': primary_config.envelope['sustain'],
                    'release': primary_config.envelope['release'] * 1.2
                },
                filter={'type': filter_type, 'frequency': filter_freq * 0.7, 'Q': 0.7},
                duration=config.duration
            )
            
            low_signal = self.create_oscillator(f"{pad_id}_low", low_config)
            layers.append((low_signal, low_config.position))
        
        # Collect all layers with positions
        for i, signal in enumerate(primary_signals):
            pos = self.oscillators[f"{pad_id}_primary"]['positions'][i]
            layers.append((signal, pos))
            
        if config.frequency > 100:
            for i, signal in enumerate(high_signals):
                pos = self.oscillators[f"{pad_id}_high"]['positions'][i]
                layers.append((signal, pos))
        
        # Store pad data
        self.oscillators[pad_id] = {
            'config': config,
            'layers': layers,
            'type': 'pad'
        }
        
        return layers
    
    def add_lfo_modulation(self, target_id: str, lfo_type: str = "sine",
                         frequency: float = 0.1, depth: float = 10,
                         parameter: str = "frequency") -> np.ndarray:
        """
        Add LFO modulation to an existing oscillator parameter.
        
        Args:
            target_id: ID of the oscillator to modulate
            lfo_type: Type of LFO waveform
            frequency: LFO frequency in Hz
            depth: Modulation depth
            parameter: Parameter to modulate ("frequency", "position", "volume")
            
        Returns:
            LFO signal
        """
        if target_id not in self.oscillators:
            raise ValueError(f"Oscillator {target_id} does not exist")
            
        # Get target oscillator data
        target_data = self.oscillators[target_id]
        config = target_data['config']
        
        # Generate LFO signal
        duration = config.duration if config.duration else 1.0
        num_samples = int(duration * self.sample_rate)
        t = np.linspace(0, duration, num_samples, endpoint=False)
        
        if lfo_type == "sine":
            lfo_signal = np.sin(2 * np.pi * frequency * t) * depth
        elif lfo_type == "triangle":
            lfo_signal = 2 * np.arcsin(np.sin(2 * np.pi * frequency * t)) / np.pi * depth
        else:
            lfo_signal = np.sin(2 * np.pi * frequency * t) * depth
            
        # Store LFO data
        if 'lfos' not in target_data:
            target_data['lfos'] = {}
            
        target_data['lfos'][parameter] = {
            'signal': lfo_signal,
            'frequency': frequency,
            'depth': depth,
            'type': lfo_type
        }
        
        return lfo_signal
    
    def _apply_envelope(self, signal: np.ndarray, envelope: Dict[str, float]) -> np.ndarray:
        """
        Apply ADSR envelope to signal.
        Preserves Weather-Tune's envelope shaping approach.
        """
        attack = envelope.get('attack', 0.1)
        decay = envelope.get('decay', 0.2)
        sustain = envelope.get('sustain', 0.7)
        release = envelope.get('release', 0.3)
        
        num_samples = len(signal)
        sample_rate = self.sample_rate
        
        # Calculate sample counts for each stage
        attack_samples = int(attack * sample_rate)
        decay_samples = int(decay * sample_rate)
        release_samples = int(release * sample_rate)
        sustain_samples = num_samples - attack_samples - decay_samples - release_samples
        
        if sustain_samples < 0:
            # Adjust if total envelope time exceeds signal duration
            ratio = num_samples / (attack_samples + decay_samples + release_samples)
            attack_samples = int(attack_samples * ratio)
            decay_samples = int(decay_samples * ratio)
            release_samples = int(release_samples * ratio)
            sustain_samples = num_samples - attack_samples - decay_samples - release_samples
        
        # Create envelope
        envelope_signal = np.ones_like(signal)
        current_sample = 0
        
        # Attack phase - curved for more natural sound
        if attack_samples > 0:
            if attack < 0.01:
                # Very fast attack - linear
                envelope_signal[current_sample:current_sample + attack_samples] = \
                    np.linspace(0, 1, attack_samples)
            else:
                # Curved attack using exponential
                t = np.linspace(0, 1, attack_samples)
                envelope_signal[current_sample:current_sample + attack_samples] = \
                    1 - np.exp(-5 * t)  # Exponential curve
            current_sample += attack_samples
        
        # Decay phase - curved
        if decay_samples > 0:
            t = np.linspace(0, 1, decay_samples)
            envelope_signal[current_sample:current_sample + decay_samples] = \
                sustain + (1 - sustain) * np.exp(-5 * t)
            current_sample += decay_samples
        
        # Sustain phase
        if sustain_samples > 0:
            envelope_signal[current_sample:current_sample + sustain_samples] = sustain
            current_sample += sustain_samples
        
        # Release phase - curved
        if release_samples > 0 and current_sample < num_samples:
            remaining_samples = min(release_samples, num_samples - current_sample)
            t = np.linspace(0, 1, remaining_samples)
            envelope_signal[current_sample:current_sample + remaining_samples] = \
                sustain * np.exp(-5 * t)
        
        return signal * envelope_signal
    
    def _apply_filter(self, signal: np.ndarray, filter_config: Dict[str, float]) -> np.ndarray:
        """
        Apply filter to signal.
        Simple implementation using basic DSP techniques.
        """
        filter_type = filter_config.get('type', 'lowpass')
        frequency = filter_config.get('frequency', 2000)
        Q = filter_config.get('Q', 1)
        
        # For now, implement a simple RC-style lowpass filter
        if filter_type == 'lowpass':
            # Simple one-pole lowpass filter
            rc = 1.0 / (2.0 * np.pi * frequency)
            dt = 1.0 / self.sample_rate
            alpha = dt / (rc + dt)
            
            filtered = np.zeros_like(signal)
            filtered[0] = signal[0]
            
            for i in range(1, len(signal)):
                filtered[i] = alpha * signal[i] + (1 - alpha) * filtered[i-1]
                
            return filtered
        
        elif filter_type == 'highpass':
            # Simple one-pole highpass filter
            rc = 1.0 / (2.0 * np.pi * frequency)
            dt = 1.0 / self.sample_rate
            alpha = rc / (rc + dt)
            
            filtered = np.zeros_like(signal)
            filtered[0] = signal[0]
            
            for i in range(1, len(signal)):
                filtered[i] = alpha * (filtered[i-1] + signal[i] - signal[i-1])
                
            return filtered
        
        else:
            # For other filter types, return unfiltered for now
            return signal
    
    def create_movement_path(self, movement: SpatialMovement, duration: float) -> List[Tuple[float, float, float]]:
        """
        Create a movement path for dynamic positioning.
        
        Args:
            movement: Movement configuration
            duration: Duration of movement in seconds
            
        Returns:
            List of positions over time
        """
        num_samples = int(duration * self.sample_rate)
        positions = []
        
        for i in range(num_samples):
            t = i / self.sample_rate
            
            if movement.pattern == "circular":
                # Circular movement in azimuth
                angle = 2 * np.pi * movement.speed * t
                azimuth = movement.center[0] + movement.radius * np.sin(angle)
                elevation = movement.center[1]
                distance = movement.center[2]
                
            elif movement.pattern == "arc":
                # Arc movement in elevation
                progress = (movement.speed * t) % 1.0
                elevation = movement.center[1] + movement.radius * np.sin(np.pi * progress)
                azimuth = movement.center[0]
                distance = movement.center[2]
                
            elif movement.pattern == "random":
                # Smooth random movement
                # Use low-frequency noise for smooth motion
                noise_freq = 0.5  # Hz
                azimuth = movement.center[0] + movement.radius * 0.3 * np.sin(2 * np.pi * noise_freq * t)
                elevation = movement.center[1] + movement.radius * 0.2 * np.sin(2 * np.pi * noise_freq * 1.3 * t)
                distance = movement.center[2] + movement.radius * 0.1 * np.sin(2 * np.pi * noise_freq * 0.7 * t)
                
            else:  # static
                azimuth, elevation, distance = movement.center
            
            positions.append((azimuth, elevation, distance))
        
        return positions