Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] TorchBiphasicAxonMapSpatial #617

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30,864 changes: 30,864 additions & 0 deletions pulse2percept/models/_horsager2009.cpp

Large diffs are not rendered by default.

32,214 changes: 32,214 additions & 0 deletions pulse2percept/models/_nanduri2012.cpp

Large diffs are not rendered by default.

250 changes: 236 additions & 14 deletions pulse2percept/models/granley2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import partial
import numpy as np
import sys
import torch

from . import AxonMapSpatial, Model
from ..implants import ProsthesisSystem, ElectrodeArray
Expand All @@ -22,6 +23,14 @@
except ImportError:
has_jax = False

try:
import torch
import torch.nn as nn
import torch.optim as optim
has_torch = True
except ImportError:
has_torch = False


def cond_jit(fn, static_argnums=None):
""" Conditional decorator for jax jit"""
Expand Down Expand Up @@ -302,6 +311,7 @@ class BiphasicAxonMapSpatial(AxonMapSpatial):

def __init__(self, **params):
super(BiphasicAxonMapSpatial, self).__init__(**params)
self.torchmodel = None
if self.bright_model is None:
self.bright_model = DefaultBrightModel()
if self.size_model is None:
Expand Down Expand Up @@ -354,7 +364,7 @@ def __setattr__(self, name, value):
pass
# Check whether the attribute is a part of any
# bright/size/streak model
if name not in ['bright_model', 'size_model', 'streak_model', 'is_built', '_is_built']:
if name not in ['bright_model', 'size_model', 'streak_model', 'is_built', '_is_built', 'torchmodel']:
try:
for m in [self.bright_model, self.size_model, self.streak_model]:
if hasattr(m, name):
Expand Down Expand Up @@ -388,6 +398,8 @@ def get_default_params(self):
# Callable model used to modulate percept streak length with amplitude,
# frequency, and pulse duration
'streak_model': None,

'torchmodel': None,
}
return {**base_params, **params}

Expand All @@ -402,15 +414,30 @@ def _build(self):
raise ImportError("Engine was chosen as jax, but jax is not installed. "
"You can install it with 'pip install \"jax[cpu]\"' for cpu "
"or following https://github.com/google/jax#installation for gpu")
if self.engine == 'torch' and not has_torch:
raise ImportError("Engine was chosen as torch, but torch is not installed. "
"You can install it with 'pip install torch'")

super(BiphasicAxonMapSpatial, self)._build()

if self.engine == 'jax':
# Clear previously cached functions
self._predict_spatial_jax = jit(self._predict_spatial_jax)
self._predict_spatial_batched = jit(self._predict_spatial_batched)
# Cache axon_contrib for fast access later
self.axon_contrib = jax.device_put(
jnp.array(self.axon_contrib), jax.devices()[0])

if self.engine == 'torch':
# Convert functions to TorchScript for optimization
# self.is_built = True
self.torchmodel = TorchBiphasicAxonMapSpatial(self)
self.axon_contrib = torch.tensor(self.axon_contrib).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) # move axon_contrib to the appropriate torch device (e.g., GPU)
self._predict_spatial = torch.compile(self._predict_spatial)
self.predict_percept = torch.compile(self.predict_percept)
self._predict_spatial_batched = torch.compile(self._predict_spatial_batched)



def _predict_spatial(self, earray, stim):
"""Predicts the percept"""
Expand Down Expand Up @@ -440,7 +467,37 @@ def _predict_spatial(self, earray, stim):
x = np.array(x, dtype=np.float32)
y = np.array(y, dtype=np.float32)

if self.engine != 'jax':
if self.engine == 'jax':
return self._predict_spatial_jax(elec_params[:, :3], x, y)

elif self.engine == 'torch':
elec_params_torch = torch.tensor(elec_params, dtype=torch.get_default_dtype())
x_torch = torch.tensor(x, dtype=torch.get_default_dtype())
y_torch = torch.tensor(y, dtype=torch.get_default_dtype())

bright_effects = self.bright_model(
elec_params_torch[:, 0], elec_params_torch[:, 1], elec_params_torch[:, 2])
size_effects = self.size_model(
elec_params_torch[:, 0], elec_params_torch[:, 1], elec_params_torch[:, 2])
streak_effects = self.streak_model(
elec_params_torch[:, 0], elec_params_torch[:, 1], elec_params_torch[:, 2])

