# AGPL-3 Karl Semich 2022
import numpy as np

def fftfreq(freq_count, sample_rate = None, min_freq = None, max_freq = None,
            dc_offset = True, complex = True, sample_time = None,
            repetition_rate = None, repetition_time = None, repetition_samples = None,
            freq_sample_rate = None, freq_sample_time = None):
    '''
    Calculates and returns the frequency bins used to convert between the time
    and frequency domains with a discrete Fourier transform.

    With no optional arguments, this function should be equivalent to numpy.fft.fftfreq .

    Parameters:
        - freq_count: the number of frequency bins to generate
        - sample_rate: the time-domain sample rate in which to treat frequency parameters
        - min_freq: the minimum frequency the signal contains
        - max_fraq: the maximum frequency the signal contains
        - dc_offset: whether or not to include a DC offset component (0 Hz)
        - complex: whether to generate negative frequencies
        - sample_time: sample_rate as the duration of a single sample
        - repetition_rate: min_freq as the repetition rate of a subsignal
        - repetition_time: min_freq as the period time of a subsignal
        - repetition_samples: min_freq as the period size of a subsignal in samples
        - freq_sample_rate: convert to or from a different frequency-domain sample rate
        - freq_sample_time: freq_sample_rate as the duration of a single sample

    Returns a vector of sinusoid time scalings that can be used to perform or
    analyse a discrete Fourier transform.
    '''
    assert not sample_time or not sample_rate
    assert not freq_sample_time or not freq_sample_rate
    assert (min_freq, repetition_rate, repetition_time, repetition_samples).count(None) >= 2
    if sample_time is not None:
        sample_rate = 1 / sample_time
    if freq_sample_time is not None:
        freq_sample_rate = 1 / freq_sample_time
    sample_rate = sample_rate or freq_sample_rate or 1
    freq_sample_rate = freq_sample_rate or sample_rate or 1
    if not dc_offset:
        freq_count += 1
    if not complex:
        freq_count = int((freq_count - 1) * 2)
    if not min_freq:
        if repetition_rate:
            min_freq = repetition_rate
        elif repetition_time:
            min_freq = 1 / repetition_time
        elif repetition_samples:
            min_freq = freq_sample_rate / repetition_samples
        else:
            min_freq = freq_sample_rate / freq_count
    min_freq /= sample_rate
    if freq_count % 2 == 0:
        if max_freq is None:
            max_freq = freq_sample_rate / 2
        max_freq /= sample_rate
        if complex:
            neg_freqs = np.linspace(-max_freq, -min_freq, num=freq_count // 2, endpoint=True)
            pos_freqs = -neg_freqs[:0:-1]
        else:
            pos_freqs = np.linspace(min_freq, max_freq, num=freq_count // 2, endpoint=True)
            neg_freqs = pos_freqs[:0]
    else:
        if max_freq is None:
            max_freq = freq_sample_rate * (freq_count - 1) / 2 / freq_count
        max_freq /= sample_rate
        pos_freqs = np.linspace(min_freq, max_freq, num=freq_count // 2, endpoint=True)
        neg_freqs = -pos_freqs[::-1] if complex else pos_freqs[:0]
        if complex:
            neg_freqs = -pos_freqs[::-1]
        else:
            neg_freqs = pos_freqs[:0]
    return np.concatenate([
        np.array([0] if dc_offset else []),
        pos_freqs,
        neg_freqs
    ])

def create_freq2time(time_count = None, freqs = None):
    '''
    Creates a matrix that will perform an inverse Fourier transform when it
    post-multiplies a vector of complex frequency magnitudes.

    Example
        time_data = spectrum @ create_freq2time(len(spectrum))

    Parameters:
        - time_count: size of the output vector, defaults to the frequency bincount
        - freqs: frequency bins to convert, defaults to a traditional IDFT for time_count

    Returns:
        - an inverse discrete Fourier matrix of shape (len(freqs), time_count)
    '''
    assert (time_count is not None) or (freqs is not None)
    if freqs is None:
        freqs = fftfreq(time_count)
    else:
        time_count = time_count or len(freqs)
    offsets = np.arange(time_count)
    mat = np.exp(2j * np.pi * np.outer(freqs, offsets))
    return mat / len(freqs) # scaled to match numpy convention

def create_time2freq(time_count = None, freqs = None):
    '''
    Creates a matrix that will perform a forward Fourier transform when it
    post-multiplies a vector of time series data.

    If time_count is too small or large, the minimal least squares solution
    over all the data passed will be produced.

    This function is equivalent to calling .pinv() on the return value of
    create_freq2time. If the return value is single use, it is more efficient and
    accurate to use numpy.linalg.lstsq .

    Example
        spectrum = time_data @ create_time2freq(len(time_data))

    Parameters:
        - time_count: size of the input vector, defaults to the frequency bincount
        - freqs: frequency bins to produce, defaults to a traditional DFT for time_count

    Returns:
        - a discrete Fourier matrix of shape (time_count, len(freqs))
    '''
    forward_mat = create_freq2time(time_count, freqs)
    reverse_mat = np.linalg.pinv(forward_mat)
    return reverse_mat

def peak_pair_idcs(freq_data):
    freq_heights = abs(freq_data) # squares and sums the components
    paired_heights = freq_height[...,1:-1] + freq_height[...,2:]
    peak_idx = paired_heights.argmax(axis=-1, keepdims=True) + 1
    return np.concatenate(peak_idx, peak_idx + 1, axis=-1)

def test():
    randvec = np.random.random(16)
    ift16 = create_freq2time(16)
    ft16 = create_time2freq(16)
    randvec2time = randvec @ ift16
    randvec2freq = randvec @ ft16
    randvec2ifft = np.fft.ifft(randvec)
    randvec2fft = np.fft.fft(randvec)
    assert np.allclose(randvec2ifft, randvec2time)
    assert np.allclose(randvec2fft, randvec2freq)
    assert np.allclose(randvec2ifft, np.linalg.solve(ft16.T, randvec))
    assert np.allclose(randvec2fft, np.linalg.solve(ift16.T, randvec))
    
    # sample data at a differing rate
    time_rate = np.random.random() * 2
    freq_rate = 1.0
    freqs = np.fft.fftfreq(len(randvec))
    rescaling_freqs = fftfreq(len(randvec), freq_sample_rate = freq_rate, sample_rate = time_rate)
    rescaling_ift = create_freq2time(freqs = rescaling_freqs)
    rescaling_ft = create_time2freq(freqs = rescaling_freqs)
    rescaled_time_data = np.array([
        np.mean([
            randvec[freqidx] * np.exp(2j * np.pi * freqs[freqidx] * sampleidx / time_rate)
            for freqidx in range(len(randvec))
        ])
        for sampleidx in range(len(randvec))
    ])
    assert np.allclose(rescaled_time_data, randvec @ rescaling_ift)
    assert np.allclose(rescaled_time_data, np.linalg.solve(rescaling_ft.T, randvec))
    unscaled_freq_data = rescaled_time_data @ rescaling_ft
    unscaled_time_data = unscaled_freq_data @ ift16
    assert np.allclose(unscaled_freq_data, randvec)
    assert np.allclose(unscaled_time_data, randvec2time)
    assert np.allclose(np.linalg.solve(rescaling_ift.T, rescaled_time_data), randvec)

if __name__ == '__main__':
    test()
