0

I am trying to calculate the Bispectrum score (BSS) for the audio data frame array; the definition of this feature can be found here:

enter image description here

And the formulas can be found here:

enter image description here

The implementation I am using is:

import numpy as np
import numba as nb
from time import time


def bispectrum(*u, ntheta=None, kmin=None, kmax=None,
               diagnostics=True, error=False,
               nsamples=None, sample_thresh=None,
               compute_fft=True, exclude_upper=False, use_pyfftw=False,
               bench=False, progress=False, **kwargs):

    
    shape, ndim = nb.typed.List(u[0].shape), u[0].ndim
    ncomp = len(u)

    if ncomp not in [1, 3]:
        raise ValueError("Pass either 1 scalar field or 3 vector components.")

    if ndim not in [2, 3]:
        raise ValueError("Data must be 2D or 3D.")

    # Geometry of output image
    kmax = int(max(shape)/2) if kmax is None else int(kmax)
    kmin = 1 if kmin is None else int(kmin)
    kn = np.arange(kmin, kmax+1, 1, dtype=int)
    dim = kn.size
    theta = np.arange(0, np.pi, np.pi/ntheta) if ntheta is not None else None
    # ...make costheta monotonically increase
    costheta = np.flip(np.cos(theta)) if theta is not None else np.array([1.])

    # theta = 0 should be included
    if theta is not None:
        costheta[-1] += 1e-5

    if bench:
        t0 = time()

    # Get binned radial coordinates of FFT
    kv = np.meshgrid(*([np.fft.fftfreq(Ni).astype(np.float32)*Ni
                        for Ni in shape]), indexing="ij")
    kr = np.zeros_like(kv[0])
    for i in range(ndim):
        kr[...] += kv[i]**2
    kr[...] = np.sqrt(kr)

    kcoords = nb.typed.List()
    for i in range(ndim):
        temp = kv[i].astype(np.int16).ravel()
        kcoords.append(temp)

    del kv, temp

    kbins = np.arange(int(np.ceil(kr.max())))
    kbinned = (np.digitize(kr, kbins)-1).astype(np.int16)

    del kr

    # Enumerate indices in each bin
    k1bins, k2bins = nb.typed.List(), nb.typed.List()
    for ki in kn:
        mask = kbinned == ki
        temp1 = np.where(mask)
        temp2 = np.where(mask[..., :shape[-1]//2+1])
        k1bins.append(np.ravel_multi_index(temp1, shape))
        k2bins.append(np.ravel_multi_index(temp2, shape))

    del kbinned

    # FFT
    ffts = []
    for i in range(ncomp):
        if compute_fft:
            temp = u[i]
            if use_pyfftw:
                fft = _fftn(temp, **kwargs)
            else:
                fft = np.fft.rfftn(temp, **kwargs)
            del temp
        else:
            fft = u[i][..., :shape[-1]//2+1]
        ffts.append(fft)
        del fft

    # Sampling settings
    if sample_thresh is None:
        sample_thresh = np.iinfo(np.int64).max
    if nsamples is None:
        nsamples = np.iinfo(np.int64).max
        sample_thresh = np.iinfo(np.int64).max

    # Sampling mask
    if np.issubdtype(type(nsamples), np.integer):
        nsamples = np.full((dim, dim), nsamples, dtype=np.int_)
    elif np.issubdtype(type(nsamples), np.floating):
        nsamples = np.full((dim, dim), nsamples)
    elif type(nsamples) is np.ndarray:
        if np.issubdtype(nsamples.dtype, np.integer):
            nsamples = nsamples.astype(np.int_)

    # Run main loop
    compute_point = eval(f"_compute_point{ndim}D")
    args = (k1bins, k2bins, kn, costheta, kcoords,
            nsamples, sample_thresh, ndim, dim, shape,
            progress, exclude_upper, error, compute_point, *ffts)
    B, norm, omega, counts, stderr = _compute_bispectrum(*args)

    # Set zero values to nan values for division
    mask = counts == 0.
    norm[mask] = np.nan
    counts[mask] = np.nan

    # Get bicoherence and average bispectrum
    b = np.abs(B) / norm
    B.real /= counts
    B.imag /= counts

    # Prepare diagnostics
    if error:
        stderr[counts <= 1.] = np.nan

    # Switch back to theta monotonically increasing
    if ntheta is not None:
        B[...] = np.flip(B, axis=0)
        b[...] = np.flip(b, axis=0)
        if diagnostics:
            counts[...] = np.flip(counts, axis=0)
            if error:
                stderr[...] = np.flip(stderr, axis=0)
    else:
        B, b = B[0], b[0]
        if diagnostics:
            counts = counts[0]
            if error:
                stderr = stderr[0]

    if bench:
        print(f"Time: {time() - t0:.04f} s")

    result = [B, b, kn]
    if ntheta is not None:
        result.append(theta)
    if diagnostics:
        result.extend([counts, omega])
        if error:
            result.append(stderr)

    return tuple(result)


def _fftn(image, overwrite_input=False, threads=-1, **kwargs):
    """
    Calculate N-dimensional fft of image with pyfftw.
    See pyfftw.builders.fftn for kwargs documentation.
    Parameters
    ----------
    image : np.ndarray
        Real or complex-valued 2D or 3D image
    overwrite_input : bool, optional
        Specify whether input data can be destroyed.
        This is useful for reducing memory usage.
        See pyfftw.builders.fftn for more.
    threads : int, optional
        Number of threads for pyfftw to use. Default
        is number of cores.
    Returns
    -------
    fft : np.ndarray
        The fft. Will be the shape of the input image
        or the user specified shape.
    """
    import pyfftw

    if image.dtype in [np.complex64, np.complex128]:
        dtype = 'complex128'
        fftn = pyfftw.builders.fftn
    elif image.dtype in [np.float32, np.float64]:
        dtype = 'float64'
        fftn = pyfftw.builders.rfftn
    else:
        raise ValueError(f"{data.dtype} is unrecognized data type.")

    a = pyfftw.empty_aligned(image.shape, dtype=dtype)
    f = fftn(a, threads=threads, overwrite_input=overwrite_input, **kwargs)
    a[...] = image
    fft = f()

    del a, fftn

    return fft


@nb.njit(parallel=True)
def _compute_bispectrum(k1bins, k2bins, kn, costheta, kcoords, nsamples,
                        sample_thresh, ndim, dim, shape, progress,
                        exclude, error, compute_point, *ffts):
    knyq = max(shape) // 2
    ntheta = costheta.size
    nffts = len(ffts)
    bispec = np.full((ntheta, dim, dim), np.nan+1.j*np.nan, dtype=np.complex128)
    binorm = np.full((ntheta, dim, dim), np.nan, dtype=np.float64)
    counts = np.full((ntheta, dim, dim), np.nan, dtype=np.float64)
    omega = np.zeros((dim, dim), dtype=np.int64)
    if error:
        stderr = np.full((ntheta, dim, dim), np.nan, dtype=np.float64)
    else:
        stderr = np.zeros((1, 1, 1), dtype=np.float64)
    for i in range(dim):
        k1 = kn[i]
        k1ind = k1bins[i]
        nk1 = k1ind.size
        dim2 = dim if nffts > 1 else i+1
        for j in range(dim2):
            k2 = kn[j]
            if ntheta == 1 and (exclude and k1 + k2 > knyq):
                continue
            k2ind = k2bins[j]
            nk2 = k2ind.size
            nsamp = nsamples[i, j]
            nsamp = int(nsamp) if type(nsamp) is np.int64 \
                else max(int(nsamp*nk1*nk2), 1)
            if nsamp < nk1*nk2 or nsamp > sample_thresh:
                samp = np.random.randint(0, nk1*nk2, size=nsamp)
                count = nsamp
            else:
                samp = np.arange(nk1*nk2)
                count = nk1*nk2
            bispecbuf = np.zeros(count, dtype=np.complex128)
            binormbuf = np.zeros(count, dtype=np.float64)
            cthetabuf = np.zeros(count, dtype=np.float64) if ntheta > 1 \
                else np.array([0.], dtype=np.float64)
            countbuf = np.zeros(count, dtype=np.float64)
            compute_point(k1ind, k2ind, kcoords, ntheta,
                          nk1, nk2, shape, samp, count,
                          bispecbuf, binormbuf, cthetabuf, countbuf,
                          *ffts)
            if ntheta == 1:
                _fill_sum(i, j, bispec, binorm, counts, stderr,
                          bispecbuf, binormbuf, countbuf, nffts, error)
            else:
                binned = np.searchsorted(costheta, cthetabuf)
                _fill_binned_sum(i, j, ntheta, binned, bispec, binorm,
                                 counts, stderr, bispecbuf, binormbuf,
                                 countbuf, nffts, error)
            omega[i, j] = nk1*nk2
            if nffts == 1:
                omega[j, i] = nk1*nk2
        if progress:
            with nb.objmode():
                _printProgressBar(i, dim-1)
    return bispec, binorm, omega, counts, stderr


@nb.njit(parallel=True, cache=True)
def _fill_sum(i, j, bispec, binorm, counts, stderr,
              bispecbuf, binormbuf, countbuf, nffts, error):
    N = countbuf.sum()
    norm = binormbuf.sum()
    value = bispecbuf.sum()
    bispec[0, i, j] = value
    binorm[0, i, j] = norm
    counts[0, i, j] = N
    if nffts == 1:
        bispec[0, j, i] = value
        binorm[0, j, i] = norm
        counts[0, j, i] = N
    if error and N > 1:
        variance = np.abs(bispecbuf - (value / N))**2
        err = np.sqrt(variance.sum() / (N*(N - 1)))
        stderr[0, i, j] = err
        if nffts == 1:
            stderr[0, j, i] = err


@nb.njit(parallel=True, cache=True)
def _fill_binned_sum(i, j, ntheta, binned, bispec, binorm, counts,
                     stderr, bispecbuf, binormbuf, countbuf, nffts, error):
    N = np.bincount(binned, weights=countbuf, minlength=ntheta)
    norm = np.bincount(binned, weights=binormbuf, minlength=ntheta)
    value = np.bincount(binned, weights=bispecbuf.real, minlength=ntheta) +\
        1.j*np.bincount(binned, weights=bispecbuf.imag, minlength=ntheta)
    bispec[:, i, j] = value
    binorm[:, i, j] = norm
    counts[:, i, j] = N
    if nffts == 1:
        bispec[:, j, i] = value
        binorm[:, j, i] = norm
        counts[:, j, i] = N
    if error:
        variance = np.zeros_like(countbuf)
        for n in range(ntheta):
            if N[n] > 1:
                idxs = np.where(binned == n)
                mean = value[n] / N[n]
                variance[idxs] = np.abs(bispecbuf[idxs] - mean)**2 / (N[n]*(N[n]-1))
        err = np.sqrt(np.bincount(binned, weights=variance, minlength=ntheta))
        stderr[:, i, j] = err
        if nffts == 1:
            stderr[:, j, i] = err


@nb.njit(parallel=True, cache=True)
def _compute_point3D(k1ind, k2ind, kcoords, ntheta, nk1, nk2, shape,
                     samp, count, bispecbuf, binormbuf,
                     cthetabuf, countbuf, *ffts):
    kx, ky, kz = kcoords[0], kcoords[1], kcoords[2]
    Nx, Ny, Nz = shape[0], shape[1], shape[2]
    nffts = len(ffts)
    fft1, fft2, fft3 = [ffts[0], ffts[0], ffts[0]] if nffts == 1 else ffts
    for idx in nb.prange(count):
        n, m = k1ind[samp[idx] % nk1], k2ind[samp[idx] // nk1]
        k1x, k1y, k1z = kx[n], ky[n], kz[n]
        k2x, k2y, k2z = kx[m], ky[m], kz[m]
        k3x, k3y, k3z = k1x+k2x, k1y+k2y, k1z+k2z
        if np.abs(k3x) > Nx//2 or np.abs(k3y) > Ny//2 or np.abs(k3z) > Nz//2:
            continue
        s1 = fft1[k1x, k1y, k1z] if k1z >= 0 \
            else np.conj(fft1[-k1x, -k1y, -k1z])
        s2 = fft2[k2x, k2y, k2z] if k2z >= 0 \
            else np.conj(fft2[-k2x, -k2y, -k2z])
        s3 = np.conj(fft3[k3x, k3y, k3z]) if k3z >= 0 \
            else fft3[-k3x, -k3y, -k3z]
        sample = s1*s2*s3
        norm = np.abs(sample)
        bispecbuf[idx] = sample
        binormbuf[idx] = norm
        countbuf[idx] = 1
        if ntheta > 1:
            k1dotk2 = k1x*k2x+k1y*k2y+k1z*k2z
            k1norm, k2norm = np.sqrt(k1x**2+k1y**2+k1z**2), np.sqrt(k2x**2+k2y**2+k2z**2)
            costheta = k1dotk2 / (k1norm*k2norm)
            cthetabuf[idx] = costheta


@nb.njit(parallel=True, cache=True)
def _compute_point2D(k1ind, k2ind, kcoords, ntheta, nk1, nk2, shape,
                     samp, count, bispecbuf, binormbuf,
                     cthetabuf, countbuf, *ffts):
    kx, ky = kcoords[0], kcoords[1]
    Nx, Ny = shape[0], shape[1]
    nffts = len(ffts)
    fft1, fft2, fft3 = [ffts[0], ffts[0], ffts[0]] if nffts == 1 else ffts
    for idx in nb.prange(count):
        n, m = k1ind[samp[idx] % nk1], k2ind[samp[idx] // nk1]
        k1x, k1y = kx[n], ky[n]
        k2x, k2y = kx[m], ky[m]
        k3x, k3y = k1x+k2x, k1y+k2y
        if np.abs(k3x) > Nx//2 or np.abs(k3y) > Ny//2:
            continue
        s1 = fft1[k1x, k1y] if k1y >= 0 else np.conj(fft1[-k1x, -k1y])
        s2 = fft2[k2x, k2y] if k2y >= 0 else np.conj(fft2[-k2x, -k2y])
        s3 = np.conj(fft3[k3x, k3y]) if k3y >= 0 else fft3[-k3x, -k3y]
        sample = s1*s2*s3
        norm = np.abs(sample)
        bispecbuf[idx] = sample
        binormbuf[idx] = norm
        countbuf[idx] = 1
        if ntheta > 1:
            k1dotk2 = k1x*k2x+k1y*k2y
            k1norm, k2norm = np.sqrt(k1x**2+k1y**2), np.sqrt(k2x**2+k2y**2)
            costheta = k1dotk2 / (k1norm*k2norm)
            cthetabuf[idx] = costheta


@nb.jit(forceobj=True, cache=True)
def _printProgressBar(iteration, total, prefix='', suffix='', decimals=1,
                      length=50, fill='█', printEnd="\r"):
    """
    Call in a loop to create terminal progress bar
    Adapted from
    https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
    """
    prefix = '(%d/%d)' % (iteration, total) if prefix == '' else prefix
    percent = str("%."+str(decimals)+"f") % (100 * (iteration / float(total)))
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - filledLength)
    prog = '\r%s |%s| %s%s %s' % (prefix, bar, percent, '%', suffix)
    print(prog, end=printEnd, flush=True)
    if iteration == total:
        print()

Is this a correct implementation or is there any better and optimized method to calculate Bispectrum score (BSS) for audio data fragment array?

Thank you!

Aaditya Ura
  • 10,695
  • 7
  • 44
  • 70

0 Answers0