I am trying to calculate the Bispectrum score (BSS) for the audio data frame array; the definition of this feature can be found here:
And the formulas can be found here:
The implementation I am using is:
import numpy as np
import numba as nb
from time import time
def bispectrum(*u, ntheta=None, kmin=None, kmax=None,
diagnostics=True, error=False,
nsamples=None, sample_thresh=None,
compute_fft=True, exclude_upper=False, use_pyfftw=False,
bench=False, progress=False, **kwargs):
shape, ndim = nb.typed.List(u[0].shape), u[0].ndim
ncomp = len(u)
if ncomp not in [1, 3]:
raise ValueError("Pass either 1 scalar field or 3 vector components.")
if ndim not in [2, 3]:
raise ValueError("Data must be 2D or 3D.")
# Geometry of output image
kmax = int(max(shape)/2) if kmax is None else int(kmax)
kmin = 1 if kmin is None else int(kmin)
kn = np.arange(kmin, kmax+1, 1, dtype=int)
dim = kn.size
theta = np.arange(0, np.pi, np.pi/ntheta) if ntheta is not None else None
# ...make costheta monotonically increase
costheta = np.flip(np.cos(theta)) if theta is not None else np.array([1.])
# theta = 0 should be included
if theta is not None:
costheta[-1] += 1e-5
if bench:
t0 = time()
# Get binned radial coordinates of FFT
kv = np.meshgrid(*([np.fft.fftfreq(Ni).astype(np.float32)*Ni
for Ni in shape]), indexing="ij")
kr = np.zeros_like(kv[0])
for i in range(ndim):
kr[...] += kv[i]**2
kr[...] = np.sqrt(kr)
kcoords = nb.typed.List()
for i in range(ndim):
temp = kv[i].astype(np.int16).ravel()
kcoords.append(temp)
del kv, temp
kbins = np.arange(int(np.ceil(kr.max())))
kbinned = (np.digitize(kr, kbins)-1).astype(np.int16)
del kr
# Enumerate indices in each bin
k1bins, k2bins = nb.typed.List(), nb.typed.List()
for ki in kn:
mask = kbinned == ki
temp1 = np.where(mask)
temp2 = np.where(mask[..., :shape[-1]//2+1])
k1bins.append(np.ravel_multi_index(temp1, shape))
k2bins.append(np.ravel_multi_index(temp2, shape))
del kbinned
# FFT
ffts = []
for i in range(ncomp):
if compute_fft:
temp = u[i]
if use_pyfftw:
fft = _fftn(temp, **kwargs)
else:
fft = np.fft.rfftn(temp, **kwargs)
del temp
else:
fft = u[i][..., :shape[-1]//2+1]
ffts.append(fft)
del fft
# Sampling settings
if sample_thresh is None:
sample_thresh = np.iinfo(np.int64).max
if nsamples is None:
nsamples = np.iinfo(np.int64).max
sample_thresh = np.iinfo(np.int64).max
# Sampling mask
if np.issubdtype(type(nsamples), np.integer):
nsamples = np.full((dim, dim), nsamples, dtype=np.int_)
elif np.issubdtype(type(nsamples), np.floating):
nsamples = np.full((dim, dim), nsamples)
elif type(nsamples) is np.ndarray:
if np.issubdtype(nsamples.dtype, np.integer):
nsamples = nsamples.astype(np.int_)
# Run main loop
compute_point = eval(f"_compute_point{ndim}D")
args = (k1bins, k2bins, kn, costheta, kcoords,
nsamples, sample_thresh, ndim, dim, shape,
progress, exclude_upper, error, compute_point, *ffts)
B, norm, omega, counts, stderr = _compute_bispectrum(*args)
# Set zero values to nan values for division
mask = counts == 0.
norm[mask] = np.nan
counts[mask] = np.nan
# Get bicoherence and average bispectrum
b = np.abs(B) / norm
B.real /= counts
B.imag /= counts
# Prepare diagnostics
if error:
stderr[counts <= 1.] = np.nan
# Switch back to theta monotonically increasing
if ntheta is not None:
B[...] = np.flip(B, axis=0)
b[...] = np.flip(b, axis=0)
if diagnostics:
counts[...] = np.flip(counts, axis=0)
if error:
stderr[...] = np.flip(stderr, axis=0)
else:
B, b = B[0], b[0]
if diagnostics:
counts = counts[0]
if error:
stderr = stderr[0]
if bench:
print(f"Time: {time() - t0:.04f} s")
result = [B, b, kn]
if ntheta is not None:
result.append(theta)
if diagnostics:
result.extend([counts, omega])
if error:
result.append(stderr)
return tuple(result)
def _fftn(image, overwrite_input=False, threads=-1, **kwargs):
"""
Calculate N-dimensional fft of image with pyfftw.
See pyfftw.builders.fftn for kwargs documentation.
Parameters
----------
image : np.ndarray
Real or complex-valued 2D or 3D image
overwrite_input : bool, optional
Specify whether input data can be destroyed.
This is useful for reducing memory usage.
See pyfftw.builders.fftn for more.
threads : int, optional
Number of threads for pyfftw to use. Default
is number of cores.
Returns
-------
fft : np.ndarray
The fft. Will be the shape of the input image
or the user specified shape.
"""
import pyfftw
if image.dtype in [np.complex64, np.complex128]:
dtype = 'complex128'
fftn = pyfftw.builders.fftn
elif image.dtype in [np.float32, np.float64]:
dtype = 'float64'
fftn = pyfftw.builders.rfftn
else:
raise ValueError(f"{data.dtype} is unrecognized data type.")
a = pyfftw.empty_aligned(image.shape, dtype=dtype)
f = fftn(a, threads=threads, overwrite_input=overwrite_input, **kwargs)
a[...] = image
fft = f()
del a, fftn
return fft
@nb.njit(parallel=True)
def _compute_bispectrum(k1bins, k2bins, kn, costheta, kcoords, nsamples,
sample_thresh, ndim, dim, shape, progress,
exclude, error, compute_point, *ffts):
knyq = max(shape) // 2
ntheta = costheta.size
nffts = len(ffts)
bispec = np.full((ntheta, dim, dim), np.nan+1.j*np.nan, dtype=np.complex128)
binorm = np.full((ntheta, dim, dim), np.nan, dtype=np.float64)
counts = np.full((ntheta, dim, dim), np.nan, dtype=np.float64)
omega = np.zeros((dim, dim), dtype=np.int64)
if error:
stderr = np.full((ntheta, dim, dim), np.nan, dtype=np.float64)
else:
stderr = np.zeros((1, 1, 1), dtype=np.float64)
for i in range(dim):
k1 = kn[i]
k1ind = k1bins[i]
nk1 = k1ind.size
dim2 = dim if nffts > 1 else i+1
for j in range(dim2):
k2 = kn[j]
if ntheta == 1 and (exclude and k1 + k2 > knyq):
continue
k2ind = k2bins[j]
nk2 = k2ind.size
nsamp = nsamples[i, j]
nsamp = int(nsamp) if type(nsamp) is np.int64 \
else max(int(nsamp*nk1*nk2), 1)
if nsamp < nk1*nk2 or nsamp > sample_thresh:
samp = np.random.randint(0, nk1*nk2, size=nsamp)
count = nsamp
else:
samp = np.arange(nk1*nk2)
count = nk1*nk2
bispecbuf = np.zeros(count, dtype=np.complex128)
binormbuf = np.zeros(count, dtype=np.float64)
cthetabuf = np.zeros(count, dtype=np.float64) if ntheta > 1 \
else np.array([0.], dtype=np.float64)
countbuf = np.zeros(count, dtype=np.float64)
compute_point(k1ind, k2ind, kcoords, ntheta,
nk1, nk2, shape, samp, count,
bispecbuf, binormbuf, cthetabuf, countbuf,
*ffts)
if ntheta == 1:
_fill_sum(i, j, bispec, binorm, counts, stderr,
bispecbuf, binormbuf, countbuf, nffts, error)
else:
binned = np.searchsorted(costheta, cthetabuf)
_fill_binned_sum(i, j, ntheta, binned, bispec, binorm,
counts, stderr, bispecbuf, binormbuf,
countbuf, nffts, error)
omega[i, j] = nk1*nk2
if nffts == 1:
omega[j, i] = nk1*nk2
if progress:
with nb.objmode():
_printProgressBar(i, dim-1)
return bispec, binorm, omega, counts, stderr
@nb.njit(parallel=True, cache=True)
def _fill_sum(i, j, bispec, binorm, counts, stderr,
bispecbuf, binormbuf, countbuf, nffts, error):
N = countbuf.sum()
norm = binormbuf.sum()
value = bispecbuf.sum()
bispec[0, i, j] = value
binorm[0, i, j] = norm
counts[0, i, j] = N
if nffts == 1:
bispec[0, j, i] = value
binorm[0, j, i] = norm
counts[0, j, i] = N
if error and N > 1:
variance = np.abs(bispecbuf - (value / N))**2
err = np.sqrt(variance.sum() / (N*(N - 1)))
stderr[0, i, j] = err
if nffts == 1:
stderr[0, j, i] = err
@nb.njit(parallel=True, cache=True)
def _fill_binned_sum(i, j, ntheta, binned, bispec, binorm, counts,
stderr, bispecbuf, binormbuf, countbuf, nffts, error):
N = np.bincount(binned, weights=countbuf, minlength=ntheta)
norm = np.bincount(binned, weights=binormbuf, minlength=ntheta)
value = np.bincount(binned, weights=bispecbuf.real, minlength=ntheta) +\
1.j*np.bincount(binned, weights=bispecbuf.imag, minlength=ntheta)
bispec[:, i, j] = value
binorm[:, i, j] = norm
counts[:, i, j] = N
if nffts == 1:
bispec[:, j, i] = value
binorm[:, j, i] = norm
counts[:, j, i] = N
if error:
variance = np.zeros_like(countbuf)
for n in range(ntheta):
if N[n] > 1:
idxs = np.where(binned == n)
mean = value[n] / N[n]
variance[idxs] = np.abs(bispecbuf[idxs] - mean)**2 / (N[n]*(N[n]-1))
err = np.sqrt(np.bincount(binned, weights=variance, minlength=ntheta))
stderr[:, i, j] = err
if nffts == 1:
stderr[:, j, i] = err
@nb.njit(parallel=True, cache=True)
def _compute_point3D(k1ind, k2ind, kcoords, ntheta, nk1, nk2, shape,
samp, count, bispecbuf, binormbuf,
cthetabuf, countbuf, *ffts):
kx, ky, kz = kcoords[0], kcoords[1], kcoords[2]
Nx, Ny, Nz = shape[0], shape[1], shape[2]
nffts = len(ffts)
fft1, fft2, fft3 = [ffts[0], ffts[0], ffts[0]] if nffts == 1 else ffts
for idx in nb.prange(count):
n, m = k1ind[samp[idx] % nk1], k2ind[samp[idx] // nk1]
k1x, k1y, k1z = kx[n], ky[n], kz[n]
k2x, k2y, k2z = kx[m], ky[m], kz[m]
k3x, k3y, k3z = k1x+k2x, k1y+k2y, k1z+k2z
if np.abs(k3x) > Nx//2 or np.abs(k3y) > Ny//2 or np.abs(k3z) > Nz//2:
continue
s1 = fft1[k1x, k1y, k1z] if k1z >= 0 \
else np.conj(fft1[-k1x, -k1y, -k1z])
s2 = fft2[k2x, k2y, k2z] if k2z >= 0 \
else np.conj(fft2[-k2x, -k2y, -k2z])
s3 = np.conj(fft3[k3x, k3y, k3z]) if k3z >= 0 \
else fft3[-k3x, -k3y, -k3z]
sample = s1*s2*s3
norm = np.abs(sample)
bispecbuf[idx] = sample
binormbuf[idx] = norm
countbuf[idx] = 1
if ntheta > 1:
k1dotk2 = k1x*k2x+k1y*k2y+k1z*k2z
k1norm, k2norm = np.sqrt(k1x**2+k1y**2+k1z**2), np.sqrt(k2x**2+k2y**2+k2z**2)
costheta = k1dotk2 / (k1norm*k2norm)
cthetabuf[idx] = costheta
@nb.njit(parallel=True, cache=True)
def _compute_point2D(k1ind, k2ind, kcoords, ntheta, nk1, nk2, shape,
samp, count, bispecbuf, binormbuf,
cthetabuf, countbuf, *ffts):
kx, ky = kcoords[0], kcoords[1]
Nx, Ny = shape[0], shape[1]
nffts = len(ffts)
fft1, fft2, fft3 = [ffts[0], ffts[0], ffts[0]] if nffts == 1 else ffts
for idx in nb.prange(count):
n, m = k1ind[samp[idx] % nk1], k2ind[samp[idx] // nk1]
k1x, k1y = kx[n], ky[n]
k2x, k2y = kx[m], ky[m]
k3x, k3y = k1x+k2x, k1y+k2y
if np.abs(k3x) > Nx//2 or np.abs(k3y) > Ny//2:
continue
s1 = fft1[k1x, k1y] if k1y >= 0 else np.conj(fft1[-k1x, -k1y])
s2 = fft2[k2x, k2y] if k2y >= 0 else np.conj(fft2[-k2x, -k2y])
s3 = np.conj(fft3[k3x, k3y]) if k3y >= 0 else fft3[-k3x, -k3y]
sample = s1*s2*s3
norm = np.abs(sample)
bispecbuf[idx] = sample
binormbuf[idx] = norm
countbuf[idx] = 1
if ntheta > 1:
k1dotk2 = k1x*k2x+k1y*k2y
k1norm, k2norm = np.sqrt(k1x**2+k1y**2), np.sqrt(k2x**2+k2y**2)
costheta = k1dotk2 / (k1norm*k2norm)
cthetabuf[idx] = costheta
@nb.jit(forceobj=True, cache=True)
def _printProgressBar(iteration, total, prefix='', suffix='', decimals=1,
length=50, fill='█', printEnd="\r"):
"""
Call in a loop to create terminal progress bar
Adapted from
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
"""
prefix = '(%d/%d)' % (iteration, total) if prefix == '' else prefix
percent = str("%."+str(decimals)+"f") % (100 * (iteration / float(total)))
filledLength = int(length * iteration // total)
bar = fill * filledLength + '-' * (length - filledLength)
prog = '\r%s |%s| %s%s %s' % (prefix, bar, percent, '%', suffix)
print(prog, end=printEnd, flush=True)
if iteration == total:
print()
Is this a correct implementation or is there any better and optimized method to calculate Bispectrum score (BSS) for audio data fragment array?
Thank you!