Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why Gabor wavelet does not implemented in CWT? #637

Open
kaz0120 opened this issue Jun 8, 2022 · 2 comments
Open

Why Gabor wavelet does not implemented in CWT? #637

kaz0120 opened this issue Jun 8, 2022 · 2 comments
Labels

Comments

@kaz0120
Copy link

kaz0120 commented Jun 8, 2022

Hi there, at first, sorry for my poor English.
I want to apply CWT on my signal data by Python, i found this library; PyWavelets.
I used Morlet mother-wavelet, and get great result. Next, I found the Improved Gavor wavelet.
Improved Gabor wavelet (IGW) is great one to understand the relationships between time ans frequency domains, so just I thought I wanna use this wavelet.
I wrote the scratch code for IGW successfully, but it's toooooooo heavy to calculate the coefficients. So I refered pywt.cwt and rewrite the code as below; but I can't make sure how to implemet the calculations because IGW does not use the scale parameter. Why pywt.cwt compute the integrals? Is it faster method? How can I see its argorithm? Please help me...

import math
import numbers
from typing import List, Optional, Tuple, Union

import numpy as np
import pywt
from pywt import ContinuousWavelet

class ExpandingWavelet(object):
    import scipy.fft
    fftmodule = scipy.fft
    next_fast_len = fftmodule.next_fast_len
    
    def __init__(self, name: str):
        pass
    
    def __new__(cls, name: str, *args, **kwargs):
        if "gabor" in name:
            tmp = type("ExpandingWavelet", (object,), 
                       dict(wavefun      = ExpandingWavelet.gabor_wavefun,
                            _gabor_sigma = name.lower().rsplit("gabor")[-1],
                            complex_cwt  = True,
                            upper_bound  = 8.0,
                            lower_bound  = -8.0
                           )
                      )
            if len(tmp._gabor_sigma) == 0:
                tmp._gabor_sigma = 0.8909
            return tmp.__new__(cls)
        else:
            return ContinuousWavelet(name)
    
    def gabor_wavefun(self, 
        level: int = 8, 
        length: Optional[int] = None,
        f: numbers.Number = 1
    ) -> Tuple[np.ndarray, np.ndarray]:
        """ improved Gabor wavelet: Publisher: Ji Z., Yan S., Bao J., "An improved Gabor wavelet and 
        its complete transforms," 2015 IEEE ICSPCC, 2015, DOI: 10.1109/ICSPCC.2015.7338925
        """
        assert isinstance(level, int), "``level`` must be integer, not {}".format(type(level))
        maxlen = 2**level
        if length is None:
            length = maxlen
        x = np.linspace(self.lower_bound, self.upper_bound, 2**level)
        _coef = 1/(self._gabor_sigma * np.sqrt(2 * np.pi))
        psi = np.array([_coef * np.exp( -(f*t)**2/(2 * self._gabor_sigma**2) + 2 * np.pi * 1j * (f*t) )
                        for t in x], dtype=complex if self.complex_cwt else float)
        assert len(x) == len(psi)
        if maxlen == length:
            return (psi, x)
        else:
            _idxs = np.linspace(0, maxlen, length).astype(int)
            return (psi[_idxs], x[_idxs])
    
    def integrate(self, precision: int = 8) -> Tuple[np.ndarray, np.ndarray]:
        """ call ``pywt.integrate_wavelet``
        """
        return pywt.integrate_wavelet(self.wavefun(level=precision), precision=precision)
    
    def cwt(
        self, 
        data: np.ndarray, 
        freqs: Union[numbers.Number, List[numbers.Number], np.ndarray],
        precision: int = 10,
        sampling_period: int = 1,
        method: str = "conv", 
        axis: int = -1
    ) -> Tuple[np.ndarray, ...]:
        """ Continuous Wavelet Transform
        """
        dtype_data = type(data)
        data = np.asarray(data, dtype=dtype_data) 
        dtype_cplx = np.result_type(dtype_data, np.complex64)
        if np.isscalar(freqs):
            freqs = np.array([freqs])
        
        assert data.ndim == 1, exec(
            f'raise ValueError("Invalid shape for ``data``. 1-D data is only accepted, not {data.ndim=}")')
        assert freqs.min() > 0, exec(
            f'raise ValueError("Invalid value for ``freqs``. ``freqs`` must be greater than 0 (freqs > 0).")')
        
        assert isinstance(axis, int), exec(
            f'raise ValueError("Invalid value for ``axis``. Integer is required, not {axis=}")')
        assert isinstance(precision, int), exec(
            f'raise ValueError("Invalid value for ``precision``. Integer is required, not {precision=}")')
        assert method in ["fft", "conv"], exec(
            f'raise ValueError("Invalid value for ``method``. `fft` or `conv` are only supported, not {method=}")')
        
        dtype_out = dtype_cplx if self.complex_cwt else dtype_data
        out = np.empty((np.size(freqs),) + data.shape, dtype=dtype_out)
        
        (integ_psi, x) = self.integrate(precision=precision)
        integ_psi = np.conj(integ_psi) if self.complex_cwt else integ_psi
        
        ## convert integ_psi, x to the same preciion as the data
        dtype_psi = dtype_cplx if integ_psi.dtype.kind == "c" else dtype_data
        integ_psi = np.asarray(integ_psi, dtype=dtype_psi)
        x = np.asarray(x, dtype=data.real.dtype)
        x_scale = x[-1] - x[0]
        x_step = x[1] - x[0]
        
        for i, freq in enumerate(freqs):
            ~~~~ I CAN'T IMPLEMET HERE ~~~~
            
            if out.dtype.kind != "c":
                coef = coef.real
            d = (coef.shape[-1] - data.shape[-1]) / 2.
            if d > 0:
                coef = coef[..., math.floor(d):-math.ceil(d)]
            elif d < 0:
                raise ValueError(f"Selected freq of {freq} too small.")
                
            out[i, ...] = coef
        # end for
        
        frequencies = freqs  ## Improved Gabor makes corresponding between time-freq domain directly; not need scale domain
        if np.isscalar(frequencies):
            frequencies = np.array([frequencies])
        
        return (out, frequencies)
