Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH][REF] Torch Scoreboard #612

Merged
merged 16 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ h5py
docutils<0.18
jax[cpu]<=0.4.3
neuropythy
pickleshare
pickleshare
torch
62 changes: 59 additions & 3 deletions pulse2percept/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from copy import deepcopy, copy
import numpy as np
import multiprocessing
import torch

from ..implants import ProsthesisSystem
from ..stimuli import Stimulus
Expand Down Expand Up @@ -250,7 +251,9 @@ def get_default_params(self):
'verbose': True,
# default to 2d model. 3d models should override this
'ndim' : [2],
'n_threads': multiprocessing.cpu_count()
'n_threads': multiprocessing.cpu_count(),
# default to cpu, can force cuda if using torch on gpu
'device': 'cpu'
}
return params

Expand Down Expand Up @@ -285,8 +288,8 @@ def build(self, **build_params):
self.grid = Grid2D(self.xrange, self.yrange, step=self.xystep,
grid_type=self.grid_type)
self.grid.build(self.vfmap)
self.is_built = True # this is so that torch models don't need to manually set is_built in order to access the model grid
self._build()
self.is_built = True
return self

@abstractmethod
Expand Down Expand Up @@ -529,7 +532,9 @@ def get_default_params(self):
'thresh_percept': 0,
# True: print status messages, False: silent
'verbose': True,
'n_threads': multiprocessing.cpu_count()
'n_threads': multiprocessing.cpu_count(),
# default to cpu, can force cuda if using torch on gpu
'device': 'cpu'
}
return params

Expand Down Expand Up @@ -1046,3 +1051,54 @@ def is_built(self):
if self.has_time:
_is_built &= self.temporal.is_built
return _is_built


class TorchBaseModel(torch.nn.Module, metaclass=ABCMeta):
def __init__(self, p2pmodel):
"""
Base class constructor for common logic

Subclasses should call this constructor and then should
read in any relevant information from the p2pmodel (which is NOT stored),
including model parameters and the spatial grid.

TODO: Can we move spatial grid reading into this?

Parameters
----------
p2pmodel : pulse2percept.models.Model
The pulse2percept model to wrap

"""
super().__init__()
if not p2pmodel.is_built:
p2pmodel.build()
self.device = torch.device(p2pmodel.device)

def forward(self, stim, e_locs, model_params=None):
"""
Forward pass of the model

Parameters
----------
stim : torch.Tensor
The stimulation tensor for each electrode.
Shape (n_time, n_elecs) or (n_time, n_elecs, 3) for biphasic models
e_locs : torch.Tensor
The locations of the electrodes
model_params : Tensor, optional
The model parameters to use. If None, will use the default parameters
Each subclass should use this, if provided, instead of parameters from the p2pmodel,
so that it is possible to differentiate wrt the parameters.

Only parameters that make sense to differentiate wrt and don't require
rebuild should go here.
For example, Scoreboard model would take in torch.tensor([rho])

Returns
-------
torch.Tensor
The predicted percept, with dimensions (n_time, n_pixels)

"""
raise NotImplementedError
111 changes: 87 additions & 24 deletions pulse2percept/models/cortex/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""`CortexSpatial`, `ScoreboardSpatial`, `ScoreboardModel`"""

from ..base import Model, SpatialModel
from ..base import Model, SpatialModel, TorchBaseModel
from ...topography import Polimeni2006Map
from .._beyeler2019 import fast_scoreboard, fast_scoreboard_3d
from ...utils.constants import ZORDER
import numpy as np
import torch

class CortexSpatial(SpatialModel):
"""Abstract base class for cortical models
Expand Down Expand Up @@ -167,6 +168,54 @@ def plot3D(self, style='scatter', ax=None, **kwargs):
ax.view_init(elev=20, azim=110)
return ax

class TorchScoreboardSpatial(TorchBaseModel):
def __init__(self, p2pmodel):
super().__init__(p2pmodel)
self.rho = torch.tensor(p2pmodel.rho, device=self.device)
self.shape = p2pmodel.grid.shape
self.regions = p2pmodel.regions
# whether to let current spread between regions
self.separate = 0
self.boundary = 0
if p2pmodel.vfmap.split_map:
self.separate = 1
self.boundary = p2pmodel.vfmap.left_offset/2
self.locs = {}
for region in self.regions:
x = torch.tensor(p2pmodel.grid[region].x.ravel(),device=self.device)
y = torch.tensor(p2pmodel.grid[region].y.ravel(),device=self.device)
if p2pmodel.grid[region].z is not None:
z = torch.tensor(p2pmodel.grid[region].z.ravel(),device=self.device)
self.locs[region] = torch.stack([x, y, z], axis=-1)
else:
self.locs[region] = torch.stack([x, y], axis=-1)

def forward(self, amps, e_locs, model_params=None):
"""Predicts the percept
Parameters
----------
amps: (t, n_elecs) shaped tensor
Amplitude for each electrode in the implant
e_locs: (n_elecs, dim) shaped tensor
electrode location (x, y, [z optional]) for each electrode
model_params: tensor, optional
rho parameter for current spread
"""
if model_params is None:
rho = self.rho
else:
rho = model_params[0]

# (npixels, nelecs)
tot_intensities = 0
for region in self.regions:
d2_el = torch.sum((self.locs[region][:, None, :] - e_locs[None, :, :] )**2, axis=-1)
intensities = amps[:, None, :] * torch.exp(-d2_el / (2 * rho**2)) # generate gaussian blobs for each electrode
if self.separate:
intensities *= torch.where((e_locs[None,:,0] < self.boundary) == (self.locs[region][:,None,0] < self.boundary), 1, 0) # ensure current cannot spread between hemispheres
intensities = torch.sum(intensities, axis=-1) # add up all gaussian blobs
tot_intensities += intensities
return tot_intensities

class ScoreboardSpatial(CortexSpatial):
"""Cortical adaptation of scoreboard model from [Beyeler2019]_
Expand Down Expand Up @@ -232,14 +281,20 @@ class ScoreboardSpatial(CortexSpatial):
"""
def __init__(self, **params):
super(ScoreboardSpatial, self).__init__(**params)
self.torchmodel = None

def _build(self):
if self.engine == 'torch':
self.torchmodel = TorchScoreboardSpatial(self)

def get_default_params(self):
"""Returns all settable parameters of the scoreboard model"""
base_params = super(ScoreboardSpatial, self).get_default_params()
params = {
# radial current spread
'rho': 200,
'ndim' : [2, 3]
'ndim' : [2, 3],
'engine': 'cython'
}
return {**base_params, **params}

Expand All @@ -258,28 +313,36 @@ def _predict_spatial(self, earray, stim):
if self.vfmap.split_map:
separate = 1
boundary = self.vfmap.left_offset/2
if self.vfmap.ndim == 3:
return np.sum([
fast_scoreboard_3d(stim.data, x_el, y_el, z_el,
self.grid[region].x.ravel(),
self.grid[region].y.ravel(),
self.grid[region].z.ravel(),
self.rho, self.thresh_percept,
separate, boundary,
self.n_threads)
for region in self.regions ],
axis = 0)
elif self.vfmap.ndim == 2:
return np.sum([
fast_scoreboard(stim.data, x_el, y_el,
self.grid[region].x.ravel(), self.grid[region].y.ravel(),
self.rho, self.thresh_percept,
separate, boundary,
self.n_threads)
for region in self.regions ],
axis = 0)
else:
raise ValueError("Invalid dimensionality of visual field map")
if self.engine == "torch":
if self.vfmap.ndim == 2:
e_locs = torch.tensor([(x,y) for x,y in zip(x_el, y_el)]).to(self.device)
amps = torch.tensor(stim.data).to(self.device)
return self.torchmodel(amps=amps.T, e_locs=e_locs).T.numpy()
else:
raise ValueError("Invalid dimensionality of visual field map")
elif self.engine == "cython":
if self.vfmap.ndim == 3:
return np.sum([
fast_scoreboard_3d(stim.data, x_el, y_el, z_el,
self.grid[region].x.ravel(),
self.grid[region].y.ravel(),
self.grid[region].z.ravel(),
self.rho, self.thresh_percept,
separate, boundary,
self.n_threads)
for region in self.regions ],
axis = 0)
elif self.vfmap.ndim == 2:
return np.sum([
fast_scoreboard(stim.data, x_el, y_el,
self.grid[region].x.ravel(), self.grid[region].y.ravel(),
self.rho, self.thresh_percept,
separate, boundary,
self.n_threads)
for region in self.regions ],
axis = 0)
else:
raise ValueError("Invalid dimensionality of visual field map")


class ScoreboardModel(Model):
Expand Down
32 changes: 18 additions & 14 deletions pulse2percept/models/cortex/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
@pytest.mark.parametrize('jitter_boundary', [True, False])
@pytest.mark.parametrize('regions',
[['v1'], ['v2'], ['v3'], ['v1', 'v2'], ['v2', 'v3'], ['v1', 'v3'], ['v1', 'v2', 'v3']])
def test_ScoreboardSpatial(ModelClass, jitter_boundary, regions):
@pytest.mark.parametrize('engine', ['cython', 'torch'])
def test_ScoreboardSpatial(ModelClass, jitter_boundary, regions, engine):
# ScoreboardSpatial automatically sets `regions`
vfmap = Polimeni2006Map(k=15, a=.5, b=90, jitter_boundary=jitter_boundary, regions=regions)
model = ModelClass(xrange=(-3, 3), yrange=(-3, 3), xystep=0.1, vfmap=vfmap).build()
model = ModelClass(xrange=(-3, 3), yrange=(-3, 3), xystep=0.1, vfmap=vfmap, engine=engine).build()
npt.assert_equal(model.regions, regions)
npt.assert_equal(model.vfmap.regions, regions)

Expand All @@ -35,7 +36,7 @@ def test_ScoreboardSpatial(ModelClass, jitter_boundary, regions):

# Converting ret <=> dva
vfmap = Polimeni2006Map(k=15, a=0.5, b=90, jitter_boundary=jitter_boundary, regions=regions)
model = ModelClass(xrange=(-3, 3), yrange=(-3, 3), xystep=1, vfmap=vfmap).build()
model = ModelClass(xrange=(-3, 3), yrange=(-3, 3), xystep=1, vfmap=vfmap, engine=engine).build()
npt.assert_equal(isinstance(model.vfmap, Polimeni2006Map), True)
if jitter_boundary:
npt.assert_equal(np.isnan(model.vfmap.dva_to_v1([0], [0])), False)
Expand Down Expand Up @@ -65,9 +66,10 @@ def test_ScoreboardSpatial(ModelClass, jitter_boundary, regions):
@pytest.mark.parametrize('ModelClass', [ScoreboardModel, ScoreboardSpatial])
@pytest.mark.parametrize('regions',
[['v1'], ['v2'], ['v3'], ['v1', 'v2'], ['v2', 'v3'], ['v1', 'v3'], ['v1', 'v2', 'v3']])
def test_predict_spatial(ModelClass, regions):
@pytest.mark.parametrize('engine', ['cython', 'torch'])
def test_predict_spatial(ModelClass, regions, engine):
# test that no current can spread between hemispheres
model = ModelClass(xrange=(-3, 3), yrange=(-3, 3), xystep=0.5, rho=100000, regions=regions).build()
model = ModelClass(xrange=(-3, 3), yrange=(-3, 3), xystep=0.5, rho=100000, regions=regions, engine=engine).build()
implant = Orion(x = 15000)
implant.stim = {e:5 for e in implant.electrode_names}
percept = model.predict_percept(implant)
Expand All @@ -77,7 +79,7 @@ def test_predict_spatial(ModelClass, regions):

# implant only in v1, shouldnt change with v2/v3
vfmap = Polimeni2006Map(k=15, a=0.5, b=90)
model = ModelClass(xrange=(-5, 0), yrange=(-3, 3), xystep=0.1, rho=400, vfmap=vfmap).build()
model = ModelClass(xrange=(-5, 0), yrange=(-3, 3), xystep=0.1, rho=400, vfmap=vfmap, engine=engine).build()
elecs = [79, 49, 19, 80, 50, 20, 90, 61, 31, 2, 72, 42, 12, 83, 53, 23, 93, 64, 34, 5, 75, 45, 15, 86, 56, 26, 96, 67, 37, 8, 68, 38]
implant = Cortivis(x=30000, y=0, rot=0, stim={str(i) : [1, 0] for i in elecs})
percept = model.predict_percept(implant)
Expand All @@ -94,7 +96,7 @@ def test_predict_spatial(ModelClass, regions):
if 'v1' in regions:
# make sure cortical representation is flipped
vfmap = Polimeni2006Map(k=15, a=0.5, b=90)
model = ModelClass(xrange=(-5, 0), yrange=(-3, 3), xystep=0.1, rho=400, vfmap=vfmap).build()
model = ModelClass(xrange=(-5, 0), yrange=(-3, 3), xystep=0.1, rho=400, vfmap=vfmap, engine=engine).build()
implant = Orion(x=30000, y=0, rot=0, stim={'40' : 1, '94' :5})
percept = model.predict_percept(implant)
half = model.grid.shape[0] // 2
Expand All @@ -103,11 +105,12 @@ def test_predict_spatial(ModelClass, regions):

@pytest.mark.parametrize('ModelClass', [ScoreboardModel, ScoreboardSpatial])
@pytest.mark.parametrize('regions', [['v1', 'v2'], ['v1', 'v3'], ['v2', 'v3']])
def test_predict_spatial_regionsum(ModelClass,regions):
@pytest.mark.parametrize('engine', ['cython', 'torch'])
def test_predict_spatial_regionsum(ModelClass,regions,engine):
print(regions)
model1 = ModelClass(xrange=(-3, 3), yrange=(-3, 3), xystep=0.1, rho=10000, regions=regions[0]).build()
model2 = ModelClass(xrange=(-3, 3), yrange=(-3, 3), xystep=0.1, rho=10000, regions=regions[1]).build()
model_both = ModelClass(xrange=(-3, 3), yrange=(-3, 3), xystep=0.1, rho=10000, regions=regions).build()
model1 = ModelClass(xrange=(-3, 3), yrange=(-3, 3), xystep=0.1, rho=10000, regions=regions[0], engine=engine).build()
model2 = ModelClass(xrange=(-3, 3), yrange=(-3, 3), xystep=0.1, rho=10000, regions=regions[1], engine=engine).build()
model_both = ModelClass(xrange=(-3, 3), yrange=(-3, 3), xystep=0.1, rho=10000, regions=regions, engine=engine).build()

implant = Orion(x = 10000, y=10000)
implant.stim = {e : 1 for e in implant.electrode_names}
Expand All @@ -121,11 +124,12 @@ def test_predict_spatial_regionsum(ModelClass,regions):

@pytest.mark.parametrize('ModelClass', [ScoreboardModel, ScoreboardSpatial])
@pytest.mark.parametrize('stimval', np.arange(0, 5, 1))
def test_eq_beyeler(ModelClass, stimval):
@pytest.mark.parametrize('engine', ['cython', 'torch'])
def test_eq_beyeler(ModelClass, stimval, engine):


vfmap = Watson2014Map()
cortex = ModelClass(xrange=(-3, 3), yrange=(-3, 3), xystep=0.1, rho=200 * stimval, regions=['ret'], vfmap=vfmap).build()
cortex = ModelClass(xrange=(-3, 3), yrange=(-3, 3), xystep=0.1, rho=200 * stimval, regions=['ret'], vfmap=vfmap, engine=engine).build()
retina = BeyelerScoreboard(xrange=(-3, 3), yrange=(-3, 3), xystep=0.1, rho=200 * stimval).build()

implant = ArgusII()
Expand All @@ -134,7 +138,7 @@ def test_eq_beyeler(ModelClass, stimval):
p1 = cortex.predict_percept(implant)
p2 = retina.predict_percept(implant)

npt.assert_equal(p1.data, p2.data)
npt.assert_almost_equal(p1.data, p2.data, 5)



Expand Down