Skip to content

Commit

Permalink
[FIX] v0.5.1 (#145)
Browse files Browse the repository at this point in the history
* add use_jit option to Horsager2009, add cython/fft switch on Nanduri2012

* bump version number

* fix Numba error

* fix deprecated test

* Trigger Azure Pipelines on stable
  • Loading branch information
mbeyeler committed Feb 5, 2020
1 parent 96c3d8a commit 9a6f35b
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 13 deletions.
5 changes: 3 additions & 2 deletions azure-pipelines.yml
@@ -1,16 +1,17 @@
trigger:
# Whenever someone pushes to master:
# Whenever someone pushes to master or stable:
branches:
include:
- master
- stable
# Whenever a tag is pushed:
tags:
include:
- v*

pr:
- master
- develop
- stable

variables:
# Install p2p in Docker container
Expand Down
51 changes: 43 additions & 8 deletions pulse2percept/retina.py
Expand Up @@ -508,6 +508,7 @@ def model_cascade(self, in_arr, pt_list, layers, use_jit):

# R1 convolved the entire stimulus (with both pos + neg parts)
r1 = self.tsample * utils.conv(stim, self.gamma1, mode='full',
use_jit=use_jit,
method='sparse')[:stim.size]

# It's possible that charge accumulation was done on the anodic phase.
Expand Down Expand Up @@ -574,13 +575,24 @@ def __init__(self, **kwargs):
self.asymptote = 14.0
self.slope = 3.0
self.shift = 16.0
self.use_cython = False

# Nanduri (2012) has a term in the stationary nonlinearity step that
# depends on future values of R3: max_t(R3). Because the finite
# difference model cannot look into the future, we need to set a
# scaling factor here:
self.maxR3 = 100.0

# perform one-time setup calculations
# gamma1 is used for the fast response
_, self.gamma1 = utils.gamma(1, self.tau1, self.tsample)

# gamma2 is used to calculate charge accumulation
_, self.gamma2 = utils.gamma(1, self.tau2, self.tsample)

# gamma3 is used to calculate the slow response
_, self.gamma3 = utils.gamma(3, self.tau3, self.tsample)

# Overwrite any given keyword arguments, print warning message (True)
# if attempting to set an unrecognized keyword
self.set_kwargs(True, **kwargs)
Expand Down Expand Up @@ -633,12 +645,35 @@ def model_cascade(self, in_arr, pt_list, layers, use_jit):
raise ValueError("Acceptable values for `layers` are: 'GCL', "
"'OFL'.")

pulse = self.calc_layer_current(in_arr, pt_list)
percept = fr.nanduri2012_model_cascade(pulse, self.tsample,
self.tau1, self.tau2, self.tau3,
self.asymptote, self.shift,
self.slope, self.eps,
self.maxR3)
b1 = self.calc_layer_current(in_arr, pt_list)
if self.use_cython:
percept = fr.nanduri2012_model_cascade(b1, self.tsample,
self.tau1, self.tau2,
self.tau3,
self.asymptote, self.shift,
self.slope, self.eps,
self.maxR3)
else:
# Fast response
b2 = self.tsample * utils.conv(-b1, self.gamma1, mode='full',
method='sparse',
use_jit=use_jit)[:b1.size]

# Charge accumulation
ca = self.tsample * np.cumsum(np.maximum(0, b1))
ca = self.tsample * utils.conv(ca, self.gamma2, mode='full',
method='fft')[:b1.size]
b3 = np.maximum(0, b2 - self.eps * ca)

# Stationary nonlinearity
b3max = b3.max()
sigmoid = ss.expit((b3max - self.shift) / self.slope)
b4 = b3 / b3max * sigmoid * self.asymptote

# Slow response
percept = self.tsample * utils.conv(b4, self.gamma3, mode='full',
method='fft')[:b1.size]

return utils.TimeSeries(self.tsample, percept)


Expand Down Expand Up @@ -1416,9 +1451,9 @@ def jansonius(num_cells=500, num_samples=801, center=np.array([15, 2]),

# rotate about the optic disc and scale
x = scale * (np.cos(rot) * (xmodel - center[0]) + np.sin(rot)
* (ymodel - center[1])) + center[0]
* (ymodel - center[1])) + center[0]
y = scale * (-np.sin(rot) * (xmodel - center[0]) + np.cos(rot)
* (ymodel - center[1])) + center[1]
* (ymodel - center[1])) + center[1]

return x, y

Expand Down
1 change: 1 addition & 0 deletions pulse2percept/tests/test_retina.py
Expand Up @@ -521,6 +521,7 @@ def test_axon_contribution():


# deprecated
@pytest.mark.xfail
def test_make_axon_map():
jan_x, jan_y = retina.jansonius(num_cells=10, num_samples=100)
xg, yg = np.meshgrid(np.linspace(-100, 100, 21),
Expand Down
4 changes: 2 additions & 2 deletions pulse2percept/utils.py
Expand Up @@ -283,7 +283,7 @@ def _sparseconv(data, kernel, mode):
# Add shifted and scaled copies of `kernel` only where `data` is nonzero
for p in pos:
out[p:p + kernel_len] = (out[p:p + kernel_len]
+ kernel.ravel() * data.ravel()[p])
+ kernel.ravel() * data.ravel()[p])

if mode.lower() == 'full':
return out
Expand Down Expand Up @@ -338,7 +338,7 @@ def sparseconv(data, kernel, mode, use_jit=True):
e_s = "You do not have numba, please run sparseconv with "
e_s += "`use_jit`=False."
raise ImportError(e_s)
func_sparseconv = jit(_sparseconv)
func_sparseconv = jit(_sparseconv, forceobj=True)
return func_sparseconv(data, kernel, mode)


Expand Down
2 changes: 1 addition & 1 deletion pulse2percept/version.py
Expand Up @@ -4,7 +4,7 @@
# Format expected by setup.py and doc/source/conf.py: string of form "X.Y.Z"
_version_major = 0
_version_minor = 5
_version_micro = 0 # use '' for first of series, number for 1 and above
_version_micro = 1 # use '' for first of series, number for 1 and above
_version_extra = ''
# _version_extra = '' # Uncomment this for full releases

Expand Down

0 comments on commit 9a6f35b

Please sign in to comment.