amps = elec_params_torch[:, 1]

return fast_biphasic_axon_map(
amps.cpu().numpy(),
bright_effects.cpu().numpy(),
size_effects.cpu().numpy(),
streak_effects.cpu().numpy(),
x_torch.cpu().numpy(),
y_torch.cpu().numpy(),
self.axon_contrib.cpu().numpy(),
self.axon_idx_start.cpu().numpy().astype(np.uint32),
self.axon_idx_end.cpu().numpy().astype(np.uint32),
self.rho, self.thresh_percept,
self.n_threads)

else:
bright_effects = np.array(self.bright_model(elec_params[:, 0], elec_params[:, 1], elec_params[:, 2]),
dtype=np.float32).reshape((-1))
size_effects = np.array(self.size_model(elec_params[:, 0], elec_params[:, 1], elec_params[:, 2]),
Expand All @@ -459,8 +516,6 @@ def _predict_spatial(self, earray, stim):
self.axon_idx_end.astype(np.uint32),
self.rho, self.thresh_percept,
self.n_threads)
else:
return self._predict_spatial_jax(elec_params[:, :3], x, y)

def predict_one_point_jax(self, axon, eparams, x, y, rho):
""" Predicts the brightness contribution from each axon segment for each pixel"""
Expand Down Expand Up @@ -547,17 +602,34 @@ def _predict_spatial_batched(self, elec_params, x, y):
------------
resp : np.array() representing the resulting percepts, shape (batch_size, :, 1)
"""
bright_effects = self.bright_model(elec_params[:, :, 0],
elec_params[:, :, 1],
elec_params[:, :, 2])
size_effects = self.size_model(elec_params[:, :, 0],
elec_params[:, :, 1],
elec_params[:, :, 2])
streak_effects = self.streak_model(elec_params[:, :, 0],
elec_params[:, :, 1],
if self.engine == 'jax':
bright_effects = self.bright_model(elec_params[:, :, 0],
elec_params[:, :, 1],
elec_params[:, :, 2])
size_effects = self.size_model(elec_params[:, :, 0],
elec_params[:, :, 1],
elec_params[:, :, 2])
eparams = jnp.stack(
[bright_effects, size_effects, streak_effects], axis=2)
streak_effects = self.streak_model(elec_params[:, :, 0],
elec_params[:, :, 1],
elec_params[:, :, 2])
eparams = jnp.stack([bright_effects, size_effects, streak_effects], axis=2)

elif self.engine == 'torch':
elec_params_torch = torch.tensor(elec_params, dtype=torch.float32)
x_torch = torch.tensor(x, dtype=torch.float32)
y_torch = torch.tensor(y, dtype=torch.float32)
bright_effects = self.bright_model(elec_params_torch[:, :, 0],
elec_params_torch[:, :, 1],
elec_params_torch[:, :, 2])
size_effects = self.size_model(elec_params_torch[:, :, 0],
elec_params_torch[:, :, 1],
elec_params_torch[:, :, 2])
streak_effects = self.streak_model(elec_params_torch[:, :, 0],
elec_params_torch[:, :, 1],
elec_params_torch[:, :, 2])
eparams = torch.stack([bright_effects, size_effects, streak_effects], dim=2)
else:
raise ValueError("Unsupported engine type")

def predict_one(e_params):
return self.biphasic_axon_map_jax(e_params, x, y,
Expand Down Expand Up @@ -867,3 +939,153 @@ def predict_percept(self, implant, t_percept=None):
return None
resp = self.spatial.predict_percept(implant, t_percept=t_percept)
return resp




class TorchBiphasicAxonMapSpatial(nn.Module):
""" TorchBiphasicAxonMapSpatial

An AxonMapModel where phosphene brightness, size, and streak length scale
according to amplitude, frequency, and pulse duration using PyTorch.

All stimuli must be BiphasicPulseTrains.

This model is different than other spatial models in that it calculates
one representative percept from all time steps of the stimulus.

Brightness, size, and streak length scaling are controlled by the parameters
bright_model, size_model, and streak model respectively. By default, these are
set to classes that implement Eqs 3-6 from Granley 2021. These models can be
individually customized by setting the bright_model, size_model, or streak_model
to any python callable with signature f(freq, amp, pdur)

Parameters
----------
bright_model: callable, optional
Model used to modulate percept brightness with amplitude, frequency,
and pulse duration
size_model: callable, optional
Model used to modulate percept size with amplitude, frequency, and
pulse duration
streak_model: callable, optional
Model used to modulate percept streak length with amplitude, frequency,
and pulse duration
do_thresholding: boolean
Use probabilistic sigmoid thresholding, default: False
**params: dict, optional
Additional parameters to customize the model
"""
def __init__(self, p2pmodel, implant, activity_regularizer=None, clip=None, amp_cutoff=True, **kwargs):
super().__init__()
# p2pmodel.build()

if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(device)

if not p2pmodel.is_built:
p2pmodel.build()
bundles = p2pmodel.grow_axon_bundles() # returns coordinates of axon bundles
axons = p2pmodel.find_closest_axon(bundles) # maps axon bundles to closest axon locations
axon_contrib = p2pmodel.calc_axon_sensitivity(axons, pad=True) # Assuming the function returns sensitivity values
axon_contrib = torch.tensor(axon_contrib, dtype=dtype) # Convert to tensor with default dtype
else:
axon_contrib = p2pmodel.axon_contrib

# Ensure models are callable
for model in [p2pmodel, self.bright_model, self.size_model, self.streak_model]:
if not isinstance(model, BiphasicAxonMapSpatial):
raise TypeError(f"{model} needs to be an instance of torch.nn.Module or callable")

if isinstance(p2pmodel, BiphasicAxonMapModel) or not isinstance(p2pmodel, BiphasicAxonMapSpatial):
raise ValueError("Must pass in a valid BiphasicAxonMapModelSpatial")

if p2pmodel.engine != 'torch':
raise ValueError("Engine Selection Conflict : Constructing TorchAxonMapSpatial with engine", p2pmodel.engine)

dtype = torch.get_default_dtype()


if type(axons) != list:
axons = [axons]
# Calculate axon sensitivity and convert to tensor
# Register as a buffer to ensure it moves with the model to GPU/CPU appropriately
self.register_buffer("axon_contrib", axon_contrib)
self.rho = torch.tensor(p2pmodel.rho, torch.get_default_dtype())

# Extract and convert implant electrode coordinates into tensors
self.elec_x = torch.tensor([implant[e].x for e in implant.electrodes], dtype=dtype)
self.elec_y = torch.tensor([implant[e].y for e in implant.electrodes], dtype=dtype)

# compute squared distances from each axon pixel to each electrode
d2_el = (self.axon_contrib[:, :, 0, None] - self.elec_x) ** 2 + \
(self.axon_contrib[:, :, 1, None] - self.elec_y) ** 2

self.register_buffer("d2_el", d2_el)

# Other parameters from initialization args
self.percept_shape = p2pmodel.grid.shape # p2pmodel has a grid attribute defining percept shape
self.thresh_percept = p2pmodel.thresh_percept



# Convert functions to TorchScript for optimization
''' self._predict_spatial_torch = torch.jit.script(self._predict_spatial_torch)
self._predict_spatial_batched_torch = torch.jit.script(self._predict_spatial_batched_torch) # use torch.compile to compile the function, and build is not needed
'''


def forward(self, inputs, like_jax=False):
freq = inputs[0][:, :, 0]
amp = inputs[0][:, :, 1]
pdur = inputs[0][:, :, 2]

rho = inputs[1][:, 0][:, None]
axlambda = inputs[1][:, 1][:, None]
a0 = inputs[1][:, 2][:, None]
a1 = inputs[1][:, 3][:, None]
a2 = inputs[1][:, 4][:, None]
a3 = inputs[1][:, 5][:, None]
a4 = inputs[1][:, 6][:, None]
a5 = inputs[1][:, 7][:, None]
a6 = inputs[1][:, 8][:, None]
a7 = inputs[1][:, 9][:, None]
a8 = inputs[1][:, 10][:, None]
a9 = inputs[1][:, 11][:, None]

scaled_amps = (a1 + a0*pdur) * amp

# bright
F_bright = a2 * scaled_amps + a3 * freq
if not like_jax: # like pyx impl.
F_bright = torch.where(amp > 0, F_bright, torch.zeros_like(F_bright))

# size
min_f_size = 10**2 / (rho**2)
F_size = a5 * scaled_amps + a6
F_size = torch.maximum(F_size, min_f_size)

# streak
min_f_streak = 10**2 / (axlambda ** 2)
F_streak = a9 - a7 * pdur ** a8
F_streak = torch.maximum(F_streak, min_f_streak)

# apply axon map
intensities = (
F_bright[:, None, None, :] * # 1, 1, 1, 225
torch.exp(
-self.d2_el[None, :, :, :] / # dist2el 1, 2401, 118, 225
(2. * rho**2 * F_size)[:, None, None, :] + # 1, 1, 1, 225
self.axon_contrib[None, :, :, 2, None] / # sens 1, 2401, 118, 1
(axlambda**2 * F_streak)[:, None, None, :] # 1, 1, 1, 225
) # 1, 2401, 118, 225, scaling between 0, 1
) # 1, 2401, 118, 225

# after summing up...
intensities = torch.max(torch.sum(intensities, axis=-1), axis=-1).values # sum over electrodes, max over segments
intensities = torch.where(intensities > self.thresh_percept, intensities, torch.zeros_like(intensities))

batched_percept_shape = tuple([-1] + list(self.percept_shape))
intensities = intensities.reshape(batched_percept_shape)
return intensities
23 changes: 19 additions & 4 deletions pulse2percept/models/tests/test_granley2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
except ImportError:
has_jax = False

try:
import torch
import torch.nn as nn
import torch.optim as optim
has_torch = True
except ImportError:
has_torch = False

def test_deepcopy_DefaultBrightModel():
original = DefaultBrightModel()
copied = copy.deepcopy(original)
Expand Down Expand Up @@ -166,11 +174,13 @@ def test_effects_models():
npt.assert_equal(hasattr(model, 'a9'), True)


@pytest.mark.parametrize('engine', ('serial', 'cython', 'jax'))
@pytest.mark.parametrize('engine', ('serial', 'cython', 'jax', 'torch'))
def test_biphasicAxonMapSpatial(engine):
if engine == 'jax' and not has_jax:
pytest.skip("Jax not installed")

if engine == 'torch' and not has_torch:
pytest.skip("Torch not installed")
# Lambda cannot be too small:
with pytest.raises(ValueError):
BiphasicAxonMapSpatial(axlambda=9).build()
Expand Down Expand Up @@ -269,10 +279,13 @@ def test_predict_spatial_jax():
p2 = model2.predict_percept(implant)
npt.assert_almost_equal(p1.data, p2.data, decimal=4)

@pytest.mark.parametrize('engine', ('serial', 'cython', 'jax'))
@pytest.mark.parametrize('engine', ('serial', 'cython', 'jax', 'torch'))
def test_predict_batched(engine):
if not has_jax:
pytest.skip("Jax not installed")

if not has_torch:
pytest.skip("Torch not installed")

# Allows mix of valid Stimulus types
stims = [{'A5' : BiphasicPulseTrain(25, 4, 0.45),
Expand All @@ -283,7 +296,7 @@ def test_predict_batched(engine):
model = BiphasicAxonMapModel(engine=engine, xystep=2)
model.build()
# Import error if we dont have jax
if engine != 'jax':
if (engine == 'jax' and not has_jax) or (engine == 'torch' and not has_torch):
with pytest.raises(ImportError):
model.predict_percept_batched(implant, stims)
return
Expand All @@ -298,10 +311,12 @@ def test_predict_batched(engine):
for p1, p2 in zip(percepts_batched, percepts_serial):
npt.assert_almost_equal(p1.data, p2.data)

@pytest.mark.parametrize('engine', ('serial', 'cython', 'jax'))
@pytest.mark.parametrize('engine', ('serial', 'cython', 'jax', 'torch'))
def test_biphasicAxonMapModel(engine):
if engine == 'jax' and not has_jax:
pytest.skip("Jax not installed")
if engine == 'torch' and not has_torch:
pytest.skip("Torch not installed")
set_params = {'xystep': 2, 'engine': engine, 'rho': 432, 'axlambda': 20,
'n_axons': 9, 'n_ax_segments': 50,
'xrange': (-30, 30), 'yrange': (-20, 20),
Expand Down