@kaz0120
Copy link
Author

kaz0120 commented Jun 24, 2022

When using Numba, the scratch code is below; could you fix the code to be faaaaaaaaaaaaaaster?
Improved Gabor Wavelet ain't required any scale transformations, this is because it's easy to utilize and understand.
But, without Numba, took toooooooooooo long times........

from numba import njit, prange, objmode, complex128
import numpy as np


@njit("c16(f8, f8)", fastmath=False)
def jit_gabor_t(
    t: float, 
    sigma: float = 0.8909
) -> complex:
    return 1/(sigma * np.sqrt(2 * np.pi)) * np.exp( -t**2/(2 * sigma**2) + 2 * np.pi * 1j * t )


@njit("c16(f8, f8, f8)", fastmath=False)
def jit_gabor_f_tau_t(
    f: float,
    tau: float,
    t: float
) -> complex:
    return abs(f) * jit_gabor_t(t=f * (t - tau), sigma=0.8909)


@njit("f8(f8, f8, f8)", fastmath=False)
def jit_calc_wavelet_window_width(
    f: float, 
    sigma: float, 
    amp: float = 0.005, 
) -> float:
    return 1/f * sigma * np.sqrt(-2 * np.log(amp))


@njit("c16(f8[:], f8[:], f8, f8)", fastmath=False)
def jit_gabor_J(
    sig: np.ndarray,
    time: np.ndarray, 
    f: float, 
    tau: float
) -> complex:
    tmp = time - tau
    for i in prange(len(tmp)):
        tmp[i] = abs(tmp[i])
    tauidx = np.argmin(tmp)
    wavelet_window_width = jit_calc_wavelet_window_width(f=f, sigma=0.8909, amp=0.005)
    _indexes = np.where(  (time[tauidx] - wavelet_window_width <= time) 
                        & (time <= time[tauidx] + wavelet_window_width) )
    idxs = np.arange(len(time))[_indexes]
    min_, max_ = idxs.min(), idxs.max()
    out = 0. + 0.j
    for i in range(min_, max_+1):
        out += sig[i] * jit_gabor_f_tau_t(f=f, tau=tau, t=time[i])
    return out


@njit("c16[:,:](f8[:], f8[:], f8[:])")
def jit_gabor(
    time: np.ndarray, 
    sig: np.ndarray, 
    freqs: np.ndarray
) -> np.ndarray:
    N_t = len(time)
    N_f = len(freqs)
    out = np.zeros((N_f, N_t), dtype=complex128)
    for fi in range(N_f):
        for ti in range(N_t):
            out[fi,ti] = jit_gabor_J(sig=sig, time=time, f=freqs[fi], tau=time[ti])
    return out

Sample:

import time

import matplotlib.pyplot as plt
import numpy


N_t = 200
N_f = 20
dt = 0.01  # sampling intervals
t = np.arange(-1, 1, dt)
sig  = np.cos(2 * np.pi * 7 * t) + np.real(np.exp(-7*(t-0.4)**2)*np.exp(1j*2*np.pi*2*(t-0.4)))
start = time.time()
res = jit_gabor(time=t, sig=sig, freqs=np.arange(1, N_f).astype(float))
fin = time.time() - start
print(fin, "sec")

fig, ax = plt.subplots()
ax.imshow(np.abs(res), aspect='auto', origin='lower')
ax.twinx().plot(sig)
ax.set(xlabel="Time [sec]", ylabel="Frequency [Hz]")
plt.show()

image

@grlee77
Copy link
Contributor

grlee77 commented Jun 29, 2022

Hi @kaz0120, unfortunately the original CWT contributor to PyWavelets is no longer active with the project and I am more familiar with the discrete transforms. In general it is mostly @rgommers and myself doing basic maintenance of the library at the moment, but neither of us has the bandwidth to develop new features.

That said, I have looked at the code briefly in the past and had done some research regarding your question about the integral. There is a summary with links to more info in this comment:
#531 